From ab486e534897f6789dfac3d1cd9e78e7178fcc24 Mon Sep 17 00:00:00 2001 From: Gyuho Lee Date: Tue, 2 Jan 2018 08:46:32 -0800 Subject: [PATCH] pkg/transport: implement "Proxy" Signed-off-by: Gyuho Lee --- pkg/transport/proxy.go | 801 ++++++++++++++++++++++++++++++++++++ pkg/transport/proxy_test.go | 606 +++++++++++++++++++++++++++ 2 files changed, 1407 insertions(+) create mode 100644 pkg/transport/proxy.go create mode 100644 pkg/transport/proxy_test.go diff --git a/pkg/transport/proxy.go b/pkg/transport/proxy.go new file mode 100644 index 000000000..8af76d46b --- /dev/null +++ b/pkg/transport/proxy.go @@ -0,0 +1,801 @@ +// Copyright 2018 The etcd Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package transport + +import ( + "fmt" + "io" + mrand "math/rand" + "net" + "net/http" + "net/url" + "os" + "strings" + "sync" + "time" + + humanize "github.com/dustin/go-humanize" + "google.golang.org/grpc/grpclog" +) + +// Proxy defines proxy layer that simulates common network faults, +// such as latency spikes, packet drop/corruption, etc.. +type Proxy interface { + // From returns proxy source address in "scheme://host:port" format. + From() string + // To returns proxy destination address in "scheme://host:port" format. + To() string + + // Ready returns when proxy is ready to serve. + Ready() <-chan struct{} + // Done returns when proxy has been closed. + Done() <-chan struct{} + // Error sends errors while serving proxy. + Error() <-chan error + // Close closes listener and transport. + Close() error + + // DelayAccept adds latency ± random variable to accepting new incoming connections. + DelayAccept(latency, rv time.Duration) + // UndelayAccept removes sending latencies. + UndelayAccept() + // LatencyAccept returns current latency on accepting new incoming connections. + LatencyAccept() time.Duration + // DelayTx adds latency ± random variable to "sending" layer. + DelayTx(latency, rv time.Duration) + // UndelayTx removes sending latencies. + UndelayTx() + // LatencyTx returns current send latency. + LatencyTx() time.Duration + // DelayRx adds latency ± random variable to "receiving" layer. + DelayRx(latency, rv time.Duration) + // UndelayRx removes "receiving" latencies. + UndelayRx() + // LatencyRx returns current receive latency. + LatencyRx() time.Duration + + // PauseAccept stops accepting new connections. + PauseAccept() + // UnpauseAccept removes pause operation on accepting new connections. + UnpauseAccept() + // PauseTx stops "forwarding" packets. + PauseTx() + // UnpauseTx removes "forwarding" pause operation. + UnpauseTx() + // PauseRx stops "receiving" packets to client. + PauseRx() + // UnpauseRx removes "receiving" pause operation. + UnpauseRx() + + // BlackholeTx drops all incoming packets before "forwarding". + BlackholeTx() + // UnblackholeTx removes blackhole operation on "sending". + UnblackholeTx() + // BlackholeRx drops all incoming packets to client. + BlackholeRx() + // UnblackholeRx removes blackhole operation on "receiving". + UnblackholeRx() + + // CorruptTx corrupts incoming packets from the listener. + CorruptTx(f func(data []byte) []byte) + // UncorruptTx removes corrupt operation on "forwarding". + UncorruptTx() + // CorruptRx corrupts incoming packets to client. + CorruptRx(f func(data []byte) []byte) + // UncorruptRx removes corrupt operation on "receiving". + UncorruptRx() + + // ResetListener closes and restarts listener. + ResetListener() error +} + +type proxy struct { + from, to url.URL + tlsInfo TLSInfo + dialTimeout time.Duration + bufferSize int + retryInterval time.Duration + logger grpclog.LoggerV2 + + readyc chan struct{} + donec chan struct{} + errc chan error + + closeOnce sync.Once + closeWg sync.WaitGroup + + listenerMu sync.RWMutex + listener net.Listener + + latencyAcceptMu sync.RWMutex + latencyAccept time.Duration + latencyTxMu sync.RWMutex + latencyTx time.Duration + latencyRxMu sync.RWMutex + latencyRx time.Duration + + corruptTxMu sync.RWMutex + corruptTx func(data []byte) []byte + corruptRxMu sync.RWMutex + corruptRx func(data []byte) []byte + + acceptMu sync.Mutex + pauseAcceptc chan struct{} + txMu sync.Mutex + pauseTxc chan struct{} + blackholeTxc chan struct{} + rxMu sync.Mutex + pauseRxc chan struct{} + blackholeRxc chan struct{} +} + +// ProxyConfig defines proxy configuration. +type ProxyConfig struct { + From url.URL + To url.URL + TLSInfo TLSInfo + DialTimeout time.Duration + BufferSize int + RetryInterval time.Duration + Logger grpclog.LoggerV2 +} + +var ( + defaultDialTimeout = 3 * time.Second + defaultBufferSize = 48 * 1024 + defaultRetryInterval = 10 * time.Millisecond + defaultLogger = grpclog.NewLoggerV2WithVerbosity(os.Stderr, os.Stderr, os.Stderr, 0) +) + +// NewProxy returns a proxy implementation with no iptables/tc dependencies. +// The proxy layer overhead is <1ms. +func NewProxy(cfg ProxyConfig) Proxy { + p := &proxy{ + from: cfg.From, + to: cfg.To, + tlsInfo: cfg.TLSInfo, + dialTimeout: cfg.DialTimeout, + bufferSize: cfg.BufferSize, + retryInterval: cfg.RetryInterval, + logger: cfg.Logger, + + readyc: make(chan struct{}), + donec: make(chan struct{}), + errc: make(chan error, 16), + + pauseAcceptc: make(chan struct{}), + pauseTxc: make(chan struct{}), + blackholeTxc: make(chan struct{}), + pauseRxc: make(chan struct{}), + blackholeRxc: make(chan struct{}), + } + if p.dialTimeout == 0 { + p.dialTimeout = defaultDialTimeout + } + if p.bufferSize == 0 { + p.bufferSize = defaultBufferSize + } + if p.retryInterval == 0 { + p.retryInterval = defaultRetryInterval + } + if p.logger == nil { + p.logger = defaultLogger + } + close(p.pauseAcceptc) + close(p.pauseTxc) + close(p.pauseRxc) + + if strings.HasPrefix(p.from.Scheme, "http") { + p.from.Scheme = "tcp" + } + if strings.HasPrefix(p.to.Scheme, "http") { + p.to.Scheme = "tcp" + } + + var ln net.Listener + var err error + if !p.tlsInfo.Empty() { + ln, err = NewListener(p.from.Host, p.from.Scheme, &p.tlsInfo) + } else { + ln, err = net.Listen(p.from.Scheme, p.from.Host) + } + if err != nil { + p.errc <- err + p.Close() + return p + } + p.listener = ln + + p.closeWg.Add(1) + go p.listenAndServe() + p.logger.Infof("started proxying [%s -> %s]", p.From(), p.To()) + return p +} + +func (p *proxy) From() string { + return fmt.Sprintf("%s://%s", p.from.Scheme, p.from.Host) +} + +func (p *proxy) To() string { + return fmt.Sprintf("%s://%s", p.to.Scheme, p.to.Host) +} + +// TODO: implement packet reordering from multiple TCP connections +// buffer packets per connection for awhile, reorder before transmit +// - https://github.com/coreos/etcd/issues/5614 +// - https://github.com/coreos/etcd/pull/6918#issuecomment-264093034 + +func (p *proxy) listenAndServe() { + defer p.closeWg.Done() + + p.logger.Infof("listen %q", p.From()) + close(p.readyc) + + for { + p.acceptMu.Lock() + pausec := p.pauseAcceptc + p.acceptMu.Unlock() + select { + case <-pausec: + case <-p.donec: + return + } + + p.latencyAcceptMu.RLock() + lat := p.latencyAccept + p.latencyAcceptMu.RUnlock() + if lat > 0 { + select { + case <-time.After(lat): + case <-p.donec: + return + } + } + + p.listenerMu.RLock() + ln := p.listener + p.listenerMu.RUnlock() + + in, err := ln.Accept() + if err != nil { + select { + case p.errc <- err: + select { + case <-p.donec: + return + default: + } + case <-p.donec: + return + } + if p.logger.V(5) { + p.logger.Errorf("listener accept error %q", err.Error()) + } + + if strings.HasSuffix(err.Error(), "use of closed network connection") { + select { + case <-time.After(p.retryInterval): + case <-p.donec: + return + } + if p.logger.V(5) { + p.logger.Errorf("listener is closed; retry listen %q", p.From()) + } + + if err = p.ResetListener(); err != nil { + select { + case p.errc <- err: + select { + case <-p.donec: + return + default: + } + case <-p.donec: + return + } + p.logger.Errorf("failed to reset listener %q", err.Error()) + } + } + + continue + } + + var out net.Conn + if !p.tlsInfo.Empty() { + var tp *http.Transport + tp, err = NewTransport(p.tlsInfo, p.dialTimeout) + if err != nil { + select { + case p.errc <- err: + select { + case <-p.donec: + return + default: + } + case <-p.donec: + return + } + continue + } + out, err = tp.Dial(p.to.Scheme, p.to.Host) + } else { + out, err = net.Dial(p.to.Scheme, p.to.Host) + } + if err != nil { + select { + case p.errc <- err: + select { + case <-p.donec: + return + default: + } + case <-p.donec: + return + } + if p.logger.V(5) { + p.logger.Errorf("dial error %q", err.Error()) + } + continue + } + + go func() { + // read incoming bytes from listener, dispatch to outgoing connection + p.transmit(out, in) + out.Close() + in.Close() + }() + go func() { + // read response from outgoing connection, write back to listener + p.receive(in, out) + in.Close() + out.Close() + }() + } +} + +func (p *proxy) transmit(dst io.Writer, src io.Reader) { p.ioCopy(dst, src, true) } +func (p *proxy) receive(dst io.Writer, src io.Reader) { p.ioCopy(dst, src, false) } +func (p *proxy) ioCopy(dst io.Writer, src io.Reader, proxySend bool) { + buf := make([]byte, p.bufferSize) + for { + nr, err := src.Read(buf) + if err != nil { + if err == io.EOF { + return + } + // connection already closed + if strings.HasSuffix(err.Error(), "read: connection reset by peer") { + return + } + if strings.HasSuffix(err.Error(), "use of closed network connection") { + return + } + select { + case p.errc <- err: + select { + case <-p.donec: + return + default: + } + case <-p.donec: + return + } + if p.logger.V(5) { + p.logger.Errorf("read error %q", err.Error()) + } + return + } + if nr == 0 { + return + } + data := buf[:nr] + + var pausec chan struct{} + var blackholec chan struct{} + if proxySend { + p.txMu.Lock() + pausec = p.pauseTxc + blackholec = p.blackholeTxc + p.txMu.Unlock() + } else { + p.rxMu.Lock() + pausec = p.pauseRxc + blackholec = p.blackholeRxc + p.rxMu.Unlock() + } + select { + case <-pausec: + case <-p.donec: + return + } + blackholed := false + select { + case <-blackholec: + blackholed = true + case <-p.donec: + return + default: + } + if blackholed { + if p.logger.V(5) { + if proxySend { + p.logger.Infof("dropped %s [%s -> %s]", humanize.Bytes(uint64(nr)), p.From(), p.To()) + } else { + p.logger.Infof("dropped %s [%s <- %s]", humanize.Bytes(uint64(nr)), p.From(), p.To()) + } + } + continue + } + + var lat time.Duration + if proxySend { + p.latencyTxMu.RLock() + lat = p.latencyTx + p.latencyTxMu.RUnlock() + } else { + p.latencyRxMu.RLock() + lat = p.latencyRx + p.latencyRxMu.RUnlock() + } + if lat > 0 { + select { + case <-time.After(lat): + case <-p.donec: + return + } + } + + if proxySend { + p.corruptTxMu.RLock() + if p.corruptTx != nil { + data = p.corruptTx(data) + } + p.corruptTxMu.RUnlock() + } else { + p.corruptRxMu.RLock() + if p.corruptRx != nil { + data = p.corruptRx(data) + } + p.corruptRxMu.RUnlock() + } + + var nw int + nw, err = dst.Write(data) + if err != nil { + if err == io.EOF { + return + } + select { + case p.errc <- err: + select { + case <-p.donec: + return + default: + } + case <-p.donec: + return + } + if p.logger.V(5) { + if proxySend { + p.logger.Errorf("write error while sending (%q)", err.Error()) + } else { + p.logger.Errorf("write error while receiving (%q)", err.Error()) + } + } + return + } + + if nr != nw { + select { + case p.errc <- io.ErrShortWrite: + select { + case <-p.donec: + return + default: + } + case <-p.donec: + return + } + if proxySend { + p.logger.Errorf("write error while sending (%q); read %d bytes != wrote %d bytes", io.ErrShortWrite.Error(), nr, nw) + } else { + p.logger.Errorf("write error while receiving (%q); read %d bytes != wrote %d bytes", io.ErrShortWrite.Error(), nr, nw) + } + return + } + + if p.logger.V(5) { + if proxySend { + p.logger.Infof("transmitted %s [%s -> %s]", humanize.Bytes(uint64(nr)), p.From(), p.To()) + } else { + p.logger.Infof("received %s [%s <- %s]", humanize.Bytes(uint64(nr)), p.From(), p.To()) + } + } + } +} + +func (p *proxy) Ready() <-chan struct{} { return p.readyc } +func (p *proxy) Done() <-chan struct{} { return p.donec } +func (p *proxy) Error() <-chan error { return p.errc } +func (p *proxy) Close() (err error) { + p.closeOnce.Do(func() { + close(p.donec) + p.listenerMu.Lock() + if p.listener != nil { + err = p.listener.Close() + p.logger.Infof("closed proxy listener on %q", p.From()) + } + p.listenerMu.Unlock() + }) + p.closeWg.Wait() + return err +} + +func (p *proxy) DelayAccept(latency, rv time.Duration) { + if latency <= 0 { + return + } + d := computeLatency(latency, rv) + p.latencyAcceptMu.Lock() + p.latencyAccept = d + p.latencyAcceptMu.Unlock() + p.logger.Infof("set accept latency %v(%v±%v) [%s -> %s]", d, latency, rv, p.From(), p.To()) +} + +func (p *proxy) UndelayAccept() { + p.latencyAcceptMu.Lock() + d := p.latencyAccept + p.latencyAccept = 0 + p.latencyAcceptMu.Unlock() + p.logger.Infof("removed accept latency %v [%s -> %s]", d, p.From(), p.To()) +} + +func (p *proxy) LatencyAccept() time.Duration { + p.latencyAcceptMu.RLock() + d := p.latencyAccept + p.latencyAcceptMu.RUnlock() + return d +} + +func (p *proxy) DelayTx(latency, rv time.Duration) { + if latency <= 0 { + return + } + d := computeLatency(latency, rv) + p.latencyTxMu.Lock() + p.latencyTx = d + p.latencyTxMu.Unlock() + p.logger.Infof("set transmit latency %v(%v±%v) [%s -> %s]", d, latency, rv, p.From(), p.To()) +} + +func (p *proxy) UndelayTx() { + p.latencyTxMu.Lock() + d := p.latencyTx + p.latencyTx = 0 + p.latencyTxMu.Unlock() + p.logger.Infof("removed transmit latency %v [%s -> %s]", d, p.From(), p.To()) +} + +func (p *proxy) LatencyTx() time.Duration { + p.latencyTxMu.RLock() + d := p.latencyTx + p.latencyTxMu.RUnlock() + return d +} + +func (p *proxy) DelayRx(latency, rv time.Duration) { + if latency <= 0 { + return + } + d := computeLatency(latency, rv) + p.latencyRxMu.Lock() + p.latencyRx = d + p.latencyRxMu.Unlock() + p.logger.Infof("set receive latency %v(%v±%v) [%s <- %s]", d, latency, rv, p.From(), p.To()) +} + +func (p *proxy) UndelayRx() { + p.latencyRxMu.Lock() + d := p.latencyRx + p.latencyRx = 0 + p.latencyRxMu.Unlock() + p.logger.Infof("removed receive latency %v [%s <- %s]", d, p.From(), p.To()) +} + +func (p *proxy) LatencyRx() time.Duration { + p.latencyRxMu.RLock() + d := p.latencyRx + p.latencyRxMu.RUnlock() + return d +} + +func computeLatency(lat, rv time.Duration) time.Duration { + if rv == 0 { + return lat + } + if rv < 0 { + rv *= -1 + } + if rv > lat { + rv = lat / 10 + } + now := time.Now() + mrand.Seed(int64(now.Nanosecond())) + sign := 1 + if now.Second()%2 == 0 { + sign = -1 + } + return lat + time.Duration(int64(sign)*mrand.Int63n(rv.Nanoseconds())) +} + +func (p *proxy) PauseAccept() { + p.acceptMu.Lock() + p.pauseAcceptc = make(chan struct{}) + p.acceptMu.Unlock() + p.logger.Infof("paused accepting new connections [%s -> %s]", p.From(), p.To()) +} + +func (p *proxy) UnpauseAccept() { + p.acceptMu.Lock() + select { + case <-p.pauseAcceptc: // already unpaused + case <-p.donec: + p.acceptMu.Unlock() + return + default: + close(p.pauseAcceptc) + } + p.acceptMu.Unlock() + p.logger.Infof("unpaused accepting new connections [%s -> %s]", p.From(), p.To()) +} + +func (p *proxy) PauseTx() { + p.txMu.Lock() + p.pauseTxc = make(chan struct{}) + p.txMu.Unlock() + p.logger.Infof("paused transmit listen [%s -> %s]", p.From(), p.To()) +} + +func (p *proxy) UnpauseTx() { + p.txMu.Lock() + select { + case <-p.pauseTxc: // already unpaused + case <-p.donec: + p.txMu.Unlock() + return + default: + close(p.pauseTxc) + } + p.txMu.Unlock() + p.logger.Infof("unpaused transmit listen [%s -> %s]", p.From(), p.To()) +} + +func (p *proxy) PauseRx() { + p.rxMu.Lock() + p.pauseRxc = make(chan struct{}) + p.rxMu.Unlock() + p.logger.Infof("paused receive listen [%s <- %s]", p.From(), p.To()) +} + +func (p *proxy) UnpauseRx() { + p.rxMu.Lock() + select { + case <-p.pauseRxc: // already unpaused + case <-p.donec: + p.rxMu.Unlock() + return + default: + close(p.pauseRxc) + } + p.rxMu.Unlock() + p.logger.Infof("unpaused receive listen [%s <- %s]", p.From(), p.To()) +} + +func (p *proxy) BlackholeTx() { + p.txMu.Lock() + select { + case <-p.blackholeTxc: // already blackholed + case <-p.donec: + p.txMu.Unlock() + return + default: + close(p.blackholeTxc) + } + p.txMu.Unlock() + p.logger.Infof("blackholed transmit [%s -> %s]", p.From(), p.To()) +} + +func (p *proxy) UnblackholeTx() { + p.txMu.Lock() + p.blackholeTxc = make(chan struct{}) + p.txMu.Unlock() + p.logger.Infof("unblackholed transmit [%s -> %s]", p.From(), p.To()) +} + +func (p *proxy) BlackholeRx() { + p.rxMu.Lock() + select { + case <-p.blackholeRxc: // already blackholed + case <-p.donec: + p.rxMu.Unlock() + return + default: + close(p.blackholeRxc) + } + p.rxMu.Unlock() + p.logger.Infof("blackholed receive [%s <- %s]", p.From(), p.To()) +} + +func (p *proxy) UnblackholeRx() { + p.rxMu.Lock() + p.blackholeRxc = make(chan struct{}) + p.rxMu.Unlock() + p.logger.Infof("unblackholed receive [%s <- %s]", p.From(), p.To()) +} + +func (p *proxy) CorruptTx(f func([]byte) []byte) { + p.corruptTxMu.Lock() + p.corruptTx = f + p.corruptTxMu.Unlock() + p.logger.Infof("corrupting transmit [%s -> %s]", p.From(), p.To()) +} + +func (p *proxy) UncorruptTx() { + p.corruptTxMu.Lock() + p.corruptTx = nil + p.corruptTxMu.Unlock() + p.logger.Infof("stopped corrupting transmit [%s -> %s]", p.From(), p.To()) +} + +func (p *proxy) CorruptRx(f func([]byte) []byte) { + p.corruptRxMu.Lock() + p.corruptRx = f + p.corruptRxMu.Unlock() + p.logger.Infof("corrupting receive [%s <- %s]", p.From(), p.To()) +} + +func (p *proxy) UncorruptRx() { + p.corruptRxMu.Lock() + p.corruptRx = nil + p.corruptRxMu.Unlock() + p.logger.Infof("stopped corrupting receive [%s <- %s]", p.From(), p.To()) +} + +func (p *proxy) ResetListener() error { + p.listenerMu.Lock() + defer p.listenerMu.Unlock() + + if err := p.listener.Close(); err != nil { + // already closed + if !strings.HasSuffix(err.Error(), "use of closed network connection") { + return err + } + } + + var ln net.Listener + var err error + if !p.tlsInfo.Empty() { + ln, err = NewListener(p.from.Host, p.from.Scheme, &p.tlsInfo) + } else { + ln, err = net.Listen(p.from.Scheme, p.from.Host) + } + if err != nil { + return err + } + p.listener = ln + + p.logger.Infof("reset listener %q", p.From()) + return nil +} diff --git a/pkg/transport/proxy_test.go b/pkg/transport/proxy_test.go new file mode 100644 index 000000000..58f9e253d --- /dev/null +++ b/pkg/transport/proxy_test.go @@ -0,0 +1,606 @@ +// Copyright 2018 The etcd Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package transport + +import ( + "bytes" + "crypto/tls" + "fmt" + "io/ioutil" + "math/rand" + "net" + "net/http" + "net/url" + "os" + "strings" + "testing" + "time" + + "google.golang.org/grpc/grpclog" +) + +var testTLSInfo = TLSInfo{ + KeyFile: "./fixtures/server.key.insecure", + CertFile: "./fixtures/server.crt", + TrustedCAFile: "./fixtures/ca.crt", + ClientCertAuth: true, +} + +func TestProxy_Unix_Insecure(t *testing.T) { testProxy(t, "unix", false, false) } +func TestProxy_TCP_Insecure(t *testing.T) { testProxy(t, "tcp", false, false) } +func TestProxy_Unix_Secure(t *testing.T) { testProxy(t, "unix", true, false) } +func TestProxy_TCP_Secure(t *testing.T) { testProxy(t, "tcp", true, false) } +func TestProxy_Unix_Insecure_DelayTx(t *testing.T) { testProxy(t, "unix", false, true) } +func TestProxy_TCP_Insecure_DelayTx(t *testing.T) { testProxy(t, "tcp", false, true) } +func TestProxy_Unix_Secure_DelayTx(t *testing.T) { testProxy(t, "unix", true, true) } +func TestProxy_TCP_Secure_DelayTx(t *testing.T) { testProxy(t, "tcp", true, true) } +func testProxy(t *testing.T, scheme string, secure bool, delayTx bool) { + srcAddr, dstAddr := newUnixAddr(), newUnixAddr() + if scheme == "tcp" { + ln1, ln2 := listen(t, "tcp", "localhost:0", TLSInfo{}), listen(t, "tcp", "localhost:0", TLSInfo{}) + srcAddr, dstAddr = ln1.Addr().String(), ln2.Addr().String() + ln1.Close() + ln2.Close() + } else { + defer func() { + os.RemoveAll(srcAddr) + os.RemoveAll(dstAddr) + }() + } + tlsInfo := testTLSInfo + if !secure { + tlsInfo = TLSInfo{} + } + ln := listen(t, scheme, dstAddr, tlsInfo) + defer ln.Close() + + cfg := ProxyConfig{ + From: url.URL{Scheme: scheme, Host: srcAddr}, + To: url.URL{Scheme: scheme, Host: dstAddr}, + Logger: grpclog.NewLoggerV2WithVerbosity(os.Stderr, os.Stderr, os.Stderr, 5), + } + if secure { + cfg.TLSInfo = testTLSInfo + } + p := NewProxy(cfg) + <-p.Ready() + defer p.Close() + + data1 := []byte("Hello World!") + donec, writec := make(chan struct{}), make(chan []byte) + + go func() { + defer close(donec) + for data := range writec { + send(t, data, scheme, srcAddr, tlsInfo) + } + }() + + recvc := make(chan []byte) + go func() { + for i := 0; i < 2; i++ { + recvc <- receive(t, ln) + } + }() + + writec <- data1 + now := time.Now() + if d := <-recvc; !bytes.Equal(data1, d) { + t.Fatalf("expected %q, got %q", string(data1), string(d)) + } + took1 := time.Since(now) + t.Logf("took %v with no latency", took1) + + lat, rv := 50*time.Millisecond, 5*time.Millisecond + if delayTx { + p.DelayTx(lat, rv) + } + + data2 := []byte("new data") + writec <- data2 + now = time.Now() + if d := <-recvc; !bytes.Equal(data2, d) { + t.Fatalf("expected %q, got %q", string(data2), string(d)) + } + took2 := time.Since(now) + if delayTx { + t.Logf("took %v with latency %v±%v", took2, lat, rv) + } else { + t.Logf("took %v with no latency", took2) + } + + if delayTx { + p.UndelayTx() + if took1 >= took2 { + t.Fatalf("expected took1 %v < took2 %v (with latency)", took1, took2) + } + } + + close(writec) + select { + case <-donec: + case <-time.After(3 * time.Second): + t.Fatal("took too long to write") + } + + select { + case <-p.Done(): + t.Fatal("unexpected done") + case err := <-p.Error(): + t.Fatal(err) + default: + } + + if err := p.Close(); err != nil { + t.Fatal(err) + } + + select { + case <-p.Done(): + case err := <-p.Error(): + if !strings.HasPrefix(err.Error(), "accept ") && + !strings.HasSuffix(err.Error(), "use of closed network connection") { + t.Fatal(err) + } + case <-time.After(3 * time.Second): + t.Fatal("took too long to close") + } +} + +func TestProxy_Unix_Insecure_DelayAccept(t *testing.T) { testProxyDelayAccept(t, false) } +func TestProxy_Unix_Secure_DelayAccept(t *testing.T) { testProxyDelayAccept(t, true) } +func testProxyDelayAccept(t *testing.T, secure bool) { + srcAddr, dstAddr := newUnixAddr(), newUnixAddr() + defer func() { + os.RemoveAll(srcAddr) + os.RemoveAll(dstAddr) + }() + tlsInfo := testTLSInfo + if !secure { + tlsInfo = TLSInfo{} + } + scheme := "unix" + ln := listen(t, scheme, dstAddr, tlsInfo) + defer ln.Close() + + cfg := ProxyConfig{ + From: url.URL{Scheme: scheme, Host: srcAddr}, + To: url.URL{Scheme: scheme, Host: dstAddr}, + Logger: grpclog.NewLoggerV2WithVerbosity(os.Stderr, os.Stderr, os.Stderr, 5), + } + if secure { + cfg.TLSInfo = testTLSInfo + } + p := NewProxy(cfg) + <-p.Ready() + defer p.Close() + + data := []byte("Hello World!") + + now := time.Now() + send(t, data, scheme, srcAddr, tlsInfo) + if d := receive(t, ln); !bytes.Equal(data, d) { + t.Fatalf("expected %q, got %q", string(data), string(d)) + } + took1 := time.Since(now) + t.Logf("took %v with no latency", took1) + + lat, rv := 700*time.Millisecond, 10*time.Millisecond + p.DelayAccept(lat, rv) + defer p.UndelayAccept() + if err := p.ResetListener(); err != nil { + t.Fatal(err) + } + time.Sleep(200 * time.Millisecond) + + now = time.Now() + send(t, data, scheme, srcAddr, tlsInfo) + if d := receive(t, ln); !bytes.Equal(data, d) { + t.Fatalf("expected %q, got %q", string(data), string(d)) + } + took2 := time.Since(now) + t.Logf("took %v with latency %v±%v", took2, lat, rv) + + if took1 >= took2 { + t.Fatalf("expected took1 %v < took2 %v", took1, took2) + } +} + +func TestProxy_PauseTx(t *testing.T) { + scheme := "unix" + srcAddr, dstAddr := newUnixAddr(), newUnixAddr() + defer func() { + os.RemoveAll(srcAddr) + os.RemoveAll(dstAddr) + }() + ln := listen(t, scheme, dstAddr, TLSInfo{}) + defer ln.Close() + + p := NewProxy(ProxyConfig{ + From: url.URL{Scheme: scheme, Host: srcAddr}, + To: url.URL{Scheme: scheme, Host: dstAddr}, + Logger: grpclog.NewLoggerV2WithVerbosity(os.Stderr, os.Stderr, os.Stderr, 5), + }) + <-p.Ready() + defer p.Close() + + p.PauseTx() + + data := []byte("Hello World!") + send(t, data, scheme, srcAddr, TLSInfo{}) + + recvc := make(chan []byte) + go func() { + recvc <- receive(t, ln) + }() + + select { + case d := <-recvc: + t.Fatalf("received unexpected data %q during pause", string(d)) + case <-time.After(200 * time.Millisecond): + } + + p.UnpauseTx() + + select { + case d := <-recvc: + if !bytes.Equal(data, d) { + t.Fatalf("expected %q, got %q", string(data), string(d)) + } + case <-time.After(2 * time.Second): + t.Fatal("took too long to receive after unpause") + } +} + +func TestProxy_BlackholeTx(t *testing.T) { + scheme := "unix" + srcAddr, dstAddr := newUnixAddr(), newUnixAddr() + defer func() { + os.RemoveAll(srcAddr) + os.RemoveAll(dstAddr) + }() + ln := listen(t, scheme, dstAddr, TLSInfo{}) + defer ln.Close() + + p := NewProxy(ProxyConfig{ + From: url.URL{Scheme: scheme, Host: srcAddr}, + To: url.URL{Scheme: scheme, Host: dstAddr}, + Logger: grpclog.NewLoggerV2WithVerbosity(os.Stderr, os.Stderr, os.Stderr, 5), + }) + <-p.Ready() + defer p.Close() + + p.BlackholeTx() + + data := []byte("Hello World!") + send(t, data, scheme, srcAddr, TLSInfo{}) + + recvc := make(chan []byte) + go func() { + recvc <- receive(t, ln) + }() + + select { + case d := <-recvc: + t.Fatalf("unexpected data receive %q during blackhole", string(d)) + case <-time.After(200 * time.Millisecond): + } + + p.UnblackholeTx() + + // expect different data, old data dropped + data[0]++ + send(t, data, scheme, srcAddr, TLSInfo{}) + + select { + case d := <-recvc: + if !bytes.Equal(data, d) { + t.Fatalf("expected %q, got %q", string(data), string(d)) + } + case <-time.After(2 * time.Second): + t.Fatal("took too long to receive after unblackhole") + } +} + +func TestProxy_CorruptTx(t *testing.T) { + scheme := "unix" + srcAddr, dstAddr := newUnixAddr(), newUnixAddr() + defer func() { + os.RemoveAll(srcAddr) + os.RemoveAll(dstAddr) + }() + ln := listen(t, scheme, dstAddr, TLSInfo{}) + defer ln.Close() + + p := NewProxy(ProxyConfig{ + From: url.URL{Scheme: scheme, Host: srcAddr}, + To: url.URL{Scheme: scheme, Host: dstAddr}, + Logger: grpclog.NewLoggerV2WithVerbosity(os.Stderr, os.Stderr, os.Stderr, 5), + }) + <-p.Ready() + defer p.Close() + + p.CorruptTx(func(d []byte) []byte { + d[len(d)/2]++ + return d + }) + data := []byte("Hello World!") + send(t, data, scheme, srcAddr, TLSInfo{}) + if d := receive(t, ln); bytes.Equal(d, data) { + t.Fatalf("expected corrupted data, got %q", string(d)) + } + + p.UncorruptTx() + send(t, data, scheme, srcAddr, TLSInfo{}) + if d := receive(t, ln); !bytes.Equal(d, data) { + t.Fatalf("expected uncorrupted data, got %q", string(d)) + } +} + +func TestProxy_Shutdown(t *testing.T) { + scheme := "unix" + srcAddr, dstAddr := newUnixAddr(), newUnixAddr() + defer func() { + os.RemoveAll(srcAddr) + os.RemoveAll(dstAddr) + }() + ln := listen(t, scheme, dstAddr, TLSInfo{}) + defer ln.Close() + + p := NewProxy(ProxyConfig{ + From: url.URL{Scheme: scheme, Host: srcAddr}, + To: url.URL{Scheme: scheme, Host: dstAddr}, + Logger: grpclog.NewLoggerV2WithVerbosity(os.Stderr, os.Stderr, os.Stderr, 5), + }) + <-p.Ready() + defer p.Close() + + px, _ := p.(*proxy) + px.listener.Close() + time.Sleep(200 * time.Millisecond) + + data := []byte("Hello World!") + send(t, data, scheme, srcAddr, TLSInfo{}) + if d := receive(t, ln); !bytes.Equal(d, data) { + t.Fatalf("expected %q, got %q", string(data), string(d)) + } +} + +func TestProxy_ShutdownListener(t *testing.T) { + scheme := "unix" + srcAddr, dstAddr := newUnixAddr(), newUnixAddr() + defer func() { + os.RemoveAll(srcAddr) + os.RemoveAll(dstAddr) + }() + + ln := listen(t, scheme, dstAddr, TLSInfo{}) + defer ln.Close() + + p := NewProxy(ProxyConfig{ + From: url.URL{Scheme: scheme, Host: srcAddr}, + To: url.URL{Scheme: scheme, Host: dstAddr}, + Logger: grpclog.NewLoggerV2WithVerbosity(os.Stderr, os.Stderr, os.Stderr, 5), + }) + <-p.Ready() + defer p.Close() + + // shut down destination + ln.Close() + time.Sleep(200 * time.Millisecond) + + ln = listen(t, scheme, dstAddr, TLSInfo{}) + defer ln.Close() + + data := []byte("Hello World!") + send(t, data, scheme, srcAddr, TLSInfo{}) + if d := receive(t, ln); !bytes.Equal(d, data) { + t.Fatalf("expected %q, got %q", string(data), string(d)) + } +} + +func TestProxyHTTP_Insecure_DelayTx(t *testing.T) { testProxyHTTP(t, false, true) } +func TestProxyHTTP_Secure_DelayTx(t *testing.T) { testProxyHTTP(t, true, true) } +func TestProxyHTTP_Insecure_DelayRx(t *testing.T) { testProxyHTTP(t, false, false) } +func TestProxyHTTP_Secure_DelayRx(t *testing.T) { testProxyHTTP(t, true, false) } +func testProxyHTTP(t *testing.T, secure, delayTx bool) { + scheme := "tcp" + ln1, ln2 := listen(t, scheme, "localhost:0", TLSInfo{}), listen(t, scheme, "localhost:0", TLSInfo{}) + srcAddr, dstAddr := ln1.Addr().String(), ln2.Addr().String() + ln1.Close() + ln2.Close() + + mux := http.NewServeMux() + mux.HandleFunc("/hello", func(w http.ResponseWriter, req *http.Request) { + d, err := ioutil.ReadAll(req.Body) + if err != nil { + t.Fatal(err) + } + if _, err = w.Write([]byte(fmt.Sprintf("%q(confirmed)", string(d)))); err != nil { + t.Fatal(err) + } + }) + var tlsConfig *tls.Config + var err error + if secure { + tlsConfig, err = testTLSInfo.ServerConfig() + if err != nil { + t.Fatal(err) + } + } + srv := &http.Server{ + Addr: dstAddr, + Handler: mux, + TLSConfig: tlsConfig, + } + + donec := make(chan struct{}) + defer func() { + srv.Close() + <-donec + }() + go func() { + defer close(donec) + if !secure { + srv.ListenAndServe() + } else { + srv.ListenAndServeTLS(testTLSInfo.CertFile, testTLSInfo.KeyFile) + } + }() + time.Sleep(200 * time.Millisecond) + + cfg := ProxyConfig{ + From: url.URL{Scheme: scheme, Host: srcAddr}, + To: url.URL{Scheme: scheme, Host: dstAddr}, + Logger: grpclog.NewLoggerV2WithVerbosity(os.Stderr, os.Stderr, os.Stderr, 5), + } + if secure { + cfg.TLSInfo = testTLSInfo + } + p := NewProxy(cfg) + <-p.Ready() + defer p.Close() + + data := "Hello World!" + + now := time.Now() + var resp *http.Response + if secure { + tp, terr := NewTransport(testTLSInfo, 3*time.Second) + if terr != nil { + t.Fatal(terr) + } + cli := &http.Client{Transport: tp} + resp, err = cli.Post("https://"+srcAddr+"/hello", "", strings.NewReader(data)) + } else { + resp, err = http.Post("http://"+srcAddr+"/hello", "", strings.NewReader(data)) + } + if err != nil { + t.Fatal(err) + } + d, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + took1 := time.Since(now) + t.Logf("took %v with no latency", took1) + + rs1 := string(d) + exp := fmt.Sprintf("%q(confirmed)", data) + if rs1 != exp { + t.Fatalf("got %q, expected %q", rs1, exp) + } + + lat, rv := 100*time.Millisecond, 10*time.Millisecond + if delayTx { + p.DelayTx(lat, rv) + defer p.UndelayTx() + } else { + p.DelayRx(lat, rv) + defer p.UndelayRx() + } + + now = time.Now() + if secure { + tp, terr := NewTransport(testTLSInfo, 3*time.Second) + if terr != nil { + t.Fatal(terr) + } + cli := &http.Client{Transport: tp} + resp, err = cli.Post("https://"+srcAddr+"/hello", "", strings.NewReader(data)) + } else { + resp, err = http.Post("http://"+srcAddr+"/hello", "", strings.NewReader(data)) + } + if err != nil { + t.Fatal(err) + } + d, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + took2 := time.Since(now) + t.Logf("took %v with latency %v±%v", took2, lat, rv) + + rs2 := string(d) + if rs2 != exp { + t.Fatalf("got %q, expected %q", rs2, exp) + } + if took1 > took2 { + t.Fatalf("expected took1 %v < took2 %v", took1, took2) + } +} + +func newUnixAddr() string { + now := time.Now().UnixNano() + rand.Seed(now) + addr := fmt.Sprintf("%X%X.unix-conn", now, rand.Intn(35000)) + os.RemoveAll(addr) + return addr +} + +func listen(t *testing.T, scheme, addr string, tlsInfo TLSInfo) (ln net.Listener) { + var err error + if !tlsInfo.Empty() { + ln, err = NewListener(addr, scheme, &tlsInfo) + } else { + ln, err = net.Listen(scheme, addr) + } + if err != nil { + t.Fatal(err) + } + return ln +} + +func send(t *testing.T, data []byte, scheme, addr string, tlsInfo TLSInfo) { + var out net.Conn + var err error + if !tlsInfo.Empty() { + tp, terr := NewTransport(tlsInfo, 3*time.Second) + if terr != nil { + t.Fatal(terr) + } + out, err = tp.Dial(scheme, addr) + } else { + out, err = net.Dial(scheme, addr) + } + if err != nil { + t.Fatal(err) + } + if _, err = out.Write(data); err != nil { + t.Fatal(err) + } + if err = out.Close(); err != nil { + t.Fatal(err) + } +} + +func receive(t *testing.T, ln net.Listener) (data []byte) { + buf := bytes.NewBuffer(make([]byte, 0, 1024)) + for { + in, err := ln.Accept() + if err != nil { + t.Fatal(err) + } + var n int64 + n, err = buf.ReadFrom(in) + if err != nil { + t.Fatal(err) + } + if n > 0 { + break + } + } + return buf.Bytes() +}