diff --git a/proxy/tcpproxy/userspace.go b/proxy/tcpproxy/userspace.go new file mode 100644 index 000000000..33c920701 --- /dev/null +++ b/proxy/tcpproxy/userspace.go @@ -0,0 +1,149 @@ +// Copyright 2016 CoreOS, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpproxy + +import ( + "io" + "net" + "sync" + "time" +) + +type tcpProxy struct { + l net.Listener + monitorInterval time.Duration + donec chan struct{} + + mu sync.Mutex // guards the following fields + remotes []*remote + nextRemote int +} + +type remote struct { + mu sync.Mutex + addr string + inactive bool +} + +func (r *remote) inactivate() { + r.mu.Lock() + defer r.mu.Unlock() + r.inactive = true +} + +func (r *remote) tryReactivate() { + conn, err := net.Dial("tcp", r.addr) + if err != nil { + return + } + conn.Close() + r.mu.Lock() + defer r.mu.Unlock() + r.inactive = false + return +} + +func (r *remote) isActive() bool { + r.mu.Lock() + defer r.mu.Unlock() + return !r.inactive +} + +func (tp *tcpProxy) run() error { + go tp.runMonitor() + for { + in, err := tp.l.Accept() + if err != nil { + return err + } + + go tp.serve(in) + } +} + +func (tp *tcpProxy) numRemotes() int { + tp.mu.Lock() + defer tp.mu.Unlock() + return len(tp.remotes) +} + +func (tp *tcpProxy) serve(in net.Conn) { + var ( + err error + out net.Conn + ) + + for i := 0; i < tp.numRemotes(); i++ { + remote := tp.pick() + if !remote.isActive() { + continue + } + // TODO: add timeout + out, err = net.Dial("tcp", remote.addr) + if err == nil { + break + } + remote.inactivate() + } + + if out == nil { + in.Close() + return + } + + go func() { + io.Copy(in, out) + in.Close() + out.Close() + }() + + io.Copy(out, in) + out.Close() + in.Close() +} + +// pick picks a remote in round-robin fashion +func (tp *tcpProxy) pick() *remote { + tp.mu.Lock() + defer tp.mu.Unlock() + + picked := tp.remotes[tp.nextRemote] + tp.nextRemote = (tp.nextRemote + 1) % len(tp.remotes) + return picked +} + +func (tp *tcpProxy) runMonitor() { + for { + select { + case <-time.After(tp.monitorInterval): + tp.mu.Lock() + for _, r := range tp.remotes { + if !r.isActive() { + go r.tryReactivate() + } + } + tp.mu.Unlock() + case <-tp.donec: + return + } + } +} + +func (tp *tcpProxy) stop() { + // graceful shutdown? + // shutdown current connections? + tp.l.Close() + close(tp.donec) +} diff --git a/proxy/tcpproxy/userspace_test.go b/proxy/tcpproxy/userspace_test.go new file mode 100644 index 000000000..6d38d076c --- /dev/null +++ b/proxy/tcpproxy/userspace_test.go @@ -0,0 +1,73 @@ +// Copyright 2016 CoreOS, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpproxy + +import ( + "fmt" + "io/ioutil" + "net" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" +) + +func TestUserspaceProxy(t *testing.T) { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer l.Close() + + want := "hello proxy" + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, want) + })) + defer ts.Close() + + u, err := url.Parse(ts.URL) + if err != nil { + t.Fatal(err) + } + + p := tcpProxy{ + l: l, + donec: make(chan struct{}), + monitorInterval: time.Second, + + remotes: []*remote{ + {addr: u.Host}, + }, + } + go p.run() + defer p.stop() + + u.Host = l.Addr().String() + + res, err := http.Get(u.String()) + if err != nil { + t.Fatal(err) + } + got, gerr := ioutil.ReadAll(res.Body) + res.Body.Close() + if gerr != nil { + t.Fatal(gerr) + } + + if string(got) != want { + t.Errorf("got = %s, want %s", got, want) + } +} diff --git a/test b/test index 8f0c2da90..1205065a0 100755 --- a/test +++ b/test @@ -28,7 +28,7 @@ ln -s ${PWD}/cmd/vendor $GOPATH/src # Hack: gofmt ./ will recursively check the .git directory. So use *.go for gofmt. PKGS=`ls pkg/*/*go | cut -f1,2 -d/ | sort | uniq` -TESTABLE_AND_FORMATTABLE="client clientv3 discovery error etcdctl/ctlv2 etcdctl/ctlv3 etcdmain etcdserver etcdserver/auth etcdserver/api/v2http etcdserver/api/v2http/httptypes $PKGS proxy/httpproxy raft snap storage storage/backend store version wal" +TESTABLE_AND_FORMATTABLE="client clientv3 discovery error etcdctl/ctlv2 etcdctl/ctlv3 etcdmain etcdserver etcdserver/auth etcdserver/api/v2http etcdserver/api/v2http/httptypes $PKGS proxy/httpproxy proxy/tcpproxy raft snap storage storage/backend store version wal" # TODO: add it to race testing when the issue is resolved # https://github.com/golang/go/issues/9946 NO_RACE_TESTABLE="rafthttp"