Merge 8c7320ae818511c5d1b8f7b17b597aa935d328be into c86c93ca2951338115159dcdd20711603044e1f1

This commit is contained in:
Chun-Hung Tseng 2024-09-26 20:52:57 +00:00 committed by GitHub
commit 85dbfc322f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 795 additions and 1078 deletions

View File

@ -18,12 +18,29 @@ import (
"context" "context"
"net" "net"
"net/http" "net/http"
"net/url"
"os"
"strings" "strings"
"time" "time"
) )
type unixTransport struct{ *http.Transport } type unixTransport struct{ *http.Transport }
var httpTransportProxyParsingFunc = determineHTTPTransportProxyParsingFunc
func determineHTTPTransportProxyParsingFunc() func(req *http.Request) (*url.URL, error) {
// according to the comment of http.ProxyFromEnvironment: if the proxy URL is "localhost"
// (with or without a port number), then a nil URL and nil error will be returned.
// Thus, we workaround this limitation by manually setting an ENV named E2E_TEST_FORWARD_PROXY_IP
// and parse the URL (which is a localhost in our case)
if forwardProxy, exists := os.LookupEnv("E2E_TEST_FORWARD_PROXY_IP"); exists {
return func(req *http.Request) (*url.URL, error) {
return url.Parse(forwardProxy)
}
}
return http.ProxyFromEnvironment
}
func NewTransport(info TLSInfo, dialtimeoutd time.Duration) (*http.Transport, error) { func NewTransport(info TLSInfo, dialtimeoutd time.Duration) (*http.Transport, error) {
cfg, err := info.ClientConfig() cfg, err := info.ClientConfig()
if err != nil { if err != nil {
@ -39,7 +56,7 @@ func NewTransport(info TLSInfo, dialtimeoutd time.Duration) (*http.Transport, er
} }
t := &http.Transport{ t := &http.Transport{
Proxy: http.ProxyFromEnvironment, Proxy: httpTransportProxyParsingFunc(),
DialContext: (&net.Dialer{ DialContext: (&net.Dialer{
Timeout: dialtimeoutd, Timeout: dialtimeoutd,
LocalAddr: ipAddr, LocalAddr: ipAddr,
@ -60,7 +77,7 @@ func NewTransport(info TLSInfo, dialtimeoutd time.Duration) (*http.Transport, er
return dialer.DialContext(ctx, "unix", addr) return dialer.DialContext(ctx, "unix", addr)
} }
tu := &http.Transport{ tu := &http.Transport{
Proxy: http.ProxyFromEnvironment, Proxy: httpTransportProxyParsingFunc(),
DialContext: dialContext, DialContext: dialContext,
TLSHandshakeTimeout: 10 * time.Second, TLSHandshakeTimeout: 10 * time.Second,
TLSClientConfig: cfg, TLSClientConfig: cfg,

File diff suppressed because it is too large Load Diff

View File

@ -17,620 +17,142 @@ package proxy
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/tls"
"fmt"
"io" "io"
"log" "log"
"math/rand"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"os"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
"go.uber.org/zap" "go.uber.org/zap"
"go.uber.org/zap/zaptest" "go.uber.org/zap/zaptest"
"go.etcd.io/etcd/client/pkg/v3/transport" "go.etcd.io/etcd/client/pkg/v3/transport"
) )
func TestServer_Unix_Insecure(t *testing.T) { testServer(t, "unix", false, false) } /* dummyServerHandler is a helper struct */
func TestServer_TCP_Insecure(t *testing.T) { testServer(t, "tcp", false, false) } type dummyServerHandler struct {
func TestServer_Unix_Secure(t *testing.T) { testServer(t, "unix", true, false) } t *testing.T
func TestServer_TCP_Secure(t *testing.T) { testServer(t, "tcp", true, false) } output chan<- []byte
func TestServer_Unix_Insecure_DelayTx(t *testing.T) { testServer(t, "unix", false, true) } }
func TestServer_TCP_Insecure_DelayTx(t *testing.T) { testServer(t, "tcp", false, true) }
func TestServer_Unix_Secure_DelayTx(t *testing.T) { testServer(t, "unix", true, true) }
func TestServer_TCP_Secure_DelayTx(t *testing.T) { testServer(t, "tcp", true, true) }
func testServer(t *testing.T, scheme string, secure bool, delayTx bool) { // ServeHTTP read the request body and write back to the response object
lg := zaptest.NewLogger(t) func (sh *dummyServerHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
srcAddr, dstAddr := newUnixAddr(), newUnixAddr() defer req.Body.Close()
if scheme == "tcp" { resp.WriteHeader(200)
ln1, ln2 := listen(t, "tcp", "localhost:0", transport.TLSInfo{}), listen(t, "tcp", "localhost:0", transport.TLSInfo{})
srcAddr, dstAddr = ln1.Addr().String(), ln2.Addr().String() if data, err := io.ReadAll(req.Body); err != nil {
ln1.Close() sh.t.Fatal(err)
ln2.Close()
} else { } else {
defer func() { sh.output <- data
os.RemoveAll(srcAddr)
os.RemoveAll(dstAddr)
}()
} }
tlsInfo := createTLSInfo(lg, secure) }
ln := listen(t, scheme, dstAddr, tlsInfo)
defer ln.Close()
cfg := ServerConfig{ func prepare(t *testing.T, serverIsClosed bool) (chan []byte, chan struct{}, Server, *http.Server, func(data []byte)) {
Logger: lg, lg := zaptest.NewLogger(t)
From: url.URL{Scheme: scheme, Host: srcAddr}, scheme := "tcp"
To: url.URL{Scheme: scheme, Host: dstAddr}, L7Scheme := "http"
}
if secure {
cfg.TLSInfo = tlsInfo
}
p := NewServer(cfg)
waitForServer(t, p) // we always send the traffic to destination with HTTPS
// this will force the CONNECT header to be sent first
tlsInfo := createTLSInfo(lg)
defer p.Close() ln1, ln2 := listen(t, "tcp", "localhost:0", transport.TLSInfo{}), listen(t, "tcp", "localhost:0", transport.TLSInfo{})
forwardProxyAddr, dstAddr := ln1.Addr().String(), ln2.Addr().String()
data1 := []byte("Hello World!") ln1.Close()
donec, writec := make(chan struct{}), make(chan []byte) ln2.Close()
go func() {
defer close(donec)
for data := range writec {
send(t, data, scheme, srcAddr, tlsInfo)
}
}()
recvc := make(chan []byte, 1) recvc := make(chan []byte, 1)
go func() { httpServer := &http.Server{
for i := 0; i < 2; i++ { Handler: &dummyServerHandler{
recvc <- receive(t, ln) t: t,
} output: recvc,
}() },
writec <- data1
now := time.Now()
if d := <-recvc; !bytes.Equal(data1, d) {
close(writec)
t.Fatalf("expected %q, got %q", string(data1), string(d))
} }
took1 := time.Since(now) go startHTTPServer(scheme, dstAddr, tlsInfo, httpServer)
t.Logf("took %v with no latency", took1)
lat, rv := 50*time.Millisecond, 5*time.Millisecond // we connect to the proxy without TLS
if delayTx { proxyURL := url.URL{Scheme: L7Scheme, Host: forwardProxyAddr}
p.DelayTx(lat, rv) cfg := ServerConfig{
Logger: lg,
Listen: proxyURL,
} }
proxyServer := NewServer(cfg)
waitForServer(t, proxyServer)
data2 := []byte("new data") // setup forward proxy
writec <- data2 t.Setenv("E2E_TEST_FORWARD_PROXY_IP", proxyURL.String())
now = time.Now() t.Logf("Proxy URL %s", proxyURL.String())
if d := <-recvc; !bytes.Equal(data2, d) {
close(writec) donec := make(chan struct{})
t.Fatalf("expected %q, got %q", string(data2), string(d))
} var tp *http.Transport
took2 := time.Since(now) var err error
if delayTx { if !tlsInfo.Empty() {
t.Logf("took %v with latency %v+-%v", took2, lat, rv) tp, err = transport.NewTransport(tlsInfo, 1*time.Second)
} else { } else {
t.Logf("took %v with no latency", took2) tp, err = transport.NewTransport(tlsInfo, 1*time.Second)
}
if err != nil {
t.Fatal(err)
}
tp.IdleConnTimeout = 100 * time.Microsecond
sendData := func(data []byte) {
send(tp, t, data, scheme, dstAddr, tlsInfo, serverIsClosed)
} }
if delayTx { return recvc, donec, proxyServer, httpServer, sendData
p.UndelayTx() }
if took2 < lat-rv {
close(writec) func destroy(t *testing.T, donec chan struct{}, proxyServer Server, serverIsClosed bool, httpServer *http.Server) {
t.Fatalf("expected took2 %v (with latency) > delay: %v", took2, lat-rv) if err := httpServer.Shutdown(context.Background()); err != nil {
} t.Fatal(err)
} }
close(writec)
select { select {
case <-donec: case <-donec:
case <-time.After(3 * time.Second): case <-time.After(3 * time.Second):
t.Fatal("took too long to write") t.Fatal("took too long to write")
} }
select { if !serverIsClosed {
case <-p.Done(): select {
t.Fatal("unexpected done") case <-proxyServer.Done():
case err := <-p.Error(): t.Fatal("unexpected done")
t.Fatal(err) case err := <-proxyServer.Error():
default: if !strings.HasSuffix(err.Error(), "use of closed network connection") {
} t.Fatal(err)
}
default:
}
if err := p.Close(); err != nil { if err := proxyServer.Close(); err != nil {
t.Fatal(err)
}
select {
case <-p.Done():
case err := <-p.Error():
if !strings.HasPrefix(err.Error(), "accept ") &&
!strings.HasSuffix(err.Error(), "use of closed network connection") {
t.Fatal(err) t.Fatal(err)
} }
case <-time.After(3 * time.Second):
t.Fatal("took too long to close")
}
}
func createTLSInfo(lg *zap.Logger, secure bool) transport.TLSInfo { select {
if secure { case <-proxyServer.Done():
return transport.TLSInfo{ case err := <-proxyServer.Error():
KeyFile: "../../tests/fixtures/server.key.insecure", if !strings.HasSuffix(err.Error(), "use of closed network connection") {
CertFile: "../../tests/fixtures/server.crt", t.Fatal(err)
TrustedCAFile: "../../tests/fixtures/ca.crt", }
ClientCertAuth: true, case <-time.After(3 * time.Second):
Logger: lg, t.Fatal("took too long to close")
} }
} }
return transport.TLSInfo{Logger: lg}
} }
func TestServer_Unix_Insecure_DelayAccept(t *testing.T) { testServerDelayAccept(t, false) } func createTLSInfo(lg *zap.Logger) transport.TLSInfo {
func TestServer_Unix_Secure_DelayAccept(t *testing.T) { testServerDelayAccept(t, true) } return transport.TLSInfo{
func testServerDelayAccept(t *testing.T, secure bool) { KeyFile: "../../tests/fixtures/server.key.insecure",
lg := zaptest.NewLogger(t) CertFile: "../../tests/fixtures/server.crt",
srcAddr, dstAddr := newUnixAddr(), newUnixAddr() TrustedCAFile: "../../tests/fixtures/ca.crt",
defer func() { ClientCertAuth: true,
os.RemoveAll(srcAddr) Logger: lg,
os.RemoveAll(dstAddr)
}()
tlsInfo := createTLSInfo(lg, secure)
scheme := "unix"
ln := listen(t, scheme, dstAddr, tlsInfo)
defer ln.Close()
cfg := ServerConfig{
Logger: lg,
From: url.URL{Scheme: scheme, Host: srcAddr},
To: url.URL{Scheme: scheme, Host: dstAddr},
} }
if secure {
cfg.TLSInfo = tlsInfo
}
p := NewServer(cfg)
waitForServer(t, p)
defer p.Close()
data := []byte("Hello World!")
now := time.Now()
send(t, data, scheme, srcAddr, tlsInfo)
if d := receive(t, ln); !bytes.Equal(data, d) {
t.Fatalf("expected %q, got %q", string(data), string(d))
}
took1 := time.Since(now)
t.Logf("took %v with no latency", took1)
lat, rv := 700*time.Millisecond, 10*time.Millisecond
p.DelayAccept(lat, rv)
defer p.UndelayAccept()
if err := p.ResetListener(); err != nil {
t.Fatal(err)
}
time.Sleep(200 * time.Millisecond)
now = time.Now()
send(t, data, scheme, srcAddr, tlsInfo)
if d := receive(t, ln); !bytes.Equal(data, d) {
t.Fatalf("expected %q, got %q", string(data), string(d))
}
took2 := time.Since(now)
t.Logf("took %v with latency %v±%v", took2, lat, rv)
if took1 >= took2 {
t.Fatalf("expected took1 %v < took2 %v", took1, took2)
}
}
func TestServer_PauseTx(t *testing.T) {
lg := zaptest.NewLogger(t)
scheme := "unix"
srcAddr, dstAddr := newUnixAddr(), newUnixAddr()
defer func() {
os.RemoveAll(srcAddr)
os.RemoveAll(dstAddr)
}()
ln := listen(t, scheme, dstAddr, transport.TLSInfo{})
defer ln.Close()
p := NewServer(ServerConfig{
Logger: lg,
From: url.URL{Scheme: scheme, Host: srcAddr},
To: url.URL{Scheme: scheme, Host: dstAddr},
})
waitForServer(t, p)
defer p.Close()
p.PauseTx()
data := []byte("Hello World!")
send(t, data, scheme, srcAddr, transport.TLSInfo{})
recvc := make(chan []byte, 1)
go func() {
recvc <- receive(t, ln)
}()
select {
case d := <-recvc:
t.Fatalf("received unexpected data %q during pause", string(d))
case <-time.After(200 * time.Millisecond):
}
p.UnpauseTx()
select {
case d := <-recvc:
if !bytes.Equal(data, d) {
t.Fatalf("expected %q, got %q", string(data), string(d))
}
case <-time.After(2 * time.Second):
t.Fatal("took too long to receive after unpause")
}
}
func TestServer_ModifyTx_corrupt(t *testing.T) {
lg := zaptest.NewLogger(t)
scheme := "unix"
srcAddr, dstAddr := newUnixAddr(), newUnixAddr()
defer func() {
os.RemoveAll(srcAddr)
os.RemoveAll(dstAddr)
}()
ln := listen(t, scheme, dstAddr, transport.TLSInfo{})
defer ln.Close()
p := NewServer(ServerConfig{
Logger: lg,
From: url.URL{Scheme: scheme, Host: srcAddr},
To: url.URL{Scheme: scheme, Host: dstAddr},
})
waitForServer(t, p)
defer p.Close()
p.ModifyTx(func(d []byte) []byte {
d[len(d)/2]++
return d
})
data := []byte("Hello World!")
send(t, data, scheme, srcAddr, transport.TLSInfo{})
if d := receive(t, ln); bytes.Equal(d, data) {
t.Fatalf("expected corrupted data, got %q", string(d))
}
p.UnmodifyTx()
send(t, data, scheme, srcAddr, transport.TLSInfo{})
if d := receive(t, ln); !bytes.Equal(d, data) {
t.Fatalf("expected uncorrupted data, got %q", string(d))
}
}
func TestServer_ModifyTx_packet_loss(t *testing.T) {
lg := zaptest.NewLogger(t)
scheme := "unix"
srcAddr, dstAddr := newUnixAddr(), newUnixAddr()
defer func() {
os.RemoveAll(srcAddr)
os.RemoveAll(dstAddr)
}()
ln := listen(t, scheme, dstAddr, transport.TLSInfo{})
defer ln.Close()
p := NewServer(ServerConfig{
Logger: lg,
From: url.URL{Scheme: scheme, Host: srcAddr},
To: url.URL{Scheme: scheme, Host: dstAddr},
})
waitForServer(t, p)
defer p.Close()
// 50% packet loss
p.ModifyTx(func(d []byte) []byte {
half := len(d) / 2
return d[:half:half]
})
data := []byte("Hello World!")
send(t, data, scheme, srcAddr, transport.TLSInfo{})
if d := receive(t, ln); bytes.Equal(d, data) {
t.Fatalf("expected corrupted data, got %q", string(d))
}
p.UnmodifyTx()
send(t, data, scheme, srcAddr, transport.TLSInfo{})
if d := receive(t, ln); !bytes.Equal(d, data) {
t.Fatalf("expected uncorrupted data, got %q", string(d))
}
}
func TestServer_BlackholeTx(t *testing.T) {
lg := zaptest.NewLogger(t)
scheme := "unix"
srcAddr, dstAddr := newUnixAddr(), newUnixAddr()
defer func() {
os.RemoveAll(srcAddr)
os.RemoveAll(dstAddr)
}()
ln := listen(t, scheme, dstAddr, transport.TLSInfo{})
defer ln.Close()
p := NewServer(ServerConfig{
Logger: lg,
From: url.URL{Scheme: scheme, Host: srcAddr},
To: url.URL{Scheme: scheme, Host: dstAddr},
})
waitForServer(t, p)
defer p.Close()
p.BlackholeTx()
data := []byte("Hello World!")
send(t, data, scheme, srcAddr, transport.TLSInfo{})
recvc := make(chan []byte, 1)
go func() {
recvc <- receive(t, ln)
}()
select {
case d := <-recvc:
t.Fatalf("unexpected data receive %q during blackhole", string(d))
case <-time.After(200 * time.Millisecond):
}
p.UnblackholeTx()
// expect different data, old data dropped
data[0]++
send(t, data, scheme, srcAddr, transport.TLSInfo{})
select {
case d := <-recvc:
if !bytes.Equal(data, d) {
t.Fatalf("expected %q, got %q", string(data), string(d))
}
case <-time.After(2 * time.Second):
t.Fatal("took too long to receive after unblackhole")
}
}
func TestServer_Shutdown(t *testing.T) {
lg := zaptest.NewLogger(t)
scheme := "unix"
srcAddr, dstAddr := newUnixAddr(), newUnixAddr()
defer func() {
os.RemoveAll(srcAddr)
os.RemoveAll(dstAddr)
}()
ln := listen(t, scheme, dstAddr, transport.TLSInfo{})
defer ln.Close()
p := NewServer(ServerConfig{
Logger: lg,
From: url.URL{Scheme: scheme, Host: srcAddr},
To: url.URL{Scheme: scheme, Host: dstAddr},
})
waitForServer(t, p)
defer p.Close()
s, _ := p.(*server)
s.listener.Close()
time.Sleep(200 * time.Millisecond)
data := []byte("Hello World!")
send(t, data, scheme, srcAddr, transport.TLSInfo{})
if d := receive(t, ln); !bytes.Equal(d, data) {
t.Fatalf("expected %q, got %q", string(data), string(d))
}
}
func TestServer_ShutdownListener(t *testing.T) {
lg := zaptest.NewLogger(t)
scheme := "unix"
srcAddr, dstAddr := newUnixAddr(), newUnixAddr()
defer func() {
os.RemoveAll(srcAddr)
os.RemoveAll(dstAddr)
}()
ln := listen(t, scheme, dstAddr, transport.TLSInfo{})
defer ln.Close()
p := NewServer(ServerConfig{
Logger: lg,
From: url.URL{Scheme: scheme, Host: srcAddr},
To: url.URL{Scheme: scheme, Host: dstAddr},
})
waitForServer(t, p)
defer p.Close()
// shut down destination
ln.Close()
time.Sleep(200 * time.Millisecond)
ln = listen(t, scheme, dstAddr, transport.TLSInfo{})
defer ln.Close()
data := []byte("Hello World!")
send(t, data, scheme, srcAddr, transport.TLSInfo{})
if d := receive(t, ln); !bytes.Equal(d, data) {
t.Fatalf("expected %q, got %q", string(data), string(d))
}
}
func TestServerHTTP_Insecure_DelayTx(t *testing.T) { testServerHTTP(t, false, true) }
func TestServerHTTP_Secure_DelayTx(t *testing.T) { testServerHTTP(t, true, true) }
func TestServerHTTP_Insecure_DelayRx(t *testing.T) { testServerHTTP(t, false, false) }
func TestServerHTTP_Secure_DelayRx(t *testing.T) { testServerHTTP(t, true, false) }
func testServerHTTP(t *testing.T, secure, delayTx bool) {
lg := zaptest.NewLogger(t)
scheme := "tcp"
ln1, ln2 := listen(t, scheme, "localhost:0", transport.TLSInfo{}), listen(t, scheme, "localhost:0", transport.TLSInfo{})
srcAddr, dstAddr := ln1.Addr().String(), ln2.Addr().String()
ln1.Close()
ln2.Close()
mux := http.NewServeMux()
mux.HandleFunc("/hello", func(w http.ResponseWriter, req *http.Request) {
d, err := io.ReadAll(req.Body)
req.Body.Close()
if err != nil {
t.Fatal(err)
}
if _, err = w.Write([]byte(fmt.Sprintf("%q(confirmed)", string(d)))); err != nil {
t.Fatal(err)
}
})
tlsInfo := createTLSInfo(lg, secure)
var tlsConfig *tls.Config
if secure {
_, err := tlsInfo.ServerConfig()
if err != nil {
t.Fatal(err)
}
}
srv := &http.Server{
Addr: dstAddr,
Handler: mux,
TLSConfig: tlsConfig,
ErrorLog: log.New(io.Discard, "net/http", 0),
}
donec := make(chan struct{})
defer func() {
srv.Close()
<-donec
}()
go func() {
if !secure {
srv.ListenAndServe()
} else {
srv.ListenAndServeTLS(tlsInfo.CertFile, tlsInfo.KeyFile)
}
defer close(donec)
}()
time.Sleep(200 * time.Millisecond)
cfg := ServerConfig{
Logger: lg,
From: url.URL{Scheme: scheme, Host: srcAddr},
To: url.URL{Scheme: scheme, Host: dstAddr},
}
if secure {
cfg.TLSInfo = tlsInfo
}
p := NewServer(cfg)
waitForServer(t, p)
defer func() {
lg.Info("closing Proxy server...")
p.Close()
lg.Info("closed Proxy server.")
}()
data := "Hello World!"
var resp *http.Response
var err error
now := time.Now()
if secure {
tp, terr := transport.NewTransport(tlsInfo, 3*time.Second)
assert.NoError(t, terr)
cli := &http.Client{Transport: tp}
resp, err = cli.Post("https://"+srcAddr+"/hello", "", strings.NewReader(data))
defer cli.CloseIdleConnections()
defer tp.CloseIdleConnections()
} else {
resp, err = http.Post("http://"+srcAddr+"/hello", "", strings.NewReader(data))
defer http.DefaultClient.CloseIdleConnections()
}
assert.NoError(t, err)
d, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
resp.Body.Close()
took1 := time.Since(now)
t.Logf("took %v with no latency", took1)
rs1 := string(d)
exp := fmt.Sprintf("%q(confirmed)", data)
if rs1 != exp {
t.Fatalf("got %q, expected %q", rs1, exp)
}
lat, rv := 100*time.Millisecond, 10*time.Millisecond
if delayTx {
p.DelayTx(lat, rv)
defer p.UndelayTx()
} else {
p.DelayRx(lat, rv)
defer p.UndelayRx()
}
now = time.Now()
if secure {
tp, terr := transport.NewTransport(tlsInfo, 3*time.Second)
if terr != nil {
t.Fatal(terr)
}
cli := &http.Client{Transport: tp}
resp, err = cli.Post("https://"+srcAddr+"/hello", "", strings.NewReader(data))
defer cli.CloseIdleConnections()
defer tp.CloseIdleConnections()
} else {
resp, err = http.Post("http://"+srcAddr+"/hello", "", strings.NewReader(data))
defer http.DefaultClient.CloseIdleConnections()
}
if err != nil {
t.Fatal(err)
}
d, err = io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
resp.Body.Close()
took2 := time.Since(now)
t.Logf("took %v with latency %v±%v", took2, lat, rv)
rs2 := string(d)
if rs2 != exp {
t.Fatalf("got %q, expected %q", rs2, exp)
}
if took1 > took2 {
t.Fatalf("expected took1 %v < took2 %v", took1, took2)
}
}
func newUnixAddr() string {
now := time.Now().UnixNano()
addr := fmt.Sprintf("%X%X.unix-conn", now, rand.Intn(35000))
os.RemoveAll(addr)
return addr
} }
func listen(t *testing.T, scheme, addr string, tlsInfo transport.TLSInfo) (ln net.Listener) { func listen(t *testing.T, scheme, addr string, tlsInfo transport.TLSInfo) (ln net.Listener) {
@ -646,46 +168,73 @@ func listen(t *testing.T, scheme, addr string, tlsInfo transport.TLSInfo) (ln ne
return ln return ln
} }
func send(t *testing.T, data []byte, scheme, addr string, tlsInfo transport.TLSInfo) { func startHTTPServer(scheme, addr string, tlsInfo transport.TLSInfo, httpServer *http.Server) {
var out net.Conn
var err error var err error
if !tlsInfo.Empty() { var ln net.Listener
tp, terr := transport.NewTransport(tlsInfo, 3*time.Second)
if terr != nil { ln, err = net.Listen(scheme, addr)
t.Fatal(terr)
}
out, err = tp.DialContext(context.Background(), scheme, addr)
} else {
out, err = net.Dial(scheme, addr)
}
if err != nil { if err != nil {
t.Fatal(err) log.Fatal(err)
} }
if _, err = out.Write(data); err != nil {
t.Fatal(err) log.Println("HTTP Server started on", addr)
} if err := httpServer.ServeTLS(ln, tlsInfo.CertFile, tlsInfo.KeyFile); err != http.ErrServerClosed {
if err = out.Close(); err != nil { // always returns error. ErrServerClosed on graceful close
t.Fatal(err) log.Fatalf("startHTTPServer ServeTLS(): %v", err)
} }
} }
func receive(t *testing.T, ln net.Listener) (data []byte) { func send(tp *http.Transport, t *testing.T, data []byte, scheme, addr string, tlsInfo transport.TLSInfo, serverIsClosed bool) {
buf := bytes.NewBuffer(make([]byte, 0, 1024)) defer func() {
for { tp.CloseIdleConnections()
in, err := ln.Accept() }()
if err != nil {
t.Fatal(err) // If you call Dial(), you will get a Conn that you can write the byte stream directly
} // If you call RoundTrip(), you will get a connection managed for you, but you need to send valid HTTP request
var n int64 dataReader := bytes.NewReader(data)
n, err = buf.ReadFrom(in) protocolScheme := scheme
if err != nil { if scheme == "tcp" {
t.Fatal(err) if !tlsInfo.Empty() {
} protocolScheme = "https"
if n > 0 { } else {
break panic("only https is supported")
} }
} else {
panic("scheme not supported")
}
rawURL := url.URL{
Scheme: protocolScheme,
Host: addr,
}
req, err := http.NewRequest("POST", rawURL.String(), dataReader)
if err != nil {
t.Fatal(err)
}
res, err := tp.RoundTrip(req)
if err != nil {
if strings.Contains(err.Error(), "TLS handshake timeout") {
t.Logf("TLS handshake timeout")
return
}
if serverIsClosed {
// when the proxy server is closed before sending, we will get this error message
if strings.Contains(err.Error(), "connect: connection refused") {
t.Logf("connect: connection refused")
return
}
}
panic(err)
}
defer func() {
if err := res.Body.Close(); err != nil {
panic(err)
}
}()
if res.StatusCode != 200 {
t.Fatalf("status code not 200")
} }
return buf.Bytes()
} }
// Waits until a proxy is ready to serve. // Waits until a proxy is ready to serve.
@ -697,3 +246,122 @@ func waitForServer(t *testing.T, s Server) {
t.Fatal(err) t.Fatal(err)
} }
} }
func TestServer_TCP(t *testing.T) { testServer(t, false, false) }
func TestServer_TCP_DelayTx(t *testing.T) { testServer(t, true, false) }
func TestServer_TCP_DelayRx(t *testing.T) { testServer(t, false, true) }
func testServer(t *testing.T, delayTx bool, delayRx bool) {
recvc, donec, proxyServer, httpServer, sendData := prepare(t, false)
defer destroy(t, donec, proxyServer, false, httpServer)
defer close(donec)
data1 := []byte("Hello World!")
sendData(data1)
now := time.Now()
if d := <-recvc; !bytes.Equal(data1, d) {
t.Fatalf("expected %q, got %q", string(data1), string(d))
}
took1 := time.Since(now)
t.Logf("took %v with no latency", took1)
lat, rv := 50*time.Millisecond, 5*time.Millisecond
if delayTx {
proxyServer.DelayTx(lat, rv)
}
if delayRx {
proxyServer.DelayRx(lat, rv)
}
data2 := []byte("new data")
now = time.Now()
sendData(data2)
if d := <-recvc; !bytes.Equal(data2, d) {
t.Fatalf("expected %q, got %q", string(data2), string(d))
}
took2 := time.Since(now)
if delayTx {
t.Logf("took %v with latency %v+-%v", took2, lat, rv)
} else {
t.Logf("took %v with no latency", took2)
}
if delayTx {
proxyServer.UndelayTx()
if took2 < lat-rv {
t.Fatalf("[delayTx] expected took2 %v (with latency) > delay: %v", took2, lat-rv)
}
}
if delayRx {
proxyServer.UndelayRx()
if took2 < lat-rv {
t.Fatalf("[delayRx] expected took2 %v (with latency) > delay: %v", took2, lat-rv)
}
}
}
func TestServer_BlackholeTx(t *testing.T) {
recvc, donec, proxyServer, httpServer, sendData := prepare(t, false)
defer destroy(t, donec, proxyServer, false, httpServer)
defer close(donec)
// before enabling blacklhole
data := []byte("Hello World!")
sendData(data)
if d := <-recvc; !bytes.Equal(data, d) {
t.Fatalf("expected %q, got %q", string(data), string(d))
}
// enable blackhole
// note that the transport is set to use 10s for TLSHandshakeTimeout, so
// this test will require at least 10s to execute, since send() is a
// blocking call thus we need to wait for ssl handshake to timeout
proxyServer.BlackholeTx()
sendData(data)
select {
case d := <-recvc:
t.Fatalf("unexpected data receive %q during blackhole", string(d))
case <-time.After(200 * time.Millisecond):
}
proxyServer.UnblackholeTx()
// disable blackhole
// TODO: figure out why HTTPS won't attempt to reconnect when the blackhole is disabled
// expect different data, old data dropped
data[0]++
sendData(data)
select {
case d := <-recvc:
if !bytes.Equal(data, d) {
t.Fatalf("expected %q, got %q", string(data), string(d))
}
case <-time.After(2 * time.Second):
t.Fatal("took too long to receive after unblackhole")
}
}
func TestServer_Shutdown(t *testing.T) {
recvc, donec, proxyServer, httpServer, sendData := prepare(t, true)
defer destroy(t, donec, proxyServer, true, httpServer)
defer close(donec)
s, _ := proxyServer.(*server)
if err := s.Close(); err != nil {
t.Fatal(err)
}
time.Sleep(200 * time.Millisecond)
data := []byte("Hello World!")
sendData(data)
select {
case d := <-recvc:
if bytes.Equal(data, d) {
t.Fatalf("expected nothing, got %q", string(d))
}
case <-time.After(2 * time.Second):
t.Log("nothing was received, proxy server seems to be closed so no traffic is forwarded")
}
}

105
tests/e2e/blackhole_test.go Normal file
View File

@ -0,0 +1,105 @@
// Copyright 2024 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build !cluster_proxy
package e2e
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.etcd.io/etcd/tests/v3/framework/e2e"
)
func TestBlackholeByMockingPartitionLeader(t *testing.T) {
blackholeTestByMockingPartition(t, 3, true)
}
func TestBlackholeByMockingPartitionFollower(t *testing.T) {
blackholeTestByMockingPartition(t, 3, false)
}
func blackholeTestByMockingPartition(t *testing.T, clusterSize int, partitionLeader bool) {
e2e.BeforeTest(t)
t.Logf("Create an etcd cluster with %d member\n", clusterSize)
epc, err := e2e.NewEtcdProcessCluster(context.TODO(), t,
e2e.WithClusterSize(clusterSize),
e2e.WithSnapshotCount(10),
e2e.WithSnapshotCatchUpEntries(10),
e2e.WithIsPeerTLS(true),
e2e.WithPeerProxy(true),
)
require.NoError(t, err, "failed to start etcd cluster: %v", err)
defer func() {
require.NoError(t, epc.Close(), "failed to close etcd cluster")
}()
leaderID := epc.WaitLeader(t)
mockPartitionNodeIndex := leaderID
if !partitionLeader {
mockPartitionNodeIndex = (leaderID + 1) % (clusterSize)
}
partitionedMember := epc.Procs[mockPartitionNodeIndex]
// Mock partition
t.Logf("Blackholing traffic from and to member %q", partitionedMember.Config().Name)
epc.BlackholePeer(partitionedMember)
t.Logf("Wait 1s for any open connections to expire")
time.Sleep(1 * time.Second)
t.Logf("Wait for new leader election with remaining members")
leaderEPC := epc.Procs[waitLeader(t, epc, mockPartitionNodeIndex)]
t.Log("Writing 20 keys to the cluster (more than SnapshotCount entries to trigger at least a snapshot.)")
writeKVs(t, leaderEPC.Etcdctl(), 0, 20)
e2e.AssertProcessLogs(t, leaderEPC, "saved snapshot")
t.Log("Verifying the partitionedMember is missing new writes")
assertRevision(t, leaderEPC, 21)
assertRevision(t, partitionedMember, 1)
// Wait for some time to restore the network
time.Sleep(1 * time.Second)
t.Logf("Unblackholing traffic from and to member %q", partitionedMember.Config().Name)
epc.UnblackholePeer(partitionedMember)
leaderEPC = epc.Procs[epc.WaitLeader(t)]
time.Sleep(1 * time.Second)
assertRevision(t, leaderEPC, 21)
assertRevision(t, partitionedMember, 21)
}
func waitLeader(t testing.TB, epc *e2e.EtcdProcessCluster, excludeNode int) int {
var membs []e2e.EtcdProcess
for i := 0; i < len(epc.Procs); i++ {
if i == excludeNode {
continue
}
membs = append(membs, epc.Procs[i])
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
return epc.WaitMembersForLeader(ctx, t, membs)
}
func assertRevision(t testing.TB, member e2e.EtcdProcess, expectedRevision int64) {
responses, err := member.Etcdctl().Status(context.TODO())
require.NoError(t, err)
assert.Equal(t, expectedRevision, responses[0].Header.Revision, "revision mismatch")
}

View File

@ -384,10 +384,10 @@ func triggerSlowApply(ctx context.Context, t *testing.T, clus *e2e.EtcdProcessCl
func blackhole(_ context.Context, t *testing.T, clus *e2e.EtcdProcessCluster, _ time.Duration) { func blackhole(_ context.Context, t *testing.T, clus *e2e.EtcdProcessCluster, _ time.Duration) {
member := clus.Procs[0] member := clus.Procs[0]
proxy := member.PeerProxy() forwardProxy := member.PeerForwardProxy()
t.Logf("Blackholing traffic from and to member %q", member.Config().Name) t.Logf("Blackholing traffic from and to member %q", member.Config().Name)
proxy.BlackholeTx() forwardProxy.BlackholeTx()
proxy.BlackholeRx() forwardProxy.BlackholeRx()
} }
func triggerRaftLoopDeadLock(ctx context.Context, t *testing.T, clus *e2e.EtcdProcessCluster, duration time.Duration) { func triggerRaftLoopDeadLock(ctx context.Context, t *testing.T, clus *e2e.EtcdProcessCluster, duration time.Duration) {

View File

@ -513,10 +513,10 @@ func (cfg *EtcdProcessClusterConfig) EtcdServerProcessConfig(tb testing.TB, i in
var curl string var curl string
port := cfg.BasePort + 5*i port := cfg.BasePort + 5*i
clientPort := port clientPort := port
peerPort := port + 1 peerPort := port + 1 // the port that the peer actually listens on
metricsPort := port + 2 metricsPort := port + 2
peer2Port := port + 3 clientHTTPPort := port + 3
clientHTTPPort := port + 4 forwardProxyPort := port + 4
if cfg.Client.ConnectionType == ClientTLSAndNonTLS { if cfg.Client.ConnectionType == ClientTLSAndNonTLS {
curl = clientURL(cfg.ClientScheme(), clientPort, ClientNonTLS) curl = clientURL(cfg.ClientScheme(), clientPort, ClientNonTLS)
@ -528,17 +528,23 @@ func (cfg *EtcdProcessClusterConfig) EtcdServerProcessConfig(tb testing.TB, i in
peerListenURL := url.URL{Scheme: cfg.PeerScheme(), Host: fmt.Sprintf("localhost:%d", peerPort)} peerListenURL := url.URL{Scheme: cfg.PeerScheme(), Host: fmt.Sprintf("localhost:%d", peerPort)}
peerAdvertiseURL := url.URL{Scheme: cfg.PeerScheme(), Host: fmt.Sprintf("localhost:%d", peerPort)} peerAdvertiseURL := url.URL{Scheme: cfg.PeerScheme(), Host: fmt.Sprintf("localhost:%d", peerPort)}
var proxyCfg *proxy.ServerConfig var forwardProxyCfg *proxy.ServerConfig
if cfg.PeerProxy { if cfg.PeerProxy {
if !cfg.IsPeerTLS { if !cfg.IsPeerTLS {
panic("Can't use peer proxy without peer TLS as it can result in malformed packets") panic("Can't use peer proxy without peer TLS as it can result in malformed packets")
} }
peerAdvertiseURL.Host = fmt.Sprintf("localhost:%d", peer2Port)
proxyCfg = &proxy.ServerConfig{ // setup forward proxy
forwardProxyURL := url.URL{Scheme: cfg.PeerScheme(), Host: fmt.Sprintf("localhost:%d", forwardProxyPort)}
forwardProxyCfg = &proxy.ServerConfig{
Logger: zap.NewNop(), Logger: zap.NewNop(),
To: peerListenURL, Listen: forwardProxyURL,
From: peerAdvertiseURL,
} }
if cfg.EnvVars == nil {
cfg.EnvVars = make(map[string]string)
}
cfg.EnvVars["E2E_TEST_FORWARD_PROXY_IP"] = fmt.Sprintf("http://127.0.0.1:%d", forwardProxyPort)
} }
name := fmt.Sprintf("%s-test-%d", testNameCleanRegex.ReplaceAllString(tb.Name(), ""), i) name := fmt.Sprintf("%s-test-%d", testNameCleanRegex.ReplaceAllString(tb.Name(), ""), i)
@ -660,7 +666,7 @@ func (cfg *EtcdProcessClusterConfig) EtcdServerProcessConfig(tb testing.TB, i in
InitialToken: cfg.ServerConfig.InitialClusterToken, InitialToken: cfg.ServerConfig.InitialClusterToken,
GoFailPort: gofailPort, GoFailPort: gofailPort,
GoFailClientTimeout: cfg.GoFailClientTimeout, GoFailClientTimeout: cfg.GoFailClientTimeout,
Proxy: proxyCfg, ForwardProxy: forwardProxyCfg,
LazyFSEnabled: cfg.LazyFSEnabled, LazyFSEnabled: cfg.LazyFSEnabled,
} }
} }
@ -910,6 +916,38 @@ func (epc *EtcdProcessCluster) Restart(ctx context.Context) error {
return epc.start(func(ep EtcdProcess) error { return ep.Restart(ctx) }) return epc.start(func(ep EtcdProcess) error { return ep.Restart(ctx) })
} }
func (epc *EtcdProcessCluster) BlackholePeer(blackholePeer EtcdProcess) error {
blackholePeer.PeerForwardProxy().BlackholeRx()
blackholePeer.PeerForwardProxy().BlackholeTx()
for _, peer := range epc.Procs {
if peer.Config().Name == blackholePeer.Config().Name {
continue
}
peer.PeerForwardProxy().BlackholePeerRx(blackholePeer.Config().PeerURL)
peer.PeerForwardProxy().BlackholePeerTx(blackholePeer.Config().PeerURL)
}
return nil
}
func (epc *EtcdProcessCluster) UnblackholePeer(blackholePeer EtcdProcess) error {
blackholePeer.PeerForwardProxy().UnblackholeRx()
blackholePeer.PeerForwardProxy().UnblackholeTx()
for _, peer := range epc.Procs {
if peer.Config().Name == blackholePeer.Config().Name {
continue
}
peer.PeerForwardProxy().UnblackholePeerRx(blackholePeer.Config().PeerURL)
peer.PeerForwardProxy().UnblackholePeerTx(blackholePeer.Config().PeerURL)
}
return nil
}
func (epc *EtcdProcessCluster) start(f func(ep EtcdProcess) error) error { func (epc *EtcdProcessCluster) start(f func(ep EtcdProcess) error) error {
readyC := make(chan error, len(epc.Procs)) readyC := make(chan error, len(epc.Procs))
for i := range epc.Procs { for i := range epc.Procs {

View File

@ -55,7 +55,7 @@ type EtcdProcess interface {
Stop() error Stop() error
Close() error Close() error
Config() *EtcdServerProcessConfig Config() *EtcdServerProcessConfig
PeerProxy() proxy.Server PeerForwardProxy() proxy.Server
Failpoints() *BinaryFailpoints Failpoints() *BinaryFailpoints
LazyFS() *LazyFS LazyFS() *LazyFS
Logs() LogsExpect Logs() LogsExpect
@ -69,12 +69,12 @@ type LogsExpect interface {
} }
type EtcdServerProcess struct { type EtcdServerProcess struct {
cfg *EtcdServerProcessConfig cfg *EtcdServerProcessConfig
proc *expect.ExpectProcess proc *expect.ExpectProcess
proxy proxy.Server forwardProxy proxy.Server
lazyfs *LazyFS lazyfs *LazyFS
failpoints *BinaryFailpoints failpoints *BinaryFailpoints
donec chan struct{} // closed when Interact() terminates donec chan struct{} // closed when Interact() terminates
} }
type EtcdServerProcessConfig struct { type EtcdServerProcessConfig struct {
@ -101,7 +101,7 @@ type EtcdServerProcessConfig struct {
GoFailClientTimeout time.Duration GoFailClientTimeout time.Duration
LazyFSEnabled bool LazyFSEnabled bool
Proxy *proxy.ServerConfig ForwardProxy *proxy.ServerConfig
} }
func NewEtcdServerProcess(t testing.TB, cfg *EtcdServerProcessConfig) (*EtcdServerProcess, error) { func NewEtcdServerProcess(t testing.TB, cfg *EtcdServerProcessConfig) (*EtcdServerProcess, error) {
@ -151,12 +151,13 @@ func (ep *EtcdServerProcess) Start(ctx context.Context) error {
if ep.proc != nil { if ep.proc != nil {
panic("already started") panic("already started")
} }
if ep.cfg.Proxy != nil && ep.proxy == nil {
ep.cfg.lg.Info("starting proxy...", zap.String("name", ep.cfg.Name), zap.String("from", ep.cfg.Proxy.From.String()), zap.String("to", ep.cfg.Proxy.To.String())) if ep.cfg.ForwardProxy != nil && ep.forwardProxy == nil {
ep.proxy = proxy.NewServer(*ep.cfg.Proxy) ep.cfg.lg.Info("starting forward proxy...", zap.String("name", ep.cfg.Name), zap.String("listen on", ep.cfg.ForwardProxy.Listen.String()))
ep.forwardProxy = proxy.NewServer(*ep.cfg.ForwardProxy)
select { select {
case <-ep.proxy.Ready(): case <-ep.forwardProxy.Ready():
case err := <-ep.proxy.Error(): case err := <-ep.forwardProxy.Error():
return err return err
} }
} }
@ -221,10 +222,10 @@ func (ep *EtcdServerProcess) Stop() (err error) {
} }
} }
ep.cfg.lg.Info("stopped server.", zap.String("name", ep.cfg.Name)) ep.cfg.lg.Info("stopped server.", zap.String("name", ep.cfg.Name))
if ep.proxy != nil { if ep.forwardProxy != nil {
ep.cfg.lg.Info("stopping proxy...", zap.String("name", ep.cfg.Name)) ep.cfg.lg.Info("stopping forward proxy...", zap.String("name", ep.cfg.Name))
err = ep.proxy.Close() err = ep.forwardProxy.Close()
ep.proxy = nil ep.forwardProxy = nil
if err != nil { if err != nil {
return err return err
} }
@ -330,8 +331,8 @@ func AssertProcessLogs(t *testing.T, ep EtcdProcess, expectLog string) {
} }
} }
func (ep *EtcdServerProcess) PeerProxy() proxy.Server { func (ep *EtcdServerProcess) PeerForwardProxy() proxy.Server {
return ep.proxy return ep.forwardProxy
} }
func (ep *EtcdServerProcess) LazyFS() *LazyFS { func (ep *EtcdServerProcess) LazyFS() *LazyFS {

View File

@ -63,23 +63,17 @@ func (tb triggerBlackhole) Available(config e2e.EtcdProcessClusterConfig, proces
if tb.waitTillSnapshot && (entriesToGuaranteeSnapshot(config) > 200 || !e2e.CouldSetSnapshotCatchupEntries(process.Config().ExecPath)) { if tb.waitTillSnapshot && (entriesToGuaranteeSnapshot(config) > 200 || !e2e.CouldSetSnapshotCatchupEntries(process.Config().ExecPath)) {
return false return false
} }
return config.ClusterSize > 1 && process.PeerProxy() != nil return config.ClusterSize > 1 && process.PeerForwardProxy() != nil
} }
func Blackhole(ctx context.Context, t *testing.T, member e2e.EtcdProcess, clus *e2e.EtcdProcessCluster, shouldWaitTillSnapshot bool) error { func Blackhole(ctx context.Context, t *testing.T, member e2e.EtcdProcess, clus *e2e.EtcdProcessCluster, shouldWaitTillSnapshot bool) error {
proxy := member.PeerProxy()
// Blackholing will cause peers to not be able to use streamWriters registered with member
// but peer traffic is still possible because member has 'pipeline' with peers
// TODO: find a way to stop all traffic
t.Logf("Blackholing traffic from and to member %q", member.Config().Name) t.Logf("Blackholing traffic from and to member %q", member.Config().Name)
proxy.BlackholeTx() clus.BlackholePeer(member)
proxy.BlackholeRx()
defer func() { defer func() {
t.Logf("Traffic restored from and to member %q", member.Config().Name) t.Logf("Traffic restored from and to member %q", member.Config().Name)
proxy.UnblackholeTx() clus.UnblackholePeer(member)
proxy.UnblackholeRx()
}() }()
if shouldWaitTillSnapshot { if shouldWaitTillSnapshot {
return waitTillSnapshot(ctx, t, clus, member) return waitTillSnapshot(ctx, t, clus, member)
} }
@ -164,15 +158,15 @@ type delayPeerNetworkFailpoint struct {
func (f delayPeerNetworkFailpoint) Inject(ctx context.Context, t *testing.T, lg *zap.Logger, clus *e2e.EtcdProcessCluster, baseTime time.Time, ids identity.Provider) ([]report.ClientReport, error) { func (f delayPeerNetworkFailpoint) Inject(ctx context.Context, t *testing.T, lg *zap.Logger, clus *e2e.EtcdProcessCluster, baseTime time.Time, ids identity.Provider) ([]report.ClientReport, error) {
member := clus.Procs[rand.Int()%len(clus.Procs)] member := clus.Procs[rand.Int()%len(clus.Procs)]
proxy := member.PeerProxy() forwardProxy := member.PeerForwardProxy()
proxy.DelayRx(f.baseLatency, f.randomizedLatency) forwardProxy.DelayRx(f.baseLatency, f.randomizedLatency)
proxy.DelayTx(f.baseLatency, f.randomizedLatency) forwardProxy.DelayTx(f.baseLatency, f.randomizedLatency)
lg.Info("Delaying traffic from and to member", zap.String("member", member.Config().Name), zap.Duration("baseLatency", f.baseLatency), zap.Duration("randomizedLatency", f.randomizedLatency)) lg.Info("Delaying traffic from and to member", zap.String("member", member.Config().Name), zap.Duration("baseLatency", f.baseLatency), zap.Duration("randomizedLatency", f.randomizedLatency))
time.Sleep(f.duration) time.Sleep(f.duration)
lg.Info("Traffic delay removed", zap.String("member", member.Config().Name)) lg.Info("Traffic delay removed", zap.String("member", member.Config().Name))
proxy.UndelayRx() forwardProxy.UndelayRx()
proxy.UndelayTx() forwardProxy.UndelayTx()
return nil, nil return nil, nil
} }
@ -181,7 +175,7 @@ func (f delayPeerNetworkFailpoint) Name() string {
} }
func (f delayPeerNetworkFailpoint) Available(config e2e.EtcdProcessClusterConfig, clus e2e.EtcdProcess, profile traffic.Profile) bool { func (f delayPeerNetworkFailpoint) Available(config e2e.EtcdProcessClusterConfig, clus e2e.EtcdProcess, profile traffic.Profile) bool {
return config.ClusterSize > 1 && clus.PeerProxy() != nil return config.ClusterSize > 1 && clus.PeerForwardProxy() != nil
} }
type dropPeerNetworkFailpoint struct { type dropPeerNetworkFailpoint struct {
@ -191,15 +185,15 @@ type dropPeerNetworkFailpoint struct {
func (f dropPeerNetworkFailpoint) Inject(ctx context.Context, t *testing.T, lg *zap.Logger, clus *e2e.EtcdProcessCluster, baseTime time.Time, ids identity.Provider) ([]report.ClientReport, error) { func (f dropPeerNetworkFailpoint) Inject(ctx context.Context, t *testing.T, lg *zap.Logger, clus *e2e.EtcdProcessCluster, baseTime time.Time, ids identity.Provider) ([]report.ClientReport, error) {
member := clus.Procs[rand.Int()%len(clus.Procs)] member := clus.Procs[rand.Int()%len(clus.Procs)]
proxy := member.PeerProxy() forwardProxy := member.PeerForwardProxy()
proxy.ModifyRx(f.modifyPacket) forwardProxy.ModifyRx(f.modifyPacket)
proxy.ModifyTx(f.modifyPacket) forwardProxy.ModifyTx(f.modifyPacket)
lg.Info("Dropping traffic from and to member", zap.String("member", member.Config().Name), zap.Int("probability", f.dropProbabilityPercent)) lg.Info("Dropping traffic from and to member", zap.String("member", member.Config().Name), zap.Int("probability", f.dropProbabilityPercent))
time.Sleep(f.duration) time.Sleep(f.duration)
lg.Info("Traffic drop removed", zap.String("member", member.Config().Name)) lg.Info("Traffic drop removed", zap.String("member", member.Config().Name))
proxy.UnmodifyRx() forwardProxy.UnmodifyRx()
proxy.UnmodifyTx() forwardProxy.UnmodifyTx()
return nil, nil return nil, nil
} }
@ -215,5 +209,5 @@ func (f dropPeerNetworkFailpoint) Name() string {
} }
func (f dropPeerNetworkFailpoint) Available(config e2e.EtcdProcessClusterConfig, clus e2e.EtcdProcess, profile traffic.Profile) bool { func (f dropPeerNetworkFailpoint) Available(config e2e.EtcdProcessClusterConfig, clus e2e.EtcdProcess, profile traffic.Profile) bool {
return config.ClusterSize > 1 && clus.PeerProxy() != nil return config.ClusterSize > 1 && clus.PeerForwardProxy() != nil
} }