Fix Blackhole implementation for e2e tests

Based on Fu Wei's idea discussed in the issue [1], we employ the network traffic blocking on L7, using a forward proxy, without the need to use external tools.

[Background]

A peer will
(a) receive traffic from its peers
(b) initiate connections to its peers (via stream and pipeline).

Thus, the current mechanism of only blocking peer traffic via the peer's existing reverse proxy is insufficient, since only scenario (a) is handled, and scenario (b) is not blocked at all.

[Proposed solution]

We introduce a forward proxy for each peer, which will be proxying all the connections initiated from a peer to its peers.

We will remove the current use of the reverse proxy, as the forward proxy holds the information of the destination, we can block all in and out traffic that is initiated from a peer to others, without having to resort to external tools, such as iptables.

The modified architecture will look something like this:
```
A --- A's forward proxy ----- B
   ^ newly introduced
```

It's verified that the blocking of traffic is complete, compared to previous solutions attempted in PRs [2][3].

[Implementation]

The main subtasks are
- redesigned as an L7 forward proxy
- Unix socket support is dropped: e2e test supports unix sockets for peer communication, but only several e2e test cases use Unix sockets as majority of e2e test cases use HTTP/HTTPS
- introduce a new environment variable `E2E_TEST_FORWARD_PROXY_IP`
- implement L7 forward proxy by drastically reducing the existing proxy server code and design to use blocking traffic

Known limitations are
- Doesn't support unix socket (L7 HTTP transport proxy only supports HTTP/HTTPS/and socks5)
- It's L7 so we need to send a perfectly crafted HTTP request
-Doesn’t support reordering, dropping, etc. packets on-the-fly

[Testing]
- `make gofail-enable && make build && make gofail-disable && go test -timeout 60s -run ^TestBlackholeByMockingPartitionLeader$ go.etcd.io/etcd/tests/v3/e2e -v -count=1`
- `make gofail-enable && make build && make gofail-disable && go test -timeout 60s -run ^TestBlackholeByMockingPartitionFollower$ go.etcd.io/etcd/tests/v3/e2e -v -count=1`
- `go test -timeout 30s -run ^TestServer_ go.etcd.io/etcd/pkg/v3/proxy -v -failfast`

[References]
[1] issue https://github.com/etcd-io/etcd/issues/17737
[2] PR (V1) https://github.com/henrybear327/etcd/tree/fix/e2e_blackhole
[3] PR (V2) https://github.com/etcd-io/etcd/pull/17891
[4] https://github.com/etcd-io/etcd/pull/17938#discussion_r1615622709
[5] https://github.com/etcd-io/etcd/pull/17985#discussion_r1598020110

Signed-off-by: Siyuan Zhang <sizhang@google.com>
Co-authored-by: Iván Valdés Castillo <iv@nvald.es>
Signed-off-by: Chun-Hung Tseng <henrybear327@gmail.com>
This commit is contained in:
Siyuan Zhang 2024-04-08 10:07:26 -07:00 committed by Chun-Hung Tseng
parent 5704c6148d
commit ac592a2f97
No known key found for this signature in database
GPG Key ID: EF93C20F55FB48BB
8 changed files with 883 additions and 828 deletions

View File

@ -18,12 +18,29 @@ import (
"context" "context"
"net" "net"
"net/http" "net/http"
"net/url"
"os"
"strings" "strings"
"time" "time"
) )
type unixTransport struct{ *http.Transport } type unixTransport struct{ *http.Transport }
var httpTransportProxyParsingFunc = determineHTTPTransportProxyParsingFunc
func determineHTTPTransportProxyParsingFunc() func(req *http.Request) (*url.URL, error) {
// according to the comment of http.ProxyFromEnvironment: if the proxy URL is "localhost"
// (with or without a port number), then a nil URL and nil error will be returned.
// Thus, we workaround this limitation by manually setting an ENV named E2E_TEST_FORWARD_PROXY_IP
// and parse the URL (which is a localhost in our case)
if forwardProxy, exists := os.LookupEnv("E2E_TEST_FORWARD_PROXY_IP"); exists {
return func(req *http.Request) (*url.URL, error) {
return url.Parse(forwardProxy)
}
}
return http.ProxyFromEnvironment
}
func NewTransport(info TLSInfo, dialtimeoutd time.Duration) (*http.Transport, error) { func NewTransport(info TLSInfo, dialtimeoutd time.Duration) (*http.Transport, error) {
cfg, err := info.ClientConfig() cfg, err := info.ClientConfig()
if err != nil { if err != nil {
@ -39,7 +56,7 @@ func NewTransport(info TLSInfo, dialtimeoutd time.Duration) (*http.Transport, er
} }
t := &http.Transport{ t := &http.Transport{
Proxy: http.ProxyFromEnvironment, Proxy: httpTransportProxyParsingFunc(),
DialContext: (&net.Dialer{ DialContext: (&net.Dialer{
Timeout: dialtimeoutd, Timeout: dialtimeoutd,
LocalAddr: ipAddr, LocalAddr: ipAddr,
@ -60,7 +77,7 @@ func NewTransport(info TLSInfo, dialtimeoutd time.Duration) (*http.Transport, er
return dialer.DialContext(ctx, "unix", addr) return dialer.DialContext(ctx, "unix", addr)
} }
tu := &http.Transport{ tu := &http.Transport{
Proxy: http.ProxyFromEnvironment, Proxy: httpTransportProxyParsingFunc(),
DialContext: dialContext, DialContext: dialContext,
TLSHandshakeTimeout: 10 * time.Second, TLSHandshakeTimeout: 10 * time.Second,
TLSClientConfig: cfg, TLSClientConfig: cfg,

View File

@ -19,6 +19,8 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log"
"math/bits"
mrand "math/rand" mrand "math/rand"
"net" "net"
"net/http" "net/http"
@ -44,11 +46,17 @@ var (
// latency spikes and packet drop or corruption. The proxy overhead is very // latency spikes and packet drop or corruption. The proxy overhead is very
// small overhead (<500μs per request). Please run tests to compute actual // small overhead (<500μs per request). Please run tests to compute actual
// overhead. // overhead.
//
// Note that the current implementation is a forward proxy, thus, unix socket
// is not supported, due to the forwarding is done in L7, which requires
// properly constructed HTTP header and body
//
// Also, because we are forced to use TLS to communicate with the proxy server
// and using well-formed header to talk to the destination server,
// we can't do random modification on the data on-the-fly anymore.
type Server interface { type Server interface {
// From returns proxy source address in "scheme://host:port" format. // Listen returns proxy listen address in "scheme://host:port" format.
From() string Listen() string
// To returns proxy destination address in "scheme://host:port" format.
To() string
// Ready returns when proxy is ready to serve. // Ready returns when proxy is ready to serve.
Ready() <-chan struct{} Ready() <-chan struct{}
@ -115,6 +123,16 @@ type Server interface {
// UnblackholeRx removes blackhole operation on "receiving". // UnblackholeRx removes blackhole operation on "receiving".
UnblackholeRx() UnblackholeRx()
// BlackholePeerTx drops all outgoing traffic of a peer.
BlackholePeerTx(peer url.URL)
// UnblackholePeerTx removes blackhole operation on "sending".
UnblackholePeerTx(peer url.URL)
// BlackholePeerTx drops all incoming traffic of a peer.
BlackholePeerRx(peer url.URL)
// UnblackholePeerRx removes blackhole operation on "receiving".
UnblackholePeerRx(peer url.URL)
// PauseTx stops "forwarding" packets; "outgoing" traffic blocks. // PauseTx stops "forwarding" packets; "outgoing" traffic blocks.
PauseTx() PauseTx()
// UnpauseTx removes "forwarding" pause operation. // UnpauseTx removes "forwarding" pause operation.
@ -124,29 +142,29 @@ type Server interface {
PauseRx() PauseRx()
// UnpauseRx removes "receiving" pause operation. // UnpauseRx removes "receiving" pause operation.
UnpauseRx() UnpauseRx()
// ResetListener closes and restarts listener.
ResetListener() error
} }
// ServerConfig defines proxy server configuration. // ServerConfig defines proxy server configuration.
type ServerConfig struct { type ServerConfig struct {
Logger *zap.Logger Logger *zap.Logger
From url.URL Listen url.URL
To url.URL
TLSInfo transport.TLSInfo TLSInfo transport.TLSInfo
DialTimeout time.Duration DialTimeout time.Duration
BufferSize int BufferSize int
RetryInterval time.Duration RetryInterval time.Duration
} }
const (
blackholePeerTypeNone uint8 = iota
blackholePeerTypeTx
blackholePeerTypeRx
)
type server struct { type server struct {
lg *zap.Logger lg *zap.Logger
from url.URL listen url.URL
fromPort int listenPort int
to url.URL
toPort int
tlsInfo transport.TLSInfo tlsInfo transport.TLSInfo
dialTimeout time.Duration dialTimeout time.Duration
@ -160,9 +178,10 @@ type server struct {
closeOnce sync.Once closeOnce sync.Once
closeWg sync.WaitGroup closeWg sync.WaitGroup
closeHijackedConn sync.WaitGroup
listenerMu sync.RWMutex listenerMu sync.RWMutex
listener net.Listener listener *customListener
pauseAcceptMu sync.Mutex pauseAcceptMu sync.Mutex
pauseAcceptc chan struct{} pauseAcceptc chan struct{}
@ -187,6 +206,11 @@ type server struct {
latencyRxMu sync.RWMutex latencyRxMu sync.RWMutex
latencyRx time.Duration latencyRx time.Duration
blackholePeerMap map[int]uint8 // port number, blackhole type
blackholePeerMapMu sync.RWMutex
httpServer *http.Server
} }
// NewServer returns a proxy implementation with no iptables/tc dependencies. // NewServer returns a proxy implementation with no iptables/tc dependencies.
@ -195,8 +219,7 @@ func NewServer(cfg ServerConfig) Server {
s := &server{ s := &server{
lg: cfg.Logger, lg: cfg.Logger,
from: cfg.From, listen: cfg.Listen,
to: cfg.To,
tlsInfo: cfg.TLSInfo, tlsInfo: cfg.TLSInfo,
dialTimeout: cfg.DialTimeout, dialTimeout: cfg.DialTimeout,
@ -211,17 +234,12 @@ func NewServer(cfg ServerConfig) Server {
pauseAcceptc: make(chan struct{}), pauseAcceptc: make(chan struct{}),
pauseTxc: make(chan struct{}), pauseTxc: make(chan struct{}),
pauseRxc: make(chan struct{}), pauseRxc: make(chan struct{}),
blackholePeerMap: make(map[int]uint8),
} }
_, fromPort, err := net.SplitHostPort(cfg.From.Host) var err error
if err == nil { var fromPort string
s.fromPort, _ = strconv.Atoi(fromPort)
}
var toPort string
_, toPort, err = net.SplitHostPort(cfg.To.Host)
if err == nil {
s.toPort, _ = strconv.Atoi(toPort)
}
if s.dialTimeout == 0 { if s.dialTimeout == 0 {
s.dialTimeout = defaultDialTimeout s.dialTimeout = defaultDialTimeout
@ -237,183 +255,276 @@ func NewServer(cfg ServerConfig) Server {
close(s.pauseTxc) close(s.pauseTxc)
close(s.pauseRxc) close(s.pauseRxc)
if strings.HasPrefix(s.from.Scheme, "http") { // L7 is http (scheme), L4 is tcp (network listener)
s.from.Scheme = "tcp" addr := ""
if strings.HasPrefix(s.listen.Scheme, "http") {
s.listen.Scheme = "tcp"
if _, fromPort, err = net.SplitHostPort(cfg.Listen.Host); err != nil {
s.errc <- err
s.Close()
return nil
} }
if strings.HasPrefix(s.to.Scheme, "http") { if s.listenPort, err = strconv.Atoi(fromPort); err != nil {
s.to.Scheme = "tcp" s.errc <- err
s.Close()
return nil
} }
addr := fmt.Sprintf(":%d", s.fromPort) addr = fmt.Sprintf(":%d", s.listenPort)
if s.fromPort == 0 { // unix } else {
addr = s.from.Host panic(fmt.Sprintf("%s is not supported", s.listen.Scheme))
} }
s.closeWg.Add(1)
var ln net.Listener var ln net.Listener
if !s.tlsInfo.Empty() { if !s.tlsInfo.Empty() {
ln, err = transport.NewListener(addr, s.from.Scheme, &s.tlsInfo) ln, err = transport.NewListener(addr, s.listen.Scheme, &s.tlsInfo)
} else { } else {
ln, err = net.Listen(s.from.Scheme, addr) ln, err = net.Listen(s.listen.Scheme, addr)
} }
if err != nil { if err != nil {
s.errc <- err s.errc <- err
s.Close() s.Close()
return s return nil
}
s.listener = ln
s.closeWg.Add(1)
go s.listenAndServe()
s.lg.Info("started proxying", zap.String("from", s.From()), zap.String("to", s.To()))
return s
} }
func (s *server) From() string { s.listener = &customListener{
return fmt.Sprintf("%s://%s", s.from.Scheme, s.from.Host) s: s,
l: &ln,
} }
func (s *server) To() string { go func() {
return fmt.Sprintf("%s://%s", s.to.Scheme, s.to.Host)
}
// TODO: implement packet reordering from multiple TCP connections
// buffer packets per connection for awhile, reorder before transmit
// - https://github.com/etcd-io/etcd/issues/5614
// - https://github.com/etcd-io/etcd/pull/6918#issuecomment-264093034
func (s *server) listenAndServe() {
defer s.closeWg.Done() defer s.closeWg.Done()
ctx := context.Background() s.httpServer = &http.Server{
s.lg.Info("proxy is listening on", zap.String("from", s.From())) Handler: &serverHandler{s: s},
close(s.readyc)
for {
s.pauseAcceptMu.Lock()
pausec := s.pauseAcceptc
s.pauseAcceptMu.Unlock()
select {
case <-pausec:
case <-s.donec:
return
} }
s.latencyAcceptMu.RLock() s.lg.Info("proxy is listening on", zap.String("listen on", s.Listen()))
lat := s.latencyAccept close(s.readyc)
s.latencyAcceptMu.RUnlock() if err := s.httpServer.Serve(s.listener); err != http.ErrServerClosed {
// always returns error. ErrServerClosed on graceful close
panic(fmt.Sprintf("startHTTPServer Serve(): %v", err))
}
}()
s.lg.Info("started proxying", zap.String("listen on", s.Listen()))
return s
}
// Because we are implementing L7 proxy, but would like to keep the L4 features,
// thus, we need to encapsulate the L4 functionalities in our custom Listener
type customListener struct {
s *server
l *net.Listener
}
func (c *customListener) Accept() (net.Conn, error) {
// we implement the L4 features here (pause / latency accept)
c.s.pauseAcceptMu.Lock()
pausec := c.s.pauseAcceptc
c.s.pauseAcceptMu.Unlock()
select {
case <-pausec:
case <-c.s.donec:
return nil, fmt.Errorf("listener is closed")
}
c.s.latencyAcceptMu.RLock()
lat := c.s.latencyAccept
c.s.lg.Info(
"get accept latency",
zap.Duration("latency", lat),
)
c.s.latencyAcceptMu.RUnlock()
if lat > 0 { if lat > 0 {
select { select {
case <-time.After(lat): case <-time.After(lat):
case <-s.donec: case <-c.s.donec:
return return nil, fmt.Errorf("listener is closed")
} }
} }
s.listenerMu.RLock() c.s.listenerMu.RLock()
ln := s.listener conn, err := (*c.l).Accept()
s.listenerMu.RUnlock() c.s.listenerMu.RUnlock()
in, err := ln.Accept()
if err != nil { if err != nil {
select { select {
case s.errc <- err: case c.s.errc <- err:
select { select {
case <-s.donec: case <-c.s.donec:
return return nil, err
default: default:
} }
case <-s.donec: case <-c.s.donec:
return return nil, err
} }
s.lg.Debug("listener accept error", zap.Error(err)) c.s.lg.Debug("listener accept error", zap.Error(err))
if strings.HasSuffix(err.Error(), "use of closed network connection") { if strings.HasSuffix(err.Error(), "use of closed network connection") {
select { select {
case <-time.After(s.retryInterval): case <-time.After(c.s.retryInterval):
case <-s.donec: case <-c.s.donec:
return return nil, err
}
c.s.lg.Debug("listener is closed")
}
} }
s.lg.Debug("listener is closed; retry listening on", zap.String("from", s.From()))
if err = s.ResetListener(); err != nil { return conn, err
}
// Close closes the listener.
// Any blocked Accept operations will be unblocked and return errors.
func (c *customListener) Close() error {
c.s.listenerMu.RLock()
defer c.s.listenerMu.RUnlock()
return (*c.l).Close()
}
// Addr returns the listener's network address.
func (c *customListener) Addr() net.Addr {
c.s.listenerMu.RLock()
defer c.s.listenerMu.RUnlock()
return (*c.l).Addr()
}
type serverHandler struct {
s *server
}
func (sh *serverHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
hijacker, _ := resp.(http.Hijacker)
in, _, err := hijacker.Hijack()
if err != nil {
select { select {
case s.errc <- err: case sh.s.errc <- err:
select { select {
case <-s.donec: case <-sh.s.donec:
return return
default: default:
} }
case <-s.donec: case <-sh.s.donec:
return return
} }
s.lg.Warn("failed to reset listener", zap.Error(err)) sh.s.lg.Debug("ServeHTTP hijack error", zap.Error(err))
} panic(err)
} }
continue targetScheme := "tcp"
targetHost := req.URL.Host
ctx := context.Background()
/*
If the traffic to the destination is HTTPS, a CONNECT request will be sent
first (containing the intended destination HOST).
If the traffic to the destination is HTTP, no CONNECT request will be sent
first. Only normal HTTP request is sent, with the HOST set to the final destination.
This will be troublesome since we need to manually forward the request to the
destination, and we can't do bte stream manipulation.
Thus, we need to send the traffic to destination with HTTPS, allowing us to
handle byte streams.
*/
if req.Method == "CONNECT" {
// for CONNECT, we need to send 200 response back first
in.Write([]byte("HTTP/1.0 200 Connection established\r\n\r\n"))
} }
var out net.Conn var out net.Conn
if !s.tlsInfo.Empty() { if !sh.s.tlsInfo.Empty() {
var tp *http.Transport var tp *http.Transport
tp, err = transport.NewTransport(s.tlsInfo, s.dialTimeout) tp, err = transport.NewTransport(sh.s.tlsInfo, sh.s.dialTimeout)
if err != nil { if err != nil {
select { select {
case s.errc <- err: case sh.s.errc <- err:
select { select {
case <-s.donec: case <-sh.s.donec:
return return
default: default:
} }
case <-s.donec: case <-sh.s.donec:
return return
} }
continue sh.s.lg.Debug("failed to get new Transport", zap.Error(err))
return
} }
out, err = tp.DialContext(ctx, s.to.Scheme, s.to.Host) out, err = tp.DialContext(ctx, targetScheme, targetHost)
} else { } else {
out, err = net.Dial(s.to.Scheme, s.to.Host) out, err = net.Dial(targetScheme, targetHost)
} }
if err != nil { if err != nil {
select { select {
case s.errc <- err: case sh.s.errc <- err:
select { select {
case <-s.donec: case <-sh.s.donec:
return return
default: default:
} }
case <-s.donec: case <-sh.s.donec:
return return
} }
s.lg.Debug("failed to dial", zap.Error(err)) sh.s.lg.Debug("failed to dial", zap.Error(err))
continue return
} }
s.closeWg.Add(2) var dstPort int
dstPort, err = getPort(out.RemoteAddr())
if err != nil {
select {
case sh.s.errc <- err:
select {
case <-sh.s.donec:
return
default:
}
case <-sh.s.donec:
return
}
sh.s.lg.Debug("failed to parse port in transmit", zap.Error(err))
return
}
sh.s.closeHijackedConn.Add(2)
go func() { go func() {
defer s.closeWg.Done() defer sh.s.closeHijackedConn.Done()
// read incoming bytes from listener, dispatch to outgoing connection // read incoming bytes from listener, dispatch to outgoing connection
s.transmit(out, in) sh.s.transmit(out, in, dstPort)
out.Close() out.Close()
in.Close() in.Close()
}() }()
go func() { go func() {
defer s.closeWg.Done() defer sh.s.closeHijackedConn.Done()
// read response from outgoing connection, write back to listener // read response from outgoing connection, write back to listener
s.receive(in, out) sh.s.receive(in, out, dstPort)
in.Close() in.Close()
out.Close() out.Close()
}() }()
} }
func (s *server) Listen() string {
return fmt.Sprintf("%s://%s", s.listen.Scheme, s.listen.Host)
} }
func (s *server) transmit(dst io.Writer, src io.Reader) { func getPort(addr net.Addr) (int, error) {
s.ioCopy(dst, src, proxyTx) switch addr := addr.(type) {
case *net.TCPAddr:
return addr.Port, nil
case *net.UDPAddr:
return addr.Port, nil
default:
return 0, fmt.Errorf("unsupported address type: %T", addr)
}
} }
func (s *server) receive(dst io.Writer, src io.Reader) { func (s *server) transmit(dst, src net.Conn, port int) {
s.ioCopy(dst, src, proxyRx) s.ioCopy(dst, src, proxyTx, port)
}
func (s *server) receive(dst, src net.Conn, port int) {
s.ioCopy(dst, src, proxyRx, port)
} }
type proxyType uint8 type proxyType uint8
@ -423,7 +534,7 @@ const (
proxyRx proxyRx
) )
func (s *server) ioCopy(dst io.Writer, src io.Reader, ptype proxyType) { func (s *server) ioCopy(dst, src net.Conn, ptype proxyType, peerPort int) {
buf := make([]byte, s.bufferSize) buf := make([]byte, s.bufferSize)
for { for {
nr1, err := src.Read(buf) nr1, err := src.Read(buf)
@ -464,12 +575,30 @@ func (s *server) ioCopy(dst io.Writer, src io.Reader, ptype proxyType) {
data = s.modifyTx(data) data = s.modifyTx(data)
} }
s.modifyTxMu.RUnlock() s.modifyTxMu.RUnlock()
s.blackholePeerMapMu.RLock()
// Tx from other peers is Rx for the target peer
if val, exist := s.blackholePeerMap[peerPort]; exist {
if (val & blackholePeerTypeRx) > 0 {
data = nil
}
}
s.blackholePeerMapMu.RUnlock()
case proxyRx: case proxyRx:
s.modifyRxMu.RLock() s.modifyRxMu.RLock()
if s.modifyRx != nil { if s.modifyRx != nil {
data = s.modifyRx(data) data = s.modifyRx(data)
} }
s.modifyRxMu.RUnlock() s.modifyRxMu.RUnlock()
s.blackholePeerMapMu.RLock()
// Rx from other peers is Tx for the target peer
if val, exist := s.blackholePeerMap[peerPort]; exist {
if (val & blackholePeerTypeTx) > 0 {
data = nil
}
}
s.blackholePeerMapMu.RUnlock()
default: default:
panic("unknown proxy type") panic("unknown proxy type")
} }
@ -480,16 +609,16 @@ func (s *server) ioCopy(dst io.Writer, src io.Reader, ptype proxyType) {
"modified tx", "modified tx",
zap.String("data-received", humanize.Bytes(uint64(nr1))), zap.String("data-received", humanize.Bytes(uint64(nr1))),
zap.String("data-modified", humanize.Bytes(uint64(nr2))), zap.String("data-modified", humanize.Bytes(uint64(nr2))),
zap.String("from", s.From()), zap.String("proxy listening on", s.Listen()),
zap.String("to", s.To()), zap.Int("to peer port", peerPort),
) )
case proxyRx: case proxyRx:
s.lg.Debug( s.lg.Debug(
"modified rx", "modified rx",
zap.String("data-received", humanize.Bytes(uint64(nr1))), zap.String("data-received", humanize.Bytes(uint64(nr1))),
zap.String("data-modified", humanize.Bytes(uint64(nr2))), zap.String("data-modified", humanize.Bytes(uint64(nr2))),
zap.String("from", s.To()), zap.String("proxy listening on", s.Listen()),
zap.String("to", s.From()), zap.Int("to peer port", peerPort),
) )
default: default:
panic("unknown proxy type") panic("unknown proxy type")
@ -607,15 +736,15 @@ func (s *server) ioCopy(dst io.Writer, src io.Reader, ptype proxyType) {
s.lg.Debug( s.lg.Debug(
"transmitted", "transmitted",
zap.String("data-size", humanize.Bytes(uint64(nr1))), zap.String("data-size", humanize.Bytes(uint64(nr1))),
zap.String("from", s.From()), zap.String("proxy listening on", s.Listen()),
zap.String("to", s.To()), zap.Int("to peer port", peerPort),
) )
case proxyRx: case proxyRx:
s.lg.Debug( s.lg.Debug(
"received", "received",
zap.String("data-size", humanize.Bytes(uint64(nr1))), zap.String("data-size", humanize.Bytes(uint64(nr1))),
zap.String("from", s.To()), zap.String("proxy listening on", s.Listen()),
zap.String("to", s.From()), zap.Int("to peer port", peerPort),
) )
default: default:
panic("unknown proxy type") panic("unknown proxy type")
@ -629,19 +758,28 @@ func (s *server) Error() <-chan error { return s.errc }
func (s *server) Close() (err error) { func (s *server) Close() (err error) {
s.closeOnce.Do(func() { s.closeOnce.Do(func() {
close(s.donec) close(s.donec)
s.listenerMu.Lock()
if s.listener != nil { // we shutdown the server
err = s.listener.Close() log.Println("we shutdown the server")
s.lg.Info( if err = s.httpServer.Shutdown(context.TODO()); err != nil {
"closed proxy listener", return
zap.String("from", s.From()),
zap.String("to", s.To()),
)
} }
s.httpServer = nil
log.Println("waiting for listenerMu")
// listener was closed by the Shutdown() call
s.listenerMu.Lock()
s.listener = nil
s.lg.Sync() s.lg.Sync()
s.listenerMu.Unlock() s.listenerMu.Unlock()
// the hijacked connections aren't tracked by the server so we need to wait for them
log.Println("waiting for closeHijackedConn")
s.closeHijackedConn.Wait()
}) })
s.closeWg.Wait()
// s.closeWg.Wait()
return err return err
} }
@ -652,8 +790,7 @@ func (s *server) PauseAccept() {
s.lg.Info( s.lg.Info(
"paused accept", "paused accept",
zap.String("from", s.From()), zap.String("proxy listen on", s.Listen()),
zap.String("to", s.To()),
) )
} }
@ -671,8 +808,7 @@ func (s *server) UnpauseAccept() {
s.lg.Info( s.lg.Info(
"unpaused accept", "unpaused accept",
zap.String("from", s.From()), zap.String("proxy listen on", s.Listen()),
zap.String("to", s.To()),
) )
} }
@ -690,8 +826,7 @@ func (s *server) DelayAccept(latency, rv time.Duration) {
zap.Duration("latency", d), zap.Duration("latency", d),
zap.Duration("given-latency", latency), zap.Duration("given-latency", latency),
zap.Duration("given-latency-random-variable", rv), zap.Duration("given-latency-random-variable", rv),
zap.String("from", s.From()), zap.String("proxy listening on", s.Listen()),
zap.String("to", s.To()),
) )
} }
@ -704,8 +839,7 @@ func (s *server) UndelayAccept() {
s.lg.Info( s.lg.Info(
"removed accept latency", "removed accept latency",
zap.Duration("latency", d), zap.Duration("latency", d),
zap.String("from", s.From()), zap.String("proxy listening on", s.Listen()),
zap.String("to", s.To()),
) )
} }
@ -720,6 +854,7 @@ func (s *server) DelayTx(latency, rv time.Duration) {
if latency <= 0 { if latency <= 0 {
return return
} }
d := computeLatency(latency, rv) d := computeLatency(latency, rv)
s.latencyTxMu.Lock() s.latencyTxMu.Lock()
s.latencyTx = d s.latencyTx = d
@ -730,8 +865,7 @@ func (s *server) DelayTx(latency, rv time.Duration) {
zap.Duration("latency", d), zap.Duration("latency", d),
zap.Duration("given-latency", latency), zap.Duration("given-latency", latency),
zap.Duration("given-latency-random-variable", rv), zap.Duration("given-latency-random-variable", rv),
zap.String("from", s.From()), zap.String("proxy listening on", s.Listen()),
zap.String("to", s.To()),
) )
} }
@ -744,8 +878,7 @@ func (s *server) UndelayTx() {
s.lg.Info( s.lg.Info(
"removed transmit latency", "removed transmit latency",
zap.Duration("latency", d), zap.Duration("latency", d),
zap.String("from", s.From()), zap.String("proxy listening on", s.Listen()),
zap.String("to", s.To()),
) )
} }
@ -770,8 +903,7 @@ func (s *server) DelayRx(latency, rv time.Duration) {
zap.Duration("latency", d), zap.Duration("latency", d),
zap.Duration("given-latency", latency), zap.Duration("given-latency", latency),
zap.Duration("given-latency-random-variable", rv), zap.Duration("given-latency-random-variable", rv),
zap.String("from", s.To()), zap.String("proxy listening on", s.Listen()),
zap.String("to", s.From()),
) )
} }
@ -784,8 +916,7 @@ func (s *server) UndelayRx() {
s.lg.Info( s.lg.Info(
"removed receive latency", "removed receive latency",
zap.Duration("latency", d), zap.Duration("latency", d),
zap.String("from", s.To()), zap.String("proxy listening on", s.Listen()),
zap.String("to", s.From()),
) )
} }
@ -821,8 +952,7 @@ func (s *server) ModifyTx(f func([]byte) []byte) {
s.lg.Info( s.lg.Info(
"modifying tx", "modifying tx",
zap.String("from", s.From()), zap.String("proxy listen on", s.Listen()),
zap.String("to", s.To()),
) )
} }
@ -833,8 +963,7 @@ func (s *server) UnmodifyTx() {
s.lg.Info( s.lg.Info(
"unmodifyed tx", "unmodifyed tx",
zap.String("from", s.From()), zap.String("proxy listen on", s.Listen()),
zap.String("to", s.To()),
) )
} }
@ -844,8 +973,7 @@ func (s *server) ModifyRx(f func([]byte) []byte) {
s.modifyRxMu.Unlock() s.modifyRxMu.Unlock()
s.lg.Info( s.lg.Info(
"modifying rx", "modifying rx",
zap.String("from", s.To()), zap.String("proxy listen on", s.Listen()),
zap.String("to", s.From()),
) )
} }
@ -856,8 +984,7 @@ func (s *server) UnmodifyRx() {
s.lg.Info( s.lg.Info(
"unmodifyed rx", "unmodifyed rx",
zap.String("from", s.To()), zap.String("proxy listen on", s.Listen()),
zap.String("to", s.From()),
) )
} }
@ -865,8 +992,7 @@ func (s *server) BlackholeTx() {
s.ModifyTx(func([]byte) []byte { return nil }) s.ModifyTx(func([]byte) []byte { return nil })
s.lg.Info( s.lg.Info(
"blackholed tx", "blackholed tx",
zap.String("from", s.From()), zap.String("proxy listening on", s.Listen()),
zap.String("to", s.To()),
) )
} }
@ -874,8 +1000,7 @@ func (s *server) UnblackholeTx() {
s.UnmodifyTx() s.UnmodifyTx()
s.lg.Info( s.lg.Info(
"unblackholed tx", "unblackholed tx",
zap.String("from", s.From()), zap.String("proxy listening on", s.Listen()),
zap.String("to", s.To()),
) )
} }
@ -883,8 +1008,7 @@ func (s *server) BlackholeRx() {
s.ModifyRx(func([]byte) []byte { return nil }) s.ModifyRx(func([]byte) []byte { return nil })
s.lg.Info( s.lg.Info(
"blackholed rx", "blackholed rx",
zap.String("from", s.To()), zap.String("proxy listening on", s.Listen()),
zap.String("to", s.From()),
) )
} }
@ -892,11 +1016,70 @@ func (s *server) UnblackholeRx() {
s.UnmodifyRx() s.UnmodifyRx()
s.lg.Info( s.lg.Info(
"unblackholed rx", "unblackholed rx",
zap.String("from", s.To()), zap.String("proxy listening on", s.Listen()),
zap.String("to", s.From()),
) )
} }
func (s *server) BlackholePeerTx(peer url.URL) {
s.blackholePeerMapMu.Lock()
defer s.blackholePeerMapMu.Unlock()
port, err := strconv.Atoi(peer.Port())
if err != nil {
panic("port parsing failed")
}
if val, exist := s.blackholePeerMap[port]; exist {
val |= blackholePeerTypeTx
s.blackholePeerMap[port] = val
} else {
s.blackholePeerMap[port] = blackholePeerTypeTx
}
}
func (s *server) UnblackholePeerTx(peer url.URL) {
s.blackholePeerMapMu.Lock()
defer s.blackholePeerMapMu.Unlock()
port, err := strconv.Atoi(peer.Port())
if err != nil {
panic("port parsing failed")
}
if val, exist := s.blackholePeerMap[port]; exist {
val &= bits.Reverse8(blackholePeerTypeTx)
s.blackholePeerMap[port] = val
}
}
func (s *server) BlackholePeerRx(peer url.URL) {
s.blackholePeerMapMu.Lock()
defer s.blackholePeerMapMu.Unlock()
port, err := strconv.Atoi(peer.Port())
if err != nil {
panic("port parsing failed")
}
if val, exist := s.blackholePeerMap[port]; exist {
val |= blackholePeerTypeRx
s.blackholePeerMap[port] = val
} else {
s.blackholePeerMap[port] = blackholePeerTypeTx
}
}
func (s *server) UnblackholePeerRx(peer url.URL) {
s.blackholePeerMapMu.Lock()
defer s.blackholePeerMapMu.Unlock()
port, err := strconv.Atoi(peer.Port())
if err != nil {
panic("port parsing failed")
}
if val, exist := s.blackholePeerMap[port]; exist {
val &= bits.Reverse8(blackholePeerTypeRx)
s.blackholePeerMap[port] = val
}
}
func (s *server) PauseTx() { func (s *server) PauseTx() {
s.pauseTxMu.Lock() s.pauseTxMu.Lock()
s.pauseTxc = make(chan struct{}) s.pauseTxc = make(chan struct{})
@ -904,8 +1087,7 @@ func (s *server) PauseTx() {
s.lg.Info( s.lg.Info(
"paused tx", "paused tx",
zap.String("from", s.From()), zap.String("proxy listen on", s.Listen()),
zap.String("to", s.To()),
) )
} }
@ -923,8 +1105,7 @@ func (s *server) UnpauseTx() {
s.lg.Info( s.lg.Info(
"unpaused tx", "unpaused tx",
zap.String("from", s.From()), zap.String("proxy listen on", s.Listen()),
zap.String("to", s.To()),
) )
} }
@ -935,8 +1116,7 @@ func (s *server) PauseRx() {
s.lg.Info( s.lg.Info(
"paused rx", "paused rx",
zap.String("from", s.To()), zap.String("proxy listen on", s.Listen()),
zap.String("to", s.From()),
) )
} }
@ -954,37 +1134,6 @@ func (s *server) UnpauseRx() {
s.lg.Info( s.lg.Info(
"unpaused rx", "unpaused rx",
zap.String("from", s.To()), zap.String("proxy listen on", s.Listen()),
zap.String("to", s.From()),
) )
} }
func (s *server) ResetListener() error {
s.listenerMu.Lock()
defer s.listenerMu.Unlock()
if err := s.listener.Close(); err != nil {
// already closed
if !strings.HasSuffix(err.Error(), "use of closed network connection") {
return err
}
}
var ln net.Listener
var err error
if !s.tlsInfo.Empty() {
ln, err = transport.NewListener(s.from.Host, s.from.Scheme, &s.tlsInfo)
} else {
ln, err = net.Listen(s.from.Scheme, s.from.Host)
}
if err != nil {
return err
}
s.listener = ln
s.lg.Info(
"reset listener on",
zap.String("from", s.From()),
)
return nil
}

View File

@ -17,127 +17,110 @@ package proxy
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/tls"
"fmt" "fmt"
"io" "io"
"log" "log"
"math/rand"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"os"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
"go.uber.org/zap" "go.uber.org/zap"
"go.uber.org/zap/zaptest" "go.uber.org/zap/zaptest"
"go.etcd.io/etcd/client/pkg/v3/transport" "go.etcd.io/etcd/client/pkg/v3/transport"
) )
func TestServer_Unix_Insecure(t *testing.T) { testServer(t, "unix", false, false) } /* Helper functions */
func TestServer_TCP_Insecure(t *testing.T) { testServer(t, "tcp", false, false) } type dummyServerHandler struct {
func TestServer_Unix_Secure(t *testing.T) { testServer(t, "unix", true, false) } t *testing.T
func TestServer_TCP_Secure(t *testing.T) { testServer(t, "tcp", true, false) } output chan<- []byte
func TestServer_Unix_Insecure_DelayTx(t *testing.T) { testServer(t, "unix", false, true) } }
func TestServer_TCP_Insecure_DelayTx(t *testing.T) { testServer(t, "tcp", false, true) }
func TestServer_Unix_Secure_DelayTx(t *testing.T) { testServer(t, "unix", true, true) }
func TestServer_TCP_Secure_DelayTx(t *testing.T) { testServer(t, "tcp", true, true) }
func testServer(t *testing.T, scheme string, secure bool, delayTx bool) { // reads the request body and write back to the response object
func (sh *dummyServerHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
defer req.Body.Close()
resp.WriteHeader(200)
if data, err := io.ReadAll(req.Body); err != nil {
sh.t.Fatal(err)
} else {
sh.output <- data
}
}
func prepare(t *testing.T, serverIsClosed bool) (chan []byte, chan struct{}, chan []byte, Server, *http.Server, func(data []byte)) {
lg := zaptest.NewLogger(t) lg := zaptest.NewLogger(t)
srcAddr, dstAddr := newUnixAddr(), newUnixAddr() scheme := "tcp"
if scheme == "tcp" { L7Scheme := "http"
// we always send the traffic to destination with HTTPS
// this will force the CONNECT header to be sent first
tlsInfo := createTLSInfo(lg)
ln1, ln2 := listen(t, "tcp", "localhost:0", transport.TLSInfo{}), listen(t, "tcp", "localhost:0", transport.TLSInfo{}) ln1, ln2 := listen(t, "tcp", "localhost:0", transport.TLSInfo{}), listen(t, "tcp", "localhost:0", transport.TLSInfo{})
srcAddr, dstAddr = ln1.Addr().String(), ln2.Addr().String() forwardProxyAddr, dstAddr := ln1.Addr().String(), ln2.Addr().String()
ln1.Close() ln1.Close()
ln2.Close() ln2.Close()
} else {
defer func() {
os.RemoveAll(srcAddr)
os.RemoveAll(dstAddr)
}()
}
tlsInfo := createTLSInfo(lg, secure)
ln := listen(t, scheme, dstAddr, tlsInfo)
defer ln.Close()
cfg := ServerConfig{
Logger: lg,
From: url.URL{Scheme: scheme, Host: srcAddr},
To: url.URL{Scheme: scheme, Host: dstAddr},
}
if secure {
cfg.TLSInfo = tlsInfo
}
p := NewServer(cfg)
waitForServer(t, p)
defer p.Close()
data1 := []byte("Hello World!")
donec, writec := make(chan struct{}), make(chan []byte)
go func() {
defer close(donec)
for data := range writec {
send(t, data, scheme, srcAddr, tlsInfo)
}
}()
recvc := make(chan []byte, 1) recvc := make(chan []byte, 1)
go func() { httpServer := &http.Server{
for i := 0; i < 2; i++ { Handler: &dummyServerHandler{
recvc <- receive(t, ln) t: t,
output: recvc,
},
} }
}() go startHTTPServer(scheme, dstAddr, tlsInfo, httpServer)
writec <- data1 // we connect to the proxy without TLS
now := time.Now() proxyURL := url.URL{Scheme: L7Scheme, Host: forwardProxyAddr}
if d := <-recvc; !bytes.Equal(data1, d) { cfg := ServerConfig{
close(writec) Logger: lg,
t.Fatalf("expected %q, got %q", string(data1), string(d)) Listen: proxyURL,
} }
took1 := time.Since(now) p := NewServer(cfg)
t.Logf("took %v with no latency", took1) waitForServer(t, p)
lat, rv := 50*time.Millisecond, 5*time.Millisecond // setup forward proxy
if delayTx { t.Setenv("E2E_TEST_FORWARD_PROXY_IP", proxyURL.String())
p.DelayTx(lat, rv) t.Logf("Proxy URL %s", proxyURL.String())
}
data2 := []byte("new data") donec, writec := make(chan struct{}), make(chan []byte)
writec <- data2
now = time.Now() var tp *http.Transport
if d := <-recvc; !bytes.Equal(data2, d) { var err error
close(writec) if !tlsInfo.Empty() {
t.Fatalf("expected %q, got %q", string(data2), string(d)) tp, err = transport.NewTransport(tlsInfo, 1*time.Second)
}
took2 := time.Since(now)
if delayTx {
t.Logf("took %v with latency %v+-%v", took2, lat, rv)
} else { } else {
t.Logf("took %v with no latency", took2) tp, err = transport.NewTransport(tlsInfo, 1*time.Second)
}
if err != nil {
t.Fatal(err)
}
tp.IdleConnTimeout = 100 * time.Microsecond
sendData := func(data []byte) {
send(tp, t, data, scheme, dstAddr, tlsInfo, serverIsClosed)
} }
if delayTx { return recvc, donec, writec, p, httpServer, sendData
p.UndelayTx()
if took2 < lat-rv {
close(writec)
t.Fatalf("expected took2 %v (with latency) > delay: %v", took2, lat-rv)
}
} }
func destroy(t *testing.T, writec chan []byte, donec chan struct{}, p Server, serverIsClosed bool, httpServer *http.Server) {
close(writec) close(writec)
if err := httpServer.Shutdown(context.Background()); err != nil {
t.Fatal(err)
}
select { select {
case <-donec: case <-donec:
case <-time.After(3 * time.Second): case <-time.After(3 * time.Second):
t.Fatal("took too long to write") t.Fatal("took too long to write")
} }
if !serverIsClosed {
select { select {
case <-p.Done(): case <-p.Done():
t.Fatal("unexpected done") t.Fatal("unexpected done")
@ -161,9 +144,9 @@ func testServer(t *testing.T, scheme string, secure bool, delayTx bool) {
t.Fatal("took too long to close") t.Fatal("took too long to close")
} }
} }
}
func createTLSInfo(lg *zap.Logger, secure bool) transport.TLSInfo { func createTLSInfo(lg *zap.Logger) transport.TLSInfo {
if secure {
return transport.TLSInfo{ return transport.TLSInfo{
KeyFile: "../../tests/fixtures/server.key.insecure", KeyFile: "../../tests/fixtures/server.key.insecure",
CertFile: "../../tests/fixtures/server.crt", CertFile: "../../tests/fixtures/server.crt",
@ -172,58 +155,175 @@ func createTLSInfo(lg *zap.Logger, secure bool) transport.TLSInfo {
Logger: lg, Logger: lg,
} }
} }
return transport.TLSInfo{Logger: lg}
func listen(t *testing.T, scheme, addr string, tlsInfo transport.TLSInfo) (ln net.Listener) {
var err error
if !tlsInfo.Empty() {
ln, err = transport.NewListener(addr, scheme, &tlsInfo)
} else {
ln, err = net.Listen(scheme, addr)
}
if err != nil {
t.Fatal(err)
}
return ln
} }
func TestServer_Unix_Insecure_DelayAccept(t *testing.T) { testServerDelayAccept(t, false) } func startHTTPServer(scheme, addr string, tlsInfo transport.TLSInfo, httpServer *http.Server) {
func TestServer_Unix_Secure_DelayAccept(t *testing.T) { testServerDelayAccept(t, true) } var err error
func testServerDelayAccept(t *testing.T, secure bool) { var ln net.Listener
lg := zaptest.NewLogger(t)
srcAddr, dstAddr := newUnixAddr(), newUnixAddr() ln, err = net.Listen(scheme, addr)
if err != nil {
log.Fatal(err)
}
log.Println("HTTP Server started on", addr)
if err := httpServer.ServeTLS(ln, tlsInfo.CertFile, tlsInfo.KeyFile); err != http.ErrServerClosed {
// always returns error. ErrServerClosed on graceful close
log.Fatalf(fmt.Sprintf("startHTTPServer ServeTLS(): %v", err))
}
}
func send(tp *http.Transport, t *testing.T, data []byte, scheme, addr string, tlsInfo transport.TLSInfo, serverIsClosed bool) {
defer func() { defer func() {
os.RemoveAll(srcAddr) tp.CloseIdleConnections()
os.RemoveAll(dstAddr)
}() }()
tlsInfo := createTLSInfo(lg, secure)
scheme := "unix"
ln := listen(t, scheme, dstAddr, tlsInfo)
defer ln.Close()
cfg := ServerConfig{ // If you call Dial(), you will get a Conn that you can write the byte stream directly
Logger: lg, // If you call RoundTrip(), you will get a connection managed for you, but you need to send valid HTTP request
From: url.URL{Scheme: scheme, Host: srcAddr}, dataReader := bytes.NewReader(data)
To: url.URL{Scheme: scheme, Host: dstAddr}, protocolScheme := scheme
if scheme == "tcp" {
if !tlsInfo.Empty() {
protocolScheme = "https"
} else {
panic("only https is supported")
} }
if secure { } else {
cfg.TLSInfo = tlsInfo panic("scheme not supported")
}
rawURL := url.URL{
Scheme: protocolScheme,
Host: addr,
} }
p := NewServer(cfg)
waitForServer(t, p) req, err := http.NewRequest("POST", rawURL.String(), dataReader)
if err != nil {
t.Fatal(err)
}
res, err := tp.RoundTrip(req)
if err != nil {
if strings.Contains(err.Error(), "TLS handshake timeout") {
t.Logf("TLS handshake timeout")
return
}
if serverIsClosed {
// when the proxy server is closed before sending, we will get this error message
if strings.Contains(err.Error(), "connect: connection refused") {
t.Logf("connect: connection refused")
return
}
}
panic(err)
}
defer func() {
if err := res.Body.Close(); err != nil {
panic(err)
}
}()
defer p.Close() if res.StatusCode != 200 {
t.Fatalf("status code not 200")
}
}
data := []byte("Hello World!") // Waits until a proxy is ready to serve.
// Aborts test on proxy start-up error.
func waitForServer(t *testing.T, s Server) {
select {
case <-s.Ready():
case err := <-s.Error():
t.Fatal(err)
}
}
/* Unit tests */
func TestServer_TCP(t *testing.T) { testServer(t, false) }
func TestServer_TCP_DelayTx(t *testing.T) { testServer(t, true) }
func testServer(t *testing.T, delayTx bool) {
recvc, donec, writec, p, httpServer, sendData := prepare(t, false)
defer destroy(t, writec, donec, p, false, httpServer)
go func() {
defer close(donec)
for data := range writec {
sendData(data)
}
}()
data1 := []byte("Hello World!")
writec <- data1
now := time.Now() now := time.Now()
send(t, data, scheme, srcAddr, tlsInfo) if d := <-recvc; !bytes.Equal(data1, d) {
if d := receive(t, ln); !bytes.Equal(data, d) { t.Fatalf("expected %q, got %q", string(data1), string(d))
t.Fatalf("expected %q, got %q", string(data), string(d))
} }
took1 := time.Since(now) took1 := time.Since(now)
t.Logf("took %v with no latency", took1) t.Logf("took %v with no latency", took1)
lat, rv := 50*time.Millisecond, 5*time.Millisecond
if delayTx {
p.DelayTx(lat, rv)
}
data2 := []byte("new data")
writec <- data2
now = time.Now()
if d := <-recvc; !bytes.Equal(data2, d) {
t.Fatalf("expected %q, got %q", string(data2), string(d))
}
took2 := time.Since(now)
if delayTx {
t.Logf("took %v with latency %v+-%v", took2, lat, rv)
} else {
t.Logf("took %v with no latency", took2)
}
if delayTx {
p.UndelayTx()
if took2 < lat-rv {
close(writec)
t.Fatalf("expected took2 %v (with latency) > delay: %v", took2, lat-rv)
}
}
}
func TestServer_DelayAccept(t *testing.T) {
recvc, donec, writec, p, httpServer, sendData := prepare(t, false)
defer destroy(t, writec, donec, p, false, httpServer)
go func() {
defer close(donec)
for data := range writec {
sendData(data)
}
}()
data := []byte("Hello World!")
now := time.Now()
writec <- data
if d := <-recvc; !bytes.Equal(data, d) {
t.Fatalf("expected %q, got %q", string(data), string(d))
}
took1 := time.Since(now)
t.Logf("took %v with no latency", took1)
time.Sleep(1 * time.Second) // wait for the idle connection to timeout
lat, rv := 700*time.Millisecond, 10*time.Millisecond lat, rv := 700*time.Millisecond, 10*time.Millisecond
p.DelayAccept(lat, rv) p.DelayAccept(lat, rv)
defer p.UndelayAccept() defer p.UndelayAccept()
if err := p.ResetListener(); err != nil {
t.Fatal(err)
}
time.Sleep(200 * time.Millisecond)
now = time.Now() now = time.Now()
send(t, data, scheme, srcAddr, tlsInfo) writec <- data
if d := receive(t, ln); !bytes.Equal(data, d) { if d := <-recvc; !bytes.Equal(data, d) {
t.Fatalf("expected %q, got %q", string(data), string(d)) t.Fatalf("expected %q, got %q", string(data), string(d))
} }
took2 := time.Since(now) took2 := time.Since(now)
@ -235,36 +335,22 @@ func testServerDelayAccept(t *testing.T, secure bool) {
} }
func TestServer_PauseTx(t *testing.T) { func TestServer_PauseTx(t *testing.T) {
lg := zaptest.NewLogger(t) recvc, donec, writec, p, httpServer, sendData := prepare(t, false)
scheme := "unix" defer destroy(t, writec, donec, p, false, httpServer)
srcAddr, dstAddr := newUnixAddr(), newUnixAddr() // the sendData function must be in a goroutine
defer func() { // otherwise, the pauseTx will cause the sendData to block
os.RemoveAll(srcAddr) go func() {
os.RemoveAll(dstAddr) defer close(donec)
for data := range writec {
sendData(data)
}
}() }()
ln := listen(t, scheme, dstAddr, transport.TLSInfo{})
defer ln.Close()
p := NewServer(ServerConfig{ data := []byte("Hello World!")
Logger: lg,
From: url.URL{Scheme: scheme, Host: srcAddr},
To: url.URL{Scheme: scheme, Host: dstAddr},
})
waitForServer(t, p)
defer p.Close()
p.PauseTx() p.PauseTx()
data := []byte("Hello World!") writec <- data
send(t, data, scheme, srcAddr, transport.TLSInfo{})
recvc := make(chan []byte, 1)
go func() {
recvc <- receive(t, ln)
}()
select { select {
case d := <-recvc: case d := <-recvc:
t.Fatalf("received unexpected data %q during pause", string(d)) t.Fatalf("received unexpected data %q during pause", string(d))
@ -283,114 +369,32 @@ func TestServer_PauseTx(t *testing.T) {
} }
} }
func TestServer_ModifyTx_corrupt(t *testing.T) {
lg := zaptest.NewLogger(t)
scheme := "unix"
srcAddr, dstAddr := newUnixAddr(), newUnixAddr()
defer func() {
os.RemoveAll(srcAddr)
os.RemoveAll(dstAddr)
}()
ln := listen(t, scheme, dstAddr, transport.TLSInfo{})
defer ln.Close()
p := NewServer(ServerConfig{
Logger: lg,
From: url.URL{Scheme: scheme, Host: srcAddr},
To: url.URL{Scheme: scheme, Host: dstAddr},
})
waitForServer(t, p)
defer p.Close()
p.ModifyTx(func(d []byte) []byte {
d[len(d)/2]++
return d
})
data := []byte("Hello World!")
send(t, data, scheme, srcAddr, transport.TLSInfo{})
if d := receive(t, ln); bytes.Equal(d, data) {
t.Fatalf("expected corrupted data, got %q", string(d))
}
p.UnmodifyTx()
send(t, data, scheme, srcAddr, transport.TLSInfo{})
if d := receive(t, ln); !bytes.Equal(d, data) {
t.Fatalf("expected uncorrupted data, got %q", string(d))
}
}
func TestServer_ModifyTx_packet_loss(t *testing.T) {
lg := zaptest.NewLogger(t)
scheme := "unix"
srcAddr, dstAddr := newUnixAddr(), newUnixAddr()
defer func() {
os.RemoveAll(srcAddr)
os.RemoveAll(dstAddr)
}()
ln := listen(t, scheme, dstAddr, transport.TLSInfo{})
defer ln.Close()
p := NewServer(ServerConfig{
Logger: lg,
From: url.URL{Scheme: scheme, Host: srcAddr},
To: url.URL{Scheme: scheme, Host: dstAddr},
})
waitForServer(t, p)
defer p.Close()
// 50% packet loss
p.ModifyTx(func(d []byte) []byte {
half := len(d) / 2
return d[:half:half]
})
data := []byte("Hello World!")
send(t, data, scheme, srcAddr, transport.TLSInfo{})
if d := receive(t, ln); bytes.Equal(d, data) {
t.Fatalf("expected corrupted data, got %q", string(d))
}
p.UnmodifyTx()
send(t, data, scheme, srcAddr, transport.TLSInfo{})
if d := receive(t, ln); !bytes.Equal(d, data) {
t.Fatalf("expected uncorrupted data, got %q", string(d))
}
}
func TestServer_BlackholeTx(t *testing.T) { func TestServer_BlackholeTx(t *testing.T) {
lg := zaptest.NewLogger(t) recvc, donec, writec, p, httpServer, sendData := prepare(t, false)
scheme := "unix" defer destroy(t, writec, donec, p, false, httpServer)
srcAddr, dstAddr := newUnixAddr(), newUnixAddr() // the sendData function must be in a goroutine
defer func() { // otherwise, the pauseTx will cause the sendData to block
os.RemoveAll(srcAddr) go func() {
os.RemoveAll(dstAddr) defer close(donec)
for data := range writec {
sendData(data)
}
}() }()
ln := listen(t, scheme, dstAddr, transport.TLSInfo{})
defer ln.Close()
p := NewServer(ServerConfig{ // before enabling blacklhole
Logger: lg, data := []byte("Hello World!")
From: url.URL{Scheme: scheme, Host: srcAddr}, writec <- data
To: url.URL{Scheme: scheme, Host: dstAddr}, if d := <-recvc; !bytes.Equal(data, d) {
}) t.Fatalf("expected %q, got %q", string(data), string(d))
}
waitForServer(t, p)
defer p.Close()
// enable blackhole
// note that the transport is set to use 10s for TLSHandshakeTimeout, so
// this test will require at least 10s to execute, since send() is a
// blocking call thus we need to wait for ssl handshake to timeout
p.BlackholeTx() p.BlackholeTx()
data := []byte("Hello World!") writec <- data
send(t, data, scheme, srcAddr, transport.TLSInfo{})
recvc := make(chan []byte, 1)
go func() {
recvc <- receive(t, ln)
}()
select { select {
case d := <-recvc: case d := <-recvc:
t.Fatalf("unexpected data receive %q during blackhole", string(d)) t.Fatalf("unexpected data receive %q during blackhole", string(d))
@ -399,10 +403,12 @@ func TestServer_BlackholeTx(t *testing.T) {
p.UnblackholeTx() p.UnblackholeTx()
// disable blackhole
// TODO: figure out why HTTPS won't attempt to reconnect when the blackhole is disabled
// expect different data, old data dropped // expect different data, old data dropped
data[0]++ data[0]++
send(t, data, scheme, srcAddr, transport.TLSInfo{}) writec <- data
select { select {
case d := <-recvc: case d := <-recvc:
if !bytes.Equal(data, d) { if !bytes.Equal(data, d) {
@ -414,286 +420,31 @@ func TestServer_BlackholeTx(t *testing.T) {
} }
func TestServer_Shutdown(t *testing.T) { func TestServer_Shutdown(t *testing.T) {
lg := zaptest.NewLogger(t) recvc, donec, writec, p, httpServer, sendData := prepare(t, true)
scheme := "unix" defer destroy(t, writec, donec, p, true, httpServer)
srcAddr, dstAddr := newUnixAddr(), newUnixAddr() go func() {
defer func() { defer close(donec)
os.RemoveAll(srcAddr) for data := range writec {
os.RemoveAll(dstAddr) sendData(data)
}
}() }()
ln := listen(t, scheme, dstAddr, transport.TLSInfo{})
defer ln.Close()
p := NewServer(ServerConfig{
Logger: lg,
From: url.URL{Scheme: scheme, Host: srcAddr},
To: url.URL{Scheme: scheme, Host: dstAddr},
})
waitForServer(t, p)
defer p.Close()
s, _ := p.(*server) s, _ := p.(*server)
s.listener.Close() if err := s.Close(); err != nil {
t.Fatal(err)
}
p = nil
time.Sleep(200 * time.Millisecond) time.Sleep(200 * time.Millisecond)
data := []byte("Hello World!") data := []byte("Hello World!")
send(t, data, scheme, srcAddr, transport.TLSInfo{}) sendData(data)
if d := receive(t, ln); !bytes.Equal(d, data) {
t.Fatalf("expected %q, got %q", string(data), string(d))
}
}
func TestServer_ShutdownListener(t *testing.T) {
lg := zaptest.NewLogger(t)
scheme := "unix"
srcAddr, dstAddr := newUnixAddr(), newUnixAddr()
defer func() {
os.RemoveAll(srcAddr)
os.RemoveAll(dstAddr)
}()
ln := listen(t, scheme, dstAddr, transport.TLSInfo{})
defer ln.Close()
p := NewServer(ServerConfig{
Logger: lg,
From: url.URL{Scheme: scheme, Host: srcAddr},
To: url.URL{Scheme: scheme, Host: dstAddr},
})
waitForServer(t, p)
defer p.Close()
// shut down destination
ln.Close()
time.Sleep(200 * time.Millisecond)
ln = listen(t, scheme, dstAddr, transport.TLSInfo{})
defer ln.Close()
data := []byte("Hello World!")
send(t, data, scheme, srcAddr, transport.TLSInfo{})
if d := receive(t, ln); !bytes.Equal(d, data) {
t.Fatalf("expected %q, got %q", string(data), string(d))
}
}
func TestServerHTTP_Insecure_DelayTx(t *testing.T) { testServerHTTP(t, false, true) }
func TestServerHTTP_Secure_DelayTx(t *testing.T) { testServerHTTP(t, true, true) }
func TestServerHTTP_Insecure_DelayRx(t *testing.T) { testServerHTTP(t, false, false) }
func TestServerHTTP_Secure_DelayRx(t *testing.T) { testServerHTTP(t, true, false) }
func testServerHTTP(t *testing.T, secure, delayTx bool) {
lg := zaptest.NewLogger(t)
scheme := "tcp"
ln1, ln2 := listen(t, scheme, "localhost:0", transport.TLSInfo{}), listen(t, scheme, "localhost:0", transport.TLSInfo{})
srcAddr, dstAddr := ln1.Addr().String(), ln2.Addr().String()
ln1.Close()
ln2.Close()
mux := http.NewServeMux()
mux.HandleFunc("/hello", func(w http.ResponseWriter, req *http.Request) {
d, err := io.ReadAll(req.Body)
req.Body.Close()
if err != nil {
t.Fatal(err)
}
if _, err = w.Write([]byte(fmt.Sprintf("%q(confirmed)", string(d)))); err != nil {
t.Fatal(err)
}
})
tlsInfo := createTLSInfo(lg, secure)
var tlsConfig *tls.Config
if secure {
_, err := tlsInfo.ServerConfig()
if err != nil {
t.Fatal(err)
}
}
srv := &http.Server{
Addr: dstAddr,
Handler: mux,
TLSConfig: tlsConfig,
ErrorLog: log.New(io.Discard, "net/http", 0),
}
donec := make(chan struct{})
defer func() {
srv.Close()
<-donec
}()
go func() {
if !secure {
srv.ListenAndServe()
} else {
srv.ListenAndServeTLS(tlsInfo.CertFile, tlsInfo.KeyFile)
}
defer close(donec)
}()
time.Sleep(200 * time.Millisecond)
cfg := ServerConfig{
Logger: lg,
From: url.URL{Scheme: scheme, Host: srcAddr},
To: url.URL{Scheme: scheme, Host: dstAddr},
}
if secure {
cfg.TLSInfo = tlsInfo
}
p := NewServer(cfg)
waitForServer(t, p)
defer func() {
lg.Info("closing Proxy server...")
p.Close()
lg.Info("closed Proxy server.")
}()
data := "Hello World!"
var resp *http.Response
var err error
now := time.Now()
if secure {
tp, terr := transport.NewTransport(tlsInfo, 3*time.Second)
assert.NoError(t, terr)
cli := &http.Client{Transport: tp}
resp, err = cli.Post("https://"+srcAddr+"/hello", "", strings.NewReader(data))
defer cli.CloseIdleConnections()
defer tp.CloseIdleConnections()
} else {
resp, err = http.Post("http://"+srcAddr+"/hello", "", strings.NewReader(data))
defer http.DefaultClient.CloseIdleConnections()
}
assert.NoError(t, err)
d, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
resp.Body.Close()
took1 := time.Since(now)
t.Logf("took %v with no latency", took1)
rs1 := string(d)
exp := fmt.Sprintf("%q(confirmed)", data)
if rs1 != exp {
t.Fatalf("got %q, expected %q", rs1, exp)
}
lat, rv := 100*time.Millisecond, 10*time.Millisecond
if delayTx {
p.DelayTx(lat, rv)
defer p.UndelayTx()
} else {
p.DelayRx(lat, rv)
defer p.UndelayRx()
}
now = time.Now()
if secure {
tp, terr := transport.NewTransport(tlsInfo, 3*time.Second)
if terr != nil {
t.Fatal(terr)
}
cli := &http.Client{Transport: tp}
resp, err = cli.Post("https://"+srcAddr+"/hello", "", strings.NewReader(data))
defer cli.CloseIdleConnections()
defer tp.CloseIdleConnections()
} else {
resp, err = http.Post("http://"+srcAddr+"/hello", "", strings.NewReader(data))
defer http.DefaultClient.CloseIdleConnections()
}
if err != nil {
t.Fatal(err)
}
d, err = io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
resp.Body.Close()
took2 := time.Since(now)
t.Logf("took %v with latency %v±%v", took2, lat, rv)
rs2 := string(d)
if rs2 != exp {
t.Fatalf("got %q, expected %q", rs2, exp)
}
if took1 > took2 {
t.Fatalf("expected took1 %v < took2 %v", took1, took2)
}
}
func newUnixAddr() string {
now := time.Now().UnixNano()
addr := fmt.Sprintf("%X%X.unix-conn", now, rand.Intn(35000))
os.RemoveAll(addr)
return addr
}
func listen(t *testing.T, scheme, addr string, tlsInfo transport.TLSInfo) (ln net.Listener) {
var err error
if !tlsInfo.Empty() {
ln, err = transport.NewListener(addr, scheme, &tlsInfo)
} else {
ln, err = net.Listen(scheme, addr)
}
if err != nil {
t.Fatal(err)
}
return ln
}
func send(t *testing.T, data []byte, scheme, addr string, tlsInfo transport.TLSInfo) {
var out net.Conn
var err error
if !tlsInfo.Empty() {
tp, terr := transport.NewTransport(tlsInfo, 3*time.Second)
if terr != nil {
t.Fatal(terr)
}
out, err = tp.DialContext(context.Background(), scheme, addr)
} else {
out, err = net.Dial(scheme, addr)
}
if err != nil {
t.Fatal(err)
}
if _, err = out.Write(data); err != nil {
t.Fatal(err)
}
if err = out.Close(); err != nil {
t.Fatal(err)
}
}
func receive(t *testing.T, ln net.Listener) (data []byte) {
buf := bytes.NewBuffer(make([]byte, 0, 1024))
for {
in, err := ln.Accept()
if err != nil {
t.Fatal(err)
}
var n int64
n, err = buf.ReadFrom(in)
if err != nil {
t.Fatal(err)
}
if n > 0 {
break
}
}
return buf.Bytes()
}
// Waits until a proxy is ready to serve.
// Aborts test on proxy start-up error.
func waitForServer(t *testing.T, s Server) {
select { select {
case <-s.Ready(): case d := <-recvc:
case err := <-s.Error(): if bytes.Equal(data, d) {
t.Fatal(err) t.Fatalf("expected nothing, got %q", string(d))
}
case <-time.After(2 * time.Second):
t.Log("nothing was received, proxy server seems to be closed so no traffic is forwarded")
} }
} }

105
tests/e2e/blackhole_test.go Normal file
View File

@ -0,0 +1,105 @@
// Copyright 2022 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build !cluster_proxy
package e2e
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.etcd.io/etcd/tests/v3/framework/e2e"
)
func TestBlackholeByMockingPartitionLeader(t *testing.T) {
blackholeTestByMockingPartition(t, 3, true)
}
func TestBlackholeByMockingPartitionFollower(t *testing.T) {
blackholeTestByMockingPartition(t, 3, false)
}
func blackholeTestByMockingPartition(t *testing.T, clusterSize int, partitionLeader bool) {
e2e.BeforeTest(t)
t.Logf("Create an etcd cluster with %d member\n", clusterSize)
epc, err := e2e.NewEtcdProcessCluster(context.TODO(), t,
e2e.WithClusterSize(clusterSize),
e2e.WithSnapshotCount(10),
e2e.WithSnapshotCatchUpEntries(10),
e2e.WithIsPeerTLS(true),
e2e.WithPeerProxy(true),
)
require.NoError(t, err, "failed to start etcd cluster: %v", err)
defer func() {
require.NoError(t, epc.Close(), "failed to close etcd cluster")
}()
leaderID := epc.WaitLeader(t)
mockPartitionNodeIndex := leaderID
if !partitionLeader {
mockPartitionNodeIndex = (leaderID + 1) % (clusterSize)
}
partitionedMember := epc.Procs[mockPartitionNodeIndex]
// Mock partition
t.Logf("Blackholing traffic from and to member %q", partitionedMember.Config().Name)
epc.BlackholePeer(partitionedMember)
t.Logf("Wait 1s for any open connections to expire")
time.Sleep(1 * time.Second)
t.Logf("Wait for new leader election with remaining members")
leaderEPC := epc.Procs[waitLeader(t, epc, mockPartitionNodeIndex)]
t.Log("Writing 20 keys to the cluster (more than SnapshotCount entries to trigger at least a snapshot.)")
writeKVs(t, leaderEPC.Etcdctl(), 0, 20)
e2e.AssertProcessLogs(t, leaderEPC, "saved snapshot")
t.Log("Verifying the partitionedMember is missing new writes")
assertRevision(t, leaderEPC, 21)
assertRevision(t, partitionedMember, 1)
// Wait for some time to restore the network
time.Sleep(1 * time.Second)
t.Logf("Unblackholing traffic from and to member %q", partitionedMember.Config().Name)
epc.UnblackholePeer(partitionedMember)
leaderEPC = epc.Procs[epc.WaitLeader(t)]
time.Sleep(1 * time.Second)
assertRevision(t, leaderEPC, 21)
assertRevision(t, partitionedMember, 21)
}
func waitLeader(t testing.TB, epc *e2e.EtcdProcessCluster, excludeNode int) int {
var membs []e2e.EtcdProcess
for i := 0; i < len(epc.Procs); i++ {
if i == excludeNode {
continue
}
membs = append(membs, epc.Procs[i])
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
return epc.WaitMembersForLeader(ctx, t, membs)
}
func assertRevision(t testing.TB, member e2e.EtcdProcess, expectedRevision int64) {
responses, err := member.Etcdctl().Status(context.TODO())
require.NoError(t, err)
assert.Equal(t, expectedRevision, responses[0].Header.Revision, "revision mismatch")
}

View File

@ -384,10 +384,10 @@ func triggerSlowApply(ctx context.Context, t *testing.T, clus *e2e.EtcdProcessCl
func blackhole(_ context.Context, t *testing.T, clus *e2e.EtcdProcessCluster, _ time.Duration) { func blackhole(_ context.Context, t *testing.T, clus *e2e.EtcdProcessCluster, _ time.Duration) {
member := clus.Procs[0] member := clus.Procs[0]
proxy := member.PeerProxy() forwardProxy := member.PeerForwardProxy()
t.Logf("Blackholing traffic from and to member %q", member.Config().Name) t.Logf("Blackholing traffic from and to member %q", member.Config().Name)
proxy.BlackholeTx() forwardProxy.BlackholeTx()
proxy.BlackholeRx() forwardProxy.BlackholeRx()
} }
func triggerRaftLoopDeadLock(ctx context.Context, t *testing.T, clus *e2e.EtcdProcessCluster, duration time.Duration) { func triggerRaftLoopDeadLock(ctx context.Context, t *testing.T, clus *e2e.EtcdProcessCluster, duration time.Duration) {

View File

@ -513,10 +513,10 @@ func (cfg *EtcdProcessClusterConfig) EtcdServerProcessConfig(tb testing.TB, i in
var curl string var curl string
port := cfg.BasePort + 5*i port := cfg.BasePort + 5*i
clientPort := port clientPort := port
peerPort := port + 1 peerPort := port + 1 // the port that the peer actually listens on
metricsPort := port + 2 metricsPort := port + 2
peer2Port := port + 3 clientHTTPPort := port + 3
clientHTTPPort := port + 4 forwardProxyPort := port + 4
if cfg.Client.ConnectionType == ClientTLSAndNonTLS { if cfg.Client.ConnectionType == ClientTLSAndNonTLS {
curl = clientURL(cfg.ClientScheme(), clientPort, ClientNonTLS) curl = clientURL(cfg.ClientScheme(), clientPort, ClientNonTLS)
@ -528,17 +528,23 @@ func (cfg *EtcdProcessClusterConfig) EtcdServerProcessConfig(tb testing.TB, i in
peerListenURL := url.URL{Scheme: cfg.PeerScheme(), Host: fmt.Sprintf("localhost:%d", peerPort)} peerListenURL := url.URL{Scheme: cfg.PeerScheme(), Host: fmt.Sprintf("localhost:%d", peerPort)}
peerAdvertiseURL := url.URL{Scheme: cfg.PeerScheme(), Host: fmt.Sprintf("localhost:%d", peerPort)} peerAdvertiseURL := url.URL{Scheme: cfg.PeerScheme(), Host: fmt.Sprintf("localhost:%d", peerPort)}
var proxyCfg *proxy.ServerConfig var forwardProxyCfg *proxy.ServerConfig
if cfg.PeerProxy { if cfg.PeerProxy {
if !cfg.IsPeerTLS { if !cfg.IsPeerTLS {
panic("Can't use peer proxy without peer TLS as it can result in malformed packets") panic("Can't use peer proxy without peer TLS as it can result in malformed packets")
} }
peerAdvertiseURL.Host = fmt.Sprintf("localhost:%d", peer2Port)
proxyCfg = &proxy.ServerConfig{ // setup forward proxy
forwardProxyURL := url.URL{Scheme: cfg.PeerScheme(), Host: fmt.Sprintf("localhost:%d", forwardProxyPort)}
forwardProxyCfg = &proxy.ServerConfig{
Logger: zap.NewNop(), Logger: zap.NewNop(),
To: peerListenURL, Listen: forwardProxyURL,
From: peerAdvertiseURL,
} }
if cfg.EnvVars == nil {
cfg.EnvVars = make(map[string]string)
}
cfg.EnvVars["E2E_TEST_FORWARD_PROXY_IP"] = fmt.Sprintf("http://127.0.0.1:%d", forwardProxyPort)
} }
name := fmt.Sprintf("%s-test-%d", testNameCleanRegex.ReplaceAllString(tb.Name(), ""), i) name := fmt.Sprintf("%s-test-%d", testNameCleanRegex.ReplaceAllString(tb.Name(), ""), i)
@ -660,7 +666,7 @@ func (cfg *EtcdProcessClusterConfig) EtcdServerProcessConfig(tb testing.TB, i in
InitialToken: cfg.ServerConfig.InitialClusterToken, InitialToken: cfg.ServerConfig.InitialClusterToken,
GoFailPort: gofailPort, GoFailPort: gofailPort,
GoFailClientTimeout: cfg.GoFailClientTimeout, GoFailClientTimeout: cfg.GoFailClientTimeout,
Proxy: proxyCfg, ForwardProxy: forwardProxyCfg,
LazyFSEnabled: cfg.LazyFSEnabled, LazyFSEnabled: cfg.LazyFSEnabled,
} }
} }
@ -910,6 +916,38 @@ func (epc *EtcdProcessCluster) Restart(ctx context.Context) error {
return epc.start(func(ep EtcdProcess) error { return ep.Restart(ctx) }) return epc.start(func(ep EtcdProcess) error { return ep.Restart(ctx) })
} }
func (epc *EtcdProcessCluster) BlackholePeer(blackholePeer EtcdProcess) error {
blackholePeer.PeerForwardProxy().BlackholeRx()
blackholePeer.PeerForwardProxy().BlackholeTx()
for _, peer := range epc.Procs {
if peer.Config().Name == blackholePeer.Config().Name {
continue
}
peer.PeerForwardProxy().BlackholePeerRx(blackholePeer.Config().PeerURL)
peer.PeerForwardProxy().BlackholePeerTx(blackholePeer.Config().PeerURL)
}
return nil
}
func (epc *EtcdProcessCluster) UnblackholePeer(blackholePeer EtcdProcess) error {
blackholePeer.PeerForwardProxy().UnblackholeRx()
blackholePeer.PeerForwardProxy().UnblackholeTx()
for _, peer := range epc.Procs {
if peer.Config().Name == blackholePeer.Config().Name {
continue
}
peer.PeerForwardProxy().UnblackholePeerRx(blackholePeer.Config().PeerURL)
peer.PeerForwardProxy().UnblackholePeerTx(blackholePeer.Config().PeerURL)
}
return nil
}
func (epc *EtcdProcessCluster) start(f func(ep EtcdProcess) error) error { func (epc *EtcdProcessCluster) start(f func(ep EtcdProcess) error) error {
readyC := make(chan error, len(epc.Procs)) readyC := make(chan error, len(epc.Procs))
for i := range epc.Procs { for i := range epc.Procs {

View File

@ -55,7 +55,7 @@ type EtcdProcess interface {
Stop() error Stop() error
Close() error Close() error
Config() *EtcdServerProcessConfig Config() *EtcdServerProcessConfig
PeerProxy() proxy.Server PeerForwardProxy() proxy.Server
Failpoints() *BinaryFailpoints Failpoints() *BinaryFailpoints
LazyFS() *LazyFS LazyFS() *LazyFS
Logs() LogsExpect Logs() LogsExpect
@ -71,7 +71,7 @@ type LogsExpect interface {
type EtcdServerProcess struct { type EtcdServerProcess struct {
cfg *EtcdServerProcessConfig cfg *EtcdServerProcessConfig
proc *expect.ExpectProcess proc *expect.ExpectProcess
proxy proxy.Server forwardProxy proxy.Server
lazyfs *LazyFS lazyfs *LazyFS
failpoints *BinaryFailpoints failpoints *BinaryFailpoints
donec chan struct{} // closed when Interact() terminates donec chan struct{} // closed when Interact() terminates
@ -101,7 +101,7 @@ type EtcdServerProcessConfig struct {
GoFailClientTimeout time.Duration GoFailClientTimeout time.Duration
LazyFSEnabled bool LazyFSEnabled bool
Proxy *proxy.ServerConfig ForwardProxy *proxy.ServerConfig
} }
func NewEtcdServerProcess(t testing.TB, cfg *EtcdServerProcessConfig) (*EtcdServerProcess, error) { func NewEtcdServerProcess(t testing.TB, cfg *EtcdServerProcessConfig) (*EtcdServerProcess, error) {
@ -151,12 +151,13 @@ func (ep *EtcdServerProcess) Start(ctx context.Context) error {
if ep.proc != nil { if ep.proc != nil {
panic("already started") panic("already started")
} }
if ep.cfg.Proxy != nil && ep.proxy == nil {
ep.cfg.lg.Info("starting proxy...", zap.String("name", ep.cfg.Name), zap.String("from", ep.cfg.Proxy.From.String()), zap.String("to", ep.cfg.Proxy.To.String())) if ep.cfg.ForwardProxy != nil && ep.forwardProxy == nil {
ep.proxy = proxy.NewServer(*ep.cfg.Proxy) ep.cfg.lg.Info("starting forward proxy...", zap.String("name", ep.cfg.Name), zap.String("listen on", ep.cfg.ForwardProxy.Listen.String()))
ep.forwardProxy = proxy.NewServer(*ep.cfg.ForwardProxy)
select { select {
case <-ep.proxy.Ready(): case <-ep.forwardProxy.Ready():
case err := <-ep.proxy.Error(): case err := <-ep.forwardProxy.Error():
return err return err
} }
} }
@ -221,10 +222,10 @@ func (ep *EtcdServerProcess) Stop() (err error) {
} }
} }
ep.cfg.lg.Info("stopped server.", zap.String("name", ep.cfg.Name)) ep.cfg.lg.Info("stopped server.", zap.String("name", ep.cfg.Name))
if ep.proxy != nil { if ep.forwardProxy != nil {
ep.cfg.lg.Info("stopping proxy...", zap.String("name", ep.cfg.Name)) ep.cfg.lg.Info("stopping forward proxy...", zap.String("name", ep.cfg.Name))
err = ep.proxy.Close() err = ep.forwardProxy.Close()
ep.proxy = nil ep.forwardProxy = nil
if err != nil { if err != nil {
return err return err
} }
@ -330,8 +331,8 @@ func AssertProcessLogs(t *testing.T, ep EtcdProcess, expectLog string) {
} }
} }
func (ep *EtcdServerProcess) PeerProxy() proxy.Server { func (ep *EtcdServerProcess) PeerForwardProxy() proxy.Server {
return ep.proxy return ep.forwardProxy
} }
func (ep *EtcdServerProcess) LazyFS() *LazyFS { func (ep *EtcdServerProcess) LazyFS() *LazyFS {

View File

@ -63,23 +63,17 @@ func (tb triggerBlackhole) Available(config e2e.EtcdProcessClusterConfig, proces
if tb.waitTillSnapshot && (entriesToGuaranteeSnapshot(config) > 200 || !e2e.CouldSetSnapshotCatchupEntries(process.Config().ExecPath)) { if tb.waitTillSnapshot && (entriesToGuaranteeSnapshot(config) > 200 || !e2e.CouldSetSnapshotCatchupEntries(process.Config().ExecPath)) {
return false return false
} }
return config.ClusterSize > 1 && process.PeerProxy() != nil return config.ClusterSize > 1 && process.PeerForwardProxy() != nil
} }
func Blackhole(ctx context.Context, t *testing.T, member e2e.EtcdProcess, clus *e2e.EtcdProcessCluster, shouldWaitTillSnapshot bool) error { func Blackhole(ctx context.Context, t *testing.T, member e2e.EtcdProcess, clus *e2e.EtcdProcessCluster, shouldWaitTillSnapshot bool) error {
proxy := member.PeerProxy()
// Blackholing will cause peers to not be able to use streamWriters registered with member
// but peer traffic is still possible because member has 'pipeline' with peers
// TODO: find a way to stop all traffic
t.Logf("Blackholing traffic from and to member %q", member.Config().Name) t.Logf("Blackholing traffic from and to member %q", member.Config().Name)
proxy.BlackholeTx() clus.BlackholePeer(member)
proxy.BlackholeRx()
defer func() { defer func() {
t.Logf("Traffic restored from and to member %q", member.Config().Name) t.Logf("Traffic restored from and to member %q", member.Config().Name)
proxy.UnblackholeTx() clus.UnblackholePeer(member)
proxy.UnblackholeRx()
}() }()
if shouldWaitTillSnapshot { if shouldWaitTillSnapshot {
return waitTillSnapshot(ctx, t, clus, member) return waitTillSnapshot(ctx, t, clus, member)
} }
@ -164,15 +158,15 @@ type delayPeerNetworkFailpoint struct {
func (f delayPeerNetworkFailpoint) Inject(ctx context.Context, t *testing.T, lg *zap.Logger, clus *e2e.EtcdProcessCluster, baseTime time.Time, ids identity.Provider) ([]report.ClientReport, error) { func (f delayPeerNetworkFailpoint) Inject(ctx context.Context, t *testing.T, lg *zap.Logger, clus *e2e.EtcdProcessCluster, baseTime time.Time, ids identity.Provider) ([]report.ClientReport, error) {
member := clus.Procs[rand.Int()%len(clus.Procs)] member := clus.Procs[rand.Int()%len(clus.Procs)]
proxy := member.PeerProxy() forwardProxy := member.PeerForwardProxy()
proxy.DelayRx(f.baseLatency, f.randomizedLatency) forwardProxy.DelayRx(f.baseLatency, f.randomizedLatency)
proxy.DelayTx(f.baseLatency, f.randomizedLatency) forwardProxy.DelayTx(f.baseLatency, f.randomizedLatency)
lg.Info("Delaying traffic from and to member", zap.String("member", member.Config().Name), zap.Duration("baseLatency", f.baseLatency), zap.Duration("randomizedLatency", f.randomizedLatency)) lg.Info("Delaying traffic from and to member", zap.String("member", member.Config().Name), zap.Duration("baseLatency", f.baseLatency), zap.Duration("randomizedLatency", f.randomizedLatency))
time.Sleep(f.duration) time.Sleep(f.duration)
lg.Info("Traffic delay removed", zap.String("member", member.Config().Name)) lg.Info("Traffic delay removed", zap.String("member", member.Config().Name))
proxy.UndelayRx() forwardProxy.UndelayRx()
proxy.UndelayTx() forwardProxy.UndelayTx()
return nil, nil return nil, nil
} }
@ -181,7 +175,7 @@ func (f delayPeerNetworkFailpoint) Name() string {
} }
func (f delayPeerNetworkFailpoint) Available(config e2e.EtcdProcessClusterConfig, clus e2e.EtcdProcess, profile traffic.Profile) bool { func (f delayPeerNetworkFailpoint) Available(config e2e.EtcdProcessClusterConfig, clus e2e.EtcdProcess, profile traffic.Profile) bool {
return config.ClusterSize > 1 && clus.PeerProxy() != nil return config.ClusterSize > 1 && clus.PeerForwardProxy() != nil
} }
type dropPeerNetworkFailpoint struct { type dropPeerNetworkFailpoint struct {
@ -191,15 +185,15 @@ type dropPeerNetworkFailpoint struct {
func (f dropPeerNetworkFailpoint) Inject(ctx context.Context, t *testing.T, lg *zap.Logger, clus *e2e.EtcdProcessCluster, baseTime time.Time, ids identity.Provider) ([]report.ClientReport, error) { func (f dropPeerNetworkFailpoint) Inject(ctx context.Context, t *testing.T, lg *zap.Logger, clus *e2e.EtcdProcessCluster, baseTime time.Time, ids identity.Provider) ([]report.ClientReport, error) {
member := clus.Procs[rand.Int()%len(clus.Procs)] member := clus.Procs[rand.Int()%len(clus.Procs)]
proxy := member.PeerProxy() forwardProxy := member.PeerForwardProxy()
proxy.ModifyRx(f.modifyPacket) forwardProxy.ModifyRx(f.modifyPacket)
proxy.ModifyTx(f.modifyPacket) forwardProxy.ModifyTx(f.modifyPacket)
lg.Info("Dropping traffic from and to member", zap.String("member", member.Config().Name), zap.Int("probability", f.dropProbabilityPercent)) lg.Info("Dropping traffic from and to member", zap.String("member", member.Config().Name), zap.Int("probability", f.dropProbabilityPercent))
time.Sleep(f.duration) time.Sleep(f.duration)
lg.Info("Traffic drop removed", zap.String("member", member.Config().Name)) lg.Info("Traffic drop removed", zap.String("member", member.Config().Name))
proxy.UnmodifyRx() forwardProxy.UnmodifyRx()
proxy.UnmodifyTx() forwardProxy.UnmodifyTx()
return nil, nil return nil, nil
} }
@ -215,5 +209,5 @@ func (f dropPeerNetworkFailpoint) Name() string {
} }
func (f dropPeerNetworkFailpoint) Available(config e2e.EtcdProcessClusterConfig, clus e2e.EtcdProcess, profile traffic.Profile) bool { func (f dropPeerNetworkFailpoint) Available(config e2e.EtcdProcessClusterConfig, clus e2e.EtcdProcess, profile traffic.Profile) bool {
return config.ClusterSize > 1 && clus.PeerProxy() != nil return config.ClusterSize > 1 && clus.PeerForwardProxy() != nil
} }