Merge 8c7320ae818511c5d1b8f7b17b597aa935d328be into c86c93ca2951338115159dcdd20711603044e1f1

This commit is contained in:
Chun-Hung Tseng 2024-09-26 20:52:57 +00:00 committed by GitHub
commit 85dbfc322f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 795 additions and 1078 deletions

View File

@ -18,12 +18,29 @@ import (
"context"
"net"
"net/http"
"net/url"
"os"
"strings"
"time"
)
type unixTransport struct{ *http.Transport }
var httpTransportProxyParsingFunc = determineHTTPTransportProxyParsingFunc
func determineHTTPTransportProxyParsingFunc() func(req *http.Request) (*url.URL, error) {
// according to the comment of http.ProxyFromEnvironment: if the proxy URL is "localhost"
// (with or without a port number), then a nil URL and nil error will be returned.
// Thus, we workaround this limitation by manually setting an ENV named E2E_TEST_FORWARD_PROXY_IP
// and parse the URL (which is a localhost in our case)
if forwardProxy, exists := os.LookupEnv("E2E_TEST_FORWARD_PROXY_IP"); exists {
return func(req *http.Request) (*url.URL, error) {
return url.Parse(forwardProxy)
}
}
return http.ProxyFromEnvironment
}
func NewTransport(info TLSInfo, dialtimeoutd time.Duration) (*http.Transport, error) {
cfg, err := info.ClientConfig()
if err != nil {
@ -39,7 +56,7 @@ func NewTransport(info TLSInfo, dialtimeoutd time.Duration) (*http.Transport, er
}
t := &http.Transport{
Proxy: http.ProxyFromEnvironment,
Proxy: httpTransportProxyParsingFunc(),
DialContext: (&net.Dialer{
Timeout: dialtimeoutd,
LocalAddr: ipAddr,
@ -60,7 +77,7 @@ func NewTransport(info TLSInfo, dialtimeoutd time.Duration) (*http.Transport, er
return dialer.DialContext(ctx, "unix", addr)
}
tu := &http.Transport{
Proxy: http.ProxyFromEnvironment,
Proxy: httpTransportProxyParsingFunc(),
DialContext: dialContext,
TLSHandshakeTimeout: 10 * time.Second,
TLSClientConfig: cfg,

File diff suppressed because it is too large Load Diff

View File

@ -17,620 +17,142 @@ package proxy
import (
"bytes"
"context"
"crypto/tls"
"fmt"
"io"
"log"
"math/rand"
"net"
"net/http"
"net/url"
"os"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
"go.uber.org/zap/zaptest"
"go.etcd.io/etcd/client/pkg/v3/transport"
)
func TestServer_Unix_Insecure(t *testing.T) { testServer(t, "unix", false, false) }
func TestServer_TCP_Insecure(t *testing.T) { testServer(t, "tcp", false, false) }
func TestServer_Unix_Secure(t *testing.T) { testServer(t, "unix", true, false) }
func TestServer_TCP_Secure(t *testing.T) { testServer(t, "tcp", true, false) }
func TestServer_Unix_Insecure_DelayTx(t *testing.T) { testServer(t, "unix", false, true) }
func TestServer_TCP_Insecure_DelayTx(t *testing.T) { testServer(t, "tcp", false, true) }
func TestServer_Unix_Secure_DelayTx(t *testing.T) { testServer(t, "unix", true, true) }
func TestServer_TCP_Secure_DelayTx(t *testing.T) { testServer(t, "tcp", true, true) }
/* dummyServerHandler is a helper struct */
type dummyServerHandler struct {
t *testing.T
output chan<- []byte
}
func testServer(t *testing.T, scheme string, secure bool, delayTx bool) {
lg := zaptest.NewLogger(t)
srcAddr, dstAddr := newUnixAddr(), newUnixAddr()
if scheme == "tcp" {
ln1, ln2 := listen(t, "tcp", "localhost:0", transport.TLSInfo{}), listen(t, "tcp", "localhost:0", transport.TLSInfo{})
srcAddr, dstAddr = ln1.Addr().String(), ln2.Addr().String()
ln1.Close()
ln2.Close()
// ServeHTTP read the request body and write back to the response object
func (sh *dummyServerHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
defer req.Body.Close()
resp.WriteHeader(200)
if data, err := io.ReadAll(req.Body); err != nil {
sh.t.Fatal(err)
} else {
defer func() {
os.RemoveAll(srcAddr)
os.RemoveAll(dstAddr)
}()
sh.output <- data
}
tlsInfo := createTLSInfo(lg, secure)
ln := listen(t, scheme, dstAddr, tlsInfo)
defer ln.Close()
}
cfg := ServerConfig{
Logger: lg,
From: url.URL{Scheme: scheme, Host: srcAddr},
To: url.URL{Scheme: scheme, Host: dstAddr},
}
if secure {
cfg.TLSInfo = tlsInfo
}
p := NewServer(cfg)
func prepare(t *testing.T, serverIsClosed bool) (chan []byte, chan struct{}, Server, *http.Server, func(data []byte)) {
lg := zaptest.NewLogger(t)
scheme := "tcp"
L7Scheme := "http"
waitForServer(t, p)
// we always send the traffic to destination with HTTPS
// this will force the CONNECT header to be sent first
tlsInfo := createTLSInfo(lg)
defer p.Close()
data1 := []byte("Hello World!")
donec, writec := make(chan struct{}), make(chan []byte)
go func() {
defer close(donec)
for data := range writec {
send(t, data, scheme, srcAddr, tlsInfo)
}
}()
ln1, ln2 := listen(t, "tcp", "localhost:0", transport.TLSInfo{}), listen(t, "tcp", "localhost:0", transport.TLSInfo{})
forwardProxyAddr, dstAddr := ln1.Addr().String(), ln2.Addr().String()
ln1.Close()
ln2.Close()
recvc := make(chan []byte, 1)
go func() {
for i := 0; i < 2; i++ {
recvc <- receive(t, ln)
}
}()
writec <- data1
now := time.Now()
if d := <-recvc; !bytes.Equal(data1, d) {
close(writec)
t.Fatalf("expected %q, got %q", string(data1), string(d))
httpServer := &http.Server{
Handler: &dummyServerHandler{
t: t,
output: recvc,
},
}
took1 := time.Since(now)
t.Logf("took %v with no latency", took1)
go startHTTPServer(scheme, dstAddr, tlsInfo, httpServer)
lat, rv := 50*time.Millisecond, 5*time.Millisecond
if delayTx {
p.DelayTx(lat, rv)
// we connect to the proxy without TLS
proxyURL := url.URL{Scheme: L7Scheme, Host: forwardProxyAddr}
cfg := ServerConfig{
Logger: lg,
Listen: proxyURL,
}
proxyServer := NewServer(cfg)
waitForServer(t, proxyServer)
data2 := []byte("new data")
writec <- data2
now = time.Now()
if d := <-recvc; !bytes.Equal(data2, d) {
close(writec)
t.Fatalf("expected %q, got %q", string(data2), string(d))
}
took2 := time.Since(now)
if delayTx {
t.Logf("took %v with latency %v+-%v", took2, lat, rv)
// setup forward proxy
t.Setenv("E2E_TEST_FORWARD_PROXY_IP", proxyURL.String())
t.Logf("Proxy URL %s", proxyURL.String())
donec := make(chan struct{})
var tp *http.Transport
var err error
if !tlsInfo.Empty() {
tp, err = transport.NewTransport(tlsInfo, 1*time.Second)
} else {
t.Logf("took %v with no latency", took2)
tp, err = transport.NewTransport(tlsInfo, 1*time.Second)
}
if err != nil {
t.Fatal(err)
}
tp.IdleConnTimeout = 100 * time.Microsecond
sendData := func(data []byte) {
send(tp, t, data, scheme, dstAddr, tlsInfo, serverIsClosed)
}
if delayTx {
p.UndelayTx()
if took2 < lat-rv {
close(writec)
t.Fatalf("expected took2 %v (with latency) > delay: %v", took2, lat-rv)
}
return recvc, donec, proxyServer, httpServer, sendData
}
func destroy(t *testing.T, donec chan struct{}, proxyServer Server, serverIsClosed bool, httpServer *http.Server) {
if err := httpServer.Shutdown(context.Background()); err != nil {
t.Fatal(err)
}
close(writec)
select {
case <-donec:
case <-time.After(3 * time.Second):
t.Fatal("took too long to write")
}
select {
case <-p.Done():
t.Fatal("unexpected done")
case err := <-p.Error():
t.Fatal(err)
default:
}
if !serverIsClosed {
select {
case <-proxyServer.Done():
t.Fatal("unexpected done")
case err := <-proxyServer.Error():
if !strings.HasSuffix(err.Error(), "use of closed network connection") {
t.Fatal(err)
}
default:
}
if err := p.Close(); err != nil {
t.Fatal(err)
}
select {
case <-p.Done():
case err := <-p.Error():
if !strings.HasPrefix(err.Error(), "accept ") &&
!strings.HasSuffix(err.Error(), "use of closed network connection") {
if err := proxyServer.Close(); err != nil {
t.Fatal(err)
}
case <-time.After(3 * time.Second):
t.Fatal("took too long to close")
}
}
func createTLSInfo(lg *zap.Logger, secure bool) transport.TLSInfo {
if secure {
return transport.TLSInfo{
KeyFile: "../../tests/fixtures/server.key.insecure",
CertFile: "../../tests/fixtures/server.crt",
TrustedCAFile: "../../tests/fixtures/ca.crt",
ClientCertAuth: true,
Logger: lg,
select {
case <-proxyServer.Done():
case err := <-proxyServer.Error():
if !strings.HasSuffix(err.Error(), "use of closed network connection") {
t.Fatal(err)
}
case <-time.After(3 * time.Second):
t.Fatal("took too long to close")
}
}
return transport.TLSInfo{Logger: lg}
}
func TestServer_Unix_Insecure_DelayAccept(t *testing.T) { testServerDelayAccept(t, false) }
func TestServer_Unix_Secure_DelayAccept(t *testing.T) { testServerDelayAccept(t, true) }
func testServerDelayAccept(t *testing.T, secure bool) {
lg := zaptest.NewLogger(t)
srcAddr, dstAddr := newUnixAddr(), newUnixAddr()
defer func() {
os.RemoveAll(srcAddr)
os.RemoveAll(dstAddr)
}()
tlsInfo := createTLSInfo(lg, secure)
scheme := "unix"
ln := listen(t, scheme, dstAddr, tlsInfo)
defer ln.Close()
cfg := ServerConfig{
Logger: lg,
From: url.URL{Scheme: scheme, Host: srcAddr},
To: url.URL{Scheme: scheme, Host: dstAddr},
func createTLSInfo(lg *zap.Logger) transport.TLSInfo {
return transport.TLSInfo{
KeyFile: "../../tests/fixtures/server.key.insecure",
CertFile: "../../tests/fixtures/server.crt",
TrustedCAFile: "../../tests/fixtures/ca.crt",
ClientCertAuth: true,
Logger: lg,
}
if secure {
cfg.TLSInfo = tlsInfo
}
p := NewServer(cfg)
waitForServer(t, p)
defer p.Close()
data := []byte("Hello World!")
now := time.Now()
send(t, data, scheme, srcAddr, tlsInfo)
if d := receive(t, ln); !bytes.Equal(data, d) {
t.Fatalf("expected %q, got %q", string(data), string(d))
}
took1 := time.Since(now)
t.Logf("took %v with no latency", took1)
lat, rv := 700*time.Millisecond, 10*time.Millisecond
p.DelayAccept(lat, rv)
defer p.UndelayAccept()
if err := p.ResetListener(); err != nil {
t.Fatal(err)
}
time.Sleep(200 * time.Millisecond)
now = time.Now()
send(t, data, scheme, srcAddr, tlsInfo)
if d := receive(t, ln); !bytes.Equal(data, d) {
t.Fatalf("expected %q, got %q", string(data), string(d))
}
took2 := time.Since(now)
t.Logf("took %v with latency %v±%v", took2, lat, rv)
if took1 >= took2 {
t.Fatalf("expected took1 %v < took2 %v", took1, took2)
}
}
func TestServer_PauseTx(t *testing.T) {
lg := zaptest.NewLogger(t)
scheme := "unix"
srcAddr, dstAddr := newUnixAddr(), newUnixAddr()
defer func() {
os.RemoveAll(srcAddr)
os.RemoveAll(dstAddr)
}()
ln := listen(t, scheme, dstAddr, transport.TLSInfo{})
defer ln.Close()
p := NewServer(ServerConfig{
Logger: lg,
From: url.URL{Scheme: scheme, Host: srcAddr},
To: url.URL{Scheme: scheme, Host: dstAddr},
})
waitForServer(t, p)
defer p.Close()
p.PauseTx()
data := []byte("Hello World!")
send(t, data, scheme, srcAddr, transport.TLSInfo{})
recvc := make(chan []byte, 1)
go func() {
recvc <- receive(t, ln)
}()
select {
case d := <-recvc:
t.Fatalf("received unexpected data %q during pause", string(d))
case <-time.After(200 * time.Millisecond):
}
p.UnpauseTx()
select {
case d := <-recvc:
if !bytes.Equal(data, d) {
t.Fatalf("expected %q, got %q", string(data), string(d))
}
case <-time.After(2 * time.Second):
t.Fatal("took too long to receive after unpause")
}
}
func TestServer_ModifyTx_corrupt(t *testing.T) {
lg := zaptest.NewLogger(t)
scheme := "unix"
srcAddr, dstAddr := newUnixAddr(), newUnixAddr()
defer func() {
os.RemoveAll(srcAddr)
os.RemoveAll(dstAddr)
}()
ln := listen(t, scheme, dstAddr, transport.TLSInfo{})
defer ln.Close()
p := NewServer(ServerConfig{
Logger: lg,
From: url.URL{Scheme: scheme, Host: srcAddr},
To: url.URL{Scheme: scheme, Host: dstAddr},
})
waitForServer(t, p)
defer p.Close()
p.ModifyTx(func(d []byte) []byte {
d[len(d)/2]++
return d
})
data := []byte("Hello World!")
send(t, data, scheme, srcAddr, transport.TLSInfo{})
if d := receive(t, ln); bytes.Equal(d, data) {
t.Fatalf("expected corrupted data, got %q", string(d))
}
p.UnmodifyTx()
send(t, data, scheme, srcAddr, transport.TLSInfo{})
if d := receive(t, ln); !bytes.Equal(d, data) {
t.Fatalf("expected uncorrupted data, got %q", string(d))
}
}
func TestServer_ModifyTx_packet_loss(t *testing.T) {
lg := zaptest.NewLogger(t)
scheme := "unix"
srcAddr, dstAddr := newUnixAddr(), newUnixAddr()
defer func() {
os.RemoveAll(srcAddr)
os.RemoveAll(dstAddr)
}()
ln := listen(t, scheme, dstAddr, transport.TLSInfo{})
defer ln.Close()
p := NewServer(ServerConfig{
Logger: lg,
From: url.URL{Scheme: scheme, Host: srcAddr},
To: url.URL{Scheme: scheme, Host: dstAddr},
})
waitForServer(t, p)
defer p.Close()
// 50% packet loss
p.ModifyTx(func(d []byte) []byte {
half := len(d) / 2
return d[:half:half]
})
data := []byte("Hello World!")
send(t, data, scheme, srcAddr, transport.TLSInfo{})
if d := receive(t, ln); bytes.Equal(d, data) {
t.Fatalf("expected corrupted data, got %q", string(d))
}
p.UnmodifyTx()
send(t, data, scheme, srcAddr, transport.TLSInfo{})
if d := receive(t, ln); !bytes.Equal(d, data) {
t.Fatalf("expected uncorrupted data, got %q", string(d))
}
}
func TestServer_BlackholeTx(t *testing.T) {
lg := zaptest.NewLogger(t)
scheme := "unix"
srcAddr, dstAddr := newUnixAddr(), newUnixAddr()
defer func() {
os.RemoveAll(srcAddr)
os.RemoveAll(dstAddr)
}()
ln := listen(t, scheme, dstAddr, transport.TLSInfo{})
defer ln.Close()
p := NewServer(ServerConfig{
Logger: lg,
From: url.URL{Scheme: scheme, Host: srcAddr},
To: url.URL{Scheme: scheme, Host: dstAddr},
})
waitForServer(t, p)
defer p.Close()
p.BlackholeTx()
data := []byte("Hello World!")
send(t, data, scheme, srcAddr, transport.TLSInfo{})
recvc := make(chan []byte, 1)
go func() {
recvc <- receive(t, ln)
}()
select {
case d := <-recvc:
t.Fatalf("unexpected data receive %q during blackhole", string(d))
case <-time.After(200 * time.Millisecond):
}
p.UnblackholeTx()
// expect different data, old data dropped
data[0]++
send(t, data, scheme, srcAddr, transport.TLSInfo{})
select {
case d := <-recvc:
if !bytes.Equal(data, d) {
t.Fatalf("expected %q, got %q", string(data), string(d))
}
case <-time.After(2 * time.Second):
t.Fatal("took too long to receive after unblackhole")
}
}
func TestServer_Shutdown(t *testing.T) {
lg := zaptest.NewLogger(t)
scheme := "unix"
srcAddr, dstAddr := newUnixAddr(), newUnixAddr()
defer func() {
os.RemoveAll(srcAddr)
os.RemoveAll(dstAddr)
}()
ln := listen(t, scheme, dstAddr, transport.TLSInfo{})
defer ln.Close()
p := NewServer(ServerConfig{
Logger: lg,
From: url.URL{Scheme: scheme, Host: srcAddr},
To: url.URL{Scheme: scheme, Host: dstAddr},
})
waitForServer(t, p)
defer p.Close()
s, _ := p.(*server)
s.listener.Close()
time.Sleep(200 * time.Millisecond)
data := []byte("Hello World!")
send(t, data, scheme, srcAddr, transport.TLSInfo{})
if d := receive(t, ln); !bytes.Equal(d, data) {
t.Fatalf("expected %q, got %q", string(data), string(d))
}
}
func TestServer_ShutdownListener(t *testing.T) {
lg := zaptest.NewLogger(t)
scheme := "unix"
srcAddr, dstAddr := newUnixAddr(), newUnixAddr()
defer func() {
os.RemoveAll(srcAddr)
os.RemoveAll(dstAddr)
}()
ln := listen(t, scheme, dstAddr, transport.TLSInfo{})
defer ln.Close()
p := NewServer(ServerConfig{
Logger: lg,
From: url.URL{Scheme: scheme, Host: srcAddr},
To: url.URL{Scheme: scheme, Host: dstAddr},
})
waitForServer(t, p)
defer p.Close()
// shut down destination
ln.Close()
time.Sleep(200 * time.Millisecond)
ln = listen(t, scheme, dstAddr, transport.TLSInfo{})
defer ln.Close()
data := []byte("Hello World!")
send(t, data, scheme, srcAddr, transport.TLSInfo{})
if d := receive(t, ln); !bytes.Equal(d, data) {
t.Fatalf("expected %q, got %q", string(data), string(d))
}
}
func TestServerHTTP_Insecure_DelayTx(t *testing.T) { testServerHTTP(t, false, true) }
func TestServerHTTP_Secure_DelayTx(t *testing.T) { testServerHTTP(t, true, true) }
func TestServerHTTP_Insecure_DelayRx(t *testing.T) { testServerHTTP(t, false, false) }
func TestServerHTTP_Secure_DelayRx(t *testing.T) { testServerHTTP(t, true, false) }
func testServerHTTP(t *testing.T, secure, delayTx bool) {
lg := zaptest.NewLogger(t)
scheme := "tcp"
ln1, ln2 := listen(t, scheme, "localhost:0", transport.TLSInfo{}), listen(t, scheme, "localhost:0", transport.TLSInfo{})
srcAddr, dstAddr := ln1.Addr().String(), ln2.Addr().String()
ln1.Close()
ln2.Close()
mux := http.NewServeMux()
mux.HandleFunc("/hello", func(w http.ResponseWriter, req *http.Request) {
d, err := io.ReadAll(req.Body)
req.Body.Close()
if err != nil {
t.Fatal(err)
}
if _, err = w.Write([]byte(fmt.Sprintf("%q(confirmed)", string(d)))); err != nil {
t.Fatal(err)
}
})
tlsInfo := createTLSInfo(lg, secure)
var tlsConfig *tls.Config
if secure {
_, err := tlsInfo.ServerConfig()
if err != nil {
t.Fatal(err)
}
}
srv := &http.Server{
Addr: dstAddr,
Handler: mux,
TLSConfig: tlsConfig,
ErrorLog: log.New(io.Discard, "net/http", 0),
}
donec := make(chan struct{})
defer func() {
srv.Close()
<-donec
}()
go func() {
if !secure {
srv.ListenAndServe()
} else {
srv.ListenAndServeTLS(tlsInfo.CertFile, tlsInfo.KeyFile)
}
defer close(donec)
}()
time.Sleep(200 * time.Millisecond)
cfg := ServerConfig{
Logger: lg,
From: url.URL{Scheme: scheme, Host: srcAddr},
To: url.URL{Scheme: scheme, Host: dstAddr},
}
if secure {
cfg.TLSInfo = tlsInfo
}
p := NewServer(cfg)
waitForServer(t, p)
defer func() {
lg.Info("closing Proxy server...")
p.Close()
lg.Info("closed Proxy server.")
}()
data := "Hello World!"
var resp *http.Response
var err error
now := time.Now()
if secure {
tp, terr := transport.NewTransport(tlsInfo, 3*time.Second)
assert.NoError(t, terr)
cli := &http.Client{Transport: tp}
resp, err = cli.Post("https://"+srcAddr+"/hello", "", strings.NewReader(data))
defer cli.CloseIdleConnections()
defer tp.CloseIdleConnections()
} else {
resp, err = http.Post("http://"+srcAddr+"/hello", "", strings.NewReader(data))
defer http.DefaultClient.CloseIdleConnections()
}
assert.NoError(t, err)
d, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
resp.Body.Close()
took1 := time.Since(now)
t.Logf("took %v with no latency", took1)
rs1 := string(d)
exp := fmt.Sprintf("%q(confirmed)", data)
if rs1 != exp {
t.Fatalf("got %q, expected %q", rs1, exp)
}
lat, rv := 100*time.Millisecond, 10*time.Millisecond
if delayTx {
p.DelayTx(lat, rv)
defer p.UndelayTx()
} else {
p.DelayRx(lat, rv)
defer p.UndelayRx()
}
now = time.Now()
if secure {
tp, terr := transport.NewTransport(tlsInfo, 3*time.Second)
if terr != nil {
t.Fatal(terr)
}
cli := &http.Client{Transport: tp}
resp, err = cli.Post("https://"+srcAddr+"/hello", "", strings.NewReader(data))
defer cli.CloseIdleConnections()
defer tp.CloseIdleConnections()
} else {
resp, err = http.Post("http://"+srcAddr+"/hello", "", strings.NewReader(data))
defer http.DefaultClient.CloseIdleConnections()
}
if err != nil {
t.Fatal(err)
}
d, err = io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
resp.Body.Close()
took2 := time.Since(now)
t.Logf("took %v with latency %v±%v", took2, lat, rv)
rs2 := string(d)
if rs2 != exp {
t.Fatalf("got %q, expected %q", rs2, exp)
}
if took1 > took2 {
t.Fatalf("expected took1 %v < took2 %v", took1, took2)
}
}
func newUnixAddr() string {
now := time.Now().UnixNano()
addr := fmt.Sprintf("%X%X.unix-conn", now, rand.Intn(35000))
os.RemoveAll(addr)
return addr
}
func listen(t *testing.T, scheme, addr string, tlsInfo transport.TLSInfo) (ln net.Listener) {
@ -646,46 +168,73 @@ func listen(t *testing.T, scheme, addr string, tlsInfo transport.TLSInfo) (ln ne
return ln
}
func send(t *testing.T, data []byte, scheme, addr string, tlsInfo transport.TLSInfo) {
var out net.Conn
func startHTTPServer(scheme, addr string, tlsInfo transport.TLSInfo, httpServer *http.Server) {
var err error
if !tlsInfo.Empty() {
tp, terr := transport.NewTransport(tlsInfo, 3*time.Second)
if terr != nil {
t.Fatal(terr)
}
out, err = tp.DialContext(context.Background(), scheme, addr)
} else {
out, err = net.Dial(scheme, addr)
}
var ln net.Listener
ln, err = net.Listen(scheme, addr)
if err != nil {
t.Fatal(err)
log.Fatal(err)
}
if _, err = out.Write(data); err != nil {
t.Fatal(err)
}
if err = out.Close(); err != nil {
t.Fatal(err)
log.Println("HTTP Server started on", addr)
if err := httpServer.ServeTLS(ln, tlsInfo.CertFile, tlsInfo.KeyFile); err != http.ErrServerClosed {
// always returns error. ErrServerClosed on graceful close
log.Fatalf("startHTTPServer ServeTLS(): %v", err)
}
}
func receive(t *testing.T, ln net.Listener) (data []byte) {
buf := bytes.NewBuffer(make([]byte, 0, 1024))
for {
in, err := ln.Accept()
if err != nil {
t.Fatal(err)
}
var n int64
n, err = buf.ReadFrom(in)
if err != nil {
t.Fatal(err)
}
if n > 0 {
break
func send(tp *http.Transport, t *testing.T, data []byte, scheme, addr string, tlsInfo transport.TLSInfo, serverIsClosed bool) {
defer func() {
tp.CloseIdleConnections()
}()
// If you call Dial(), you will get a Conn that you can write the byte stream directly
// If you call RoundTrip(), you will get a connection managed for you, but you need to send valid HTTP request
dataReader := bytes.NewReader(data)
protocolScheme := scheme
if scheme == "tcp" {
if !tlsInfo.Empty() {
protocolScheme = "https"
} else {
panic("only https is supported")
}
} else {
panic("scheme not supported")
}
rawURL := url.URL{
Scheme: protocolScheme,
Host: addr,
}
req, err := http.NewRequest("POST", rawURL.String(), dataReader)
if err != nil {
t.Fatal(err)
}
res, err := tp.RoundTrip(req)
if err != nil {
if strings.Contains(err.Error(), "TLS handshake timeout") {
t.Logf("TLS handshake timeout")
return
}
if serverIsClosed {
// when the proxy server is closed before sending, we will get this error message
if strings.Contains(err.Error(), "connect: connection refused") {
t.Logf("connect: connection refused")
return
}
}
panic(err)
}
defer func() {
if err := res.Body.Close(); err != nil {
panic(err)
}
}()
if res.StatusCode != 200 {
t.Fatalf("status code not 200")
}
return buf.Bytes()
}
// Waits until a proxy is ready to serve.
@ -697,3 +246,122 @@ func waitForServer(t *testing.T, s Server) {
t.Fatal(err)
}
}
func TestServer_TCP(t *testing.T) { testServer(t, false, false) }
func TestServer_TCP_DelayTx(t *testing.T) { testServer(t, true, false) }
func TestServer_TCP_DelayRx(t *testing.T) { testServer(t, false, true) }
func testServer(t *testing.T, delayTx bool, delayRx bool) {
recvc, donec, proxyServer, httpServer, sendData := prepare(t, false)
defer destroy(t, donec, proxyServer, false, httpServer)
defer close(donec)
data1 := []byte("Hello World!")
sendData(data1)
now := time.Now()
if d := <-recvc; !bytes.Equal(data1, d) {
t.Fatalf("expected %q, got %q", string(data1), string(d))
}
took1 := time.Since(now)
t.Logf("took %v with no latency", took1)
lat, rv := 50*time.Millisecond, 5*time.Millisecond
if delayTx {
proxyServer.DelayTx(lat, rv)
}
if delayRx {
proxyServer.DelayRx(lat, rv)
}
data2 := []byte("new data")
now = time.Now()
sendData(data2)
if d := <-recvc; !bytes.Equal(data2, d) {
t.Fatalf("expected %q, got %q", string(data2), string(d))
}
took2 := time.Since(now)
if delayTx {
t.Logf("took %v with latency %v+-%v", took2, lat, rv)
} else {
t.Logf("took %v with no latency", took2)
}
if delayTx {
proxyServer.UndelayTx()
if took2 < lat-rv {
t.Fatalf("[delayTx] expected took2 %v (with latency) > delay: %v", took2, lat-rv)
}
}
if delayRx {
proxyServer.UndelayRx()
if took2 < lat-rv {
t.Fatalf("[delayRx] expected took2 %v (with latency) > delay: %v", took2, lat-rv)
}
}
}
func TestServer_BlackholeTx(t *testing.T) {
recvc, donec, proxyServer, httpServer, sendData := prepare(t, false)
defer destroy(t, donec, proxyServer, false, httpServer)
defer close(donec)
// before enabling blacklhole
data := []byte("Hello World!")
sendData(data)
if d := <-recvc; !bytes.Equal(data, d) {
t.Fatalf("expected %q, got %q", string(data), string(d))
}
// enable blackhole
// note that the transport is set to use 10s for TLSHandshakeTimeout, so
// this test will require at least 10s to execute, since send() is a
// blocking call thus we need to wait for ssl handshake to timeout
proxyServer.BlackholeTx()
sendData(data)
select {
case d := <-recvc:
t.Fatalf("unexpected data receive %q during blackhole", string(d))
case <-time.After(200 * time.Millisecond):
}
proxyServer.UnblackholeTx()
// disable blackhole
// TODO: figure out why HTTPS won't attempt to reconnect when the blackhole is disabled
// expect different data, old data dropped
data[0]++
sendData(data)
select {
case d := <-recvc:
if !bytes.Equal(data, d) {
t.Fatalf("expected %q, got %q", string(data), string(d))
}
case <-time.After(2 * time.Second):
t.Fatal("took too long to receive after unblackhole")
}
}
func TestServer_Shutdown(t *testing.T) {
recvc, donec, proxyServer, httpServer, sendData := prepare(t, true)
defer destroy(t, donec, proxyServer, true, httpServer)
defer close(donec)
s, _ := proxyServer.(*server)
if err := s.Close(); err != nil {
t.Fatal(err)
}
time.Sleep(200 * time.Millisecond)
data := []byte("Hello World!")
sendData(data)
select {
case d := <-recvc:
if bytes.Equal(data, d) {
t.Fatalf("expected nothing, got %q", string(d))
}
case <-time.After(2 * time.Second):
t.Log("nothing was received, proxy server seems to be closed so no traffic is forwarded")
}
}

105
tests/e2e/blackhole_test.go Normal file
View File

@ -0,0 +1,105 @@
// Copyright 2024 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build !cluster_proxy
package e2e
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.etcd.io/etcd/tests/v3/framework/e2e"
)
func TestBlackholeByMockingPartitionLeader(t *testing.T) {
blackholeTestByMockingPartition(t, 3, true)
}
func TestBlackholeByMockingPartitionFollower(t *testing.T) {
blackholeTestByMockingPartition(t, 3, false)
}
func blackholeTestByMockingPartition(t *testing.T, clusterSize int, partitionLeader bool) {
e2e.BeforeTest(t)
t.Logf("Create an etcd cluster with %d member\n", clusterSize)
epc, err := e2e.NewEtcdProcessCluster(context.TODO(), t,
e2e.WithClusterSize(clusterSize),
e2e.WithSnapshotCount(10),
e2e.WithSnapshotCatchUpEntries(10),
e2e.WithIsPeerTLS(true),
e2e.WithPeerProxy(true),
)
require.NoError(t, err, "failed to start etcd cluster: %v", err)
defer func() {
require.NoError(t, epc.Close(), "failed to close etcd cluster")
}()
leaderID := epc.WaitLeader(t)
mockPartitionNodeIndex := leaderID
if !partitionLeader {
mockPartitionNodeIndex = (leaderID + 1) % (clusterSize)
}
partitionedMember := epc.Procs[mockPartitionNodeIndex]
// Mock partition
t.Logf("Blackholing traffic from and to member %q", partitionedMember.Config().Name)
epc.BlackholePeer(partitionedMember)
t.Logf("Wait 1s for any open connections to expire")
time.Sleep(1 * time.Second)
t.Logf("Wait for new leader election with remaining members")
leaderEPC := epc.Procs[waitLeader(t, epc, mockPartitionNodeIndex)]
t.Log("Writing 20 keys to the cluster (more than SnapshotCount entries to trigger at least a snapshot.)")
writeKVs(t, leaderEPC.Etcdctl(), 0, 20)
e2e.AssertProcessLogs(t, leaderEPC, "saved snapshot")
t.Log("Verifying the partitionedMember is missing new writes")
assertRevision(t, leaderEPC, 21)
assertRevision(t, partitionedMember, 1)
// Wait for some time to restore the network
time.Sleep(1 * time.Second)
t.Logf("Unblackholing traffic from and to member %q", partitionedMember.Config().Name)
epc.UnblackholePeer(partitionedMember)
leaderEPC = epc.Procs[epc.WaitLeader(t)]
time.Sleep(1 * time.Second)
assertRevision(t, leaderEPC, 21)
assertRevision(t, partitionedMember, 21)
}
func waitLeader(t testing.TB, epc *e2e.EtcdProcessCluster, excludeNode int) int {
var membs []e2e.EtcdProcess
for i := 0; i < len(epc.Procs); i++ {
if i == excludeNode {
continue
}
membs = append(membs, epc.Procs[i])
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
return epc.WaitMembersForLeader(ctx, t, membs)
}
func assertRevision(t testing.TB, member e2e.EtcdProcess, expectedRevision int64) {
responses, err := member.Etcdctl().Status(context.TODO())
require.NoError(t, err)
assert.Equal(t, expectedRevision, responses[0].Header.Revision, "revision mismatch")
}

View File

@ -384,10 +384,10 @@ func triggerSlowApply(ctx context.Context, t *testing.T, clus *e2e.EtcdProcessCl
func blackhole(_ context.Context, t *testing.T, clus *e2e.EtcdProcessCluster, _ time.Duration) {
member := clus.Procs[0]
proxy := member.PeerProxy()
forwardProxy := member.PeerForwardProxy()
t.Logf("Blackholing traffic from and to member %q", member.Config().Name)
proxy.BlackholeTx()
proxy.BlackholeRx()
forwardProxy.BlackholeTx()
forwardProxy.BlackholeRx()
}
func triggerRaftLoopDeadLock(ctx context.Context, t *testing.T, clus *e2e.EtcdProcessCluster, duration time.Duration) {

View File

@ -513,10 +513,10 @@ func (cfg *EtcdProcessClusterConfig) EtcdServerProcessConfig(tb testing.TB, i in
var curl string
port := cfg.BasePort + 5*i
clientPort := port
peerPort := port + 1
peerPort := port + 1 // the port that the peer actually listens on
metricsPort := port + 2
peer2Port := port + 3
clientHTTPPort := port + 4
clientHTTPPort := port + 3
forwardProxyPort := port + 4
if cfg.Client.ConnectionType == ClientTLSAndNonTLS {
curl = clientURL(cfg.ClientScheme(), clientPort, ClientNonTLS)
@ -528,17 +528,23 @@ func (cfg *EtcdProcessClusterConfig) EtcdServerProcessConfig(tb testing.TB, i in
peerListenURL := url.URL{Scheme: cfg.PeerScheme(), Host: fmt.Sprintf("localhost:%d", peerPort)}
peerAdvertiseURL := url.URL{Scheme: cfg.PeerScheme(), Host: fmt.Sprintf("localhost:%d", peerPort)}
var proxyCfg *proxy.ServerConfig
var forwardProxyCfg *proxy.ServerConfig
if cfg.PeerProxy {
if !cfg.IsPeerTLS {
panic("Can't use peer proxy without peer TLS as it can result in malformed packets")
}
peerAdvertiseURL.Host = fmt.Sprintf("localhost:%d", peer2Port)
proxyCfg = &proxy.ServerConfig{
// setup forward proxy
forwardProxyURL := url.URL{Scheme: cfg.PeerScheme(), Host: fmt.Sprintf("localhost:%d", forwardProxyPort)}
forwardProxyCfg = &proxy.ServerConfig{
Logger: zap.NewNop(),
To: peerListenURL,
From: peerAdvertiseURL,
Listen: forwardProxyURL,
}
if cfg.EnvVars == nil {
cfg.EnvVars = make(map[string]string)
}
cfg.EnvVars["E2E_TEST_FORWARD_PROXY_IP"] = fmt.Sprintf("http://127.0.0.1:%d", forwardProxyPort)
}
name := fmt.Sprintf("%s-test-%d", testNameCleanRegex.ReplaceAllString(tb.Name(), ""), i)
@ -660,7 +666,7 @@ func (cfg *EtcdProcessClusterConfig) EtcdServerProcessConfig(tb testing.TB, i in
InitialToken: cfg.ServerConfig.InitialClusterToken,
GoFailPort: gofailPort,
GoFailClientTimeout: cfg.GoFailClientTimeout,
Proxy: proxyCfg,
ForwardProxy: forwardProxyCfg,
LazyFSEnabled: cfg.LazyFSEnabled,
}
}
@ -910,6 +916,38 @@ func (epc *EtcdProcessCluster) Restart(ctx context.Context) error {
return epc.start(func(ep EtcdProcess) error { return ep.Restart(ctx) })
}
func (epc *EtcdProcessCluster) BlackholePeer(blackholePeer EtcdProcess) error {
blackholePeer.PeerForwardProxy().BlackholeRx()
blackholePeer.PeerForwardProxy().BlackholeTx()
for _, peer := range epc.Procs {
if peer.Config().Name == blackholePeer.Config().Name {
continue
}
peer.PeerForwardProxy().BlackholePeerRx(blackholePeer.Config().PeerURL)
peer.PeerForwardProxy().BlackholePeerTx(blackholePeer.Config().PeerURL)
}
return nil
}
func (epc *EtcdProcessCluster) UnblackholePeer(blackholePeer EtcdProcess) error {
blackholePeer.PeerForwardProxy().UnblackholeRx()
blackholePeer.PeerForwardProxy().UnblackholeTx()
for _, peer := range epc.Procs {
if peer.Config().Name == blackholePeer.Config().Name {
continue
}
peer.PeerForwardProxy().UnblackholePeerRx(blackholePeer.Config().PeerURL)
peer.PeerForwardProxy().UnblackholePeerTx(blackholePeer.Config().PeerURL)
}
return nil
}
func (epc *EtcdProcessCluster) start(f func(ep EtcdProcess) error) error {
readyC := make(chan error, len(epc.Procs))
for i := range epc.Procs {

View File

@ -55,7 +55,7 @@ type EtcdProcess interface {
Stop() error
Close() error
Config() *EtcdServerProcessConfig
PeerProxy() proxy.Server
PeerForwardProxy() proxy.Server
Failpoints() *BinaryFailpoints
LazyFS() *LazyFS
Logs() LogsExpect
@ -69,12 +69,12 @@ type LogsExpect interface {
}
type EtcdServerProcess struct {
cfg *EtcdServerProcessConfig
proc *expect.ExpectProcess
proxy proxy.Server
lazyfs *LazyFS
failpoints *BinaryFailpoints
donec chan struct{} // closed when Interact() terminates
cfg *EtcdServerProcessConfig
proc *expect.ExpectProcess
forwardProxy proxy.Server
lazyfs *LazyFS
failpoints *BinaryFailpoints
donec chan struct{} // closed when Interact() terminates
}
type EtcdServerProcessConfig struct {
@ -101,7 +101,7 @@ type EtcdServerProcessConfig struct {
GoFailClientTimeout time.Duration
LazyFSEnabled bool
Proxy *proxy.ServerConfig
ForwardProxy *proxy.ServerConfig
}
func NewEtcdServerProcess(t testing.TB, cfg *EtcdServerProcessConfig) (*EtcdServerProcess, error) {
@ -151,12 +151,13 @@ func (ep *EtcdServerProcess) Start(ctx context.Context) error {
if ep.proc != nil {
panic("already started")
}
if ep.cfg.Proxy != nil && ep.proxy == nil {
ep.cfg.lg.Info("starting proxy...", zap.String("name", ep.cfg.Name), zap.String("from", ep.cfg.Proxy.From.String()), zap.String("to", ep.cfg.Proxy.To.String()))
ep.proxy = proxy.NewServer(*ep.cfg.Proxy)
if ep.cfg.ForwardProxy != nil && ep.forwardProxy == nil {
ep.cfg.lg.Info("starting forward proxy...", zap.String("name", ep.cfg.Name), zap.String("listen on", ep.cfg.ForwardProxy.Listen.String()))
ep.forwardProxy = proxy.NewServer(*ep.cfg.ForwardProxy)
select {
case <-ep.proxy.Ready():
case err := <-ep.proxy.Error():
case <-ep.forwardProxy.Ready():
case err := <-ep.forwardProxy.Error():
return err
}
}
@ -221,10 +222,10 @@ func (ep *EtcdServerProcess) Stop() (err error) {
}
}
ep.cfg.lg.Info("stopped server.", zap.String("name", ep.cfg.Name))
if ep.proxy != nil {
ep.cfg.lg.Info("stopping proxy...", zap.String("name", ep.cfg.Name))
err = ep.proxy.Close()
ep.proxy = nil
if ep.forwardProxy != nil {
ep.cfg.lg.Info("stopping forward proxy...", zap.String("name", ep.cfg.Name))
err = ep.forwardProxy.Close()
ep.forwardProxy = nil
if err != nil {
return err
}
@ -330,8 +331,8 @@ func AssertProcessLogs(t *testing.T, ep EtcdProcess, expectLog string) {
}
}
func (ep *EtcdServerProcess) PeerProxy() proxy.Server {
return ep.proxy
func (ep *EtcdServerProcess) PeerForwardProxy() proxy.Server {
return ep.forwardProxy
}
func (ep *EtcdServerProcess) LazyFS() *LazyFS {

View File

@ -63,23 +63,17 @@ func (tb triggerBlackhole) Available(config e2e.EtcdProcessClusterConfig, proces
if tb.waitTillSnapshot && (entriesToGuaranteeSnapshot(config) > 200 || !e2e.CouldSetSnapshotCatchupEntries(process.Config().ExecPath)) {
return false
}
return config.ClusterSize > 1 && process.PeerProxy() != nil
return config.ClusterSize > 1 && process.PeerForwardProxy() != nil
}
func Blackhole(ctx context.Context, t *testing.T, member e2e.EtcdProcess, clus *e2e.EtcdProcessCluster, shouldWaitTillSnapshot bool) error {
proxy := member.PeerProxy()
// Blackholing will cause peers to not be able to use streamWriters registered with member
// but peer traffic is still possible because member has 'pipeline' with peers
// TODO: find a way to stop all traffic
t.Logf("Blackholing traffic from and to member %q", member.Config().Name)
proxy.BlackholeTx()
proxy.BlackholeRx()
clus.BlackholePeer(member)
defer func() {
t.Logf("Traffic restored from and to member %q", member.Config().Name)
proxy.UnblackholeTx()
proxy.UnblackholeRx()
clus.UnblackholePeer(member)
}()
if shouldWaitTillSnapshot {
return waitTillSnapshot(ctx, t, clus, member)
}
@ -164,15 +158,15 @@ type delayPeerNetworkFailpoint struct {
func (f delayPeerNetworkFailpoint) Inject(ctx context.Context, t *testing.T, lg *zap.Logger, clus *e2e.EtcdProcessCluster, baseTime time.Time, ids identity.Provider) ([]report.ClientReport, error) {
member := clus.Procs[rand.Int()%len(clus.Procs)]
proxy := member.PeerProxy()
forwardProxy := member.PeerForwardProxy()
proxy.DelayRx(f.baseLatency, f.randomizedLatency)
proxy.DelayTx(f.baseLatency, f.randomizedLatency)
forwardProxy.DelayRx(f.baseLatency, f.randomizedLatency)
forwardProxy.DelayTx(f.baseLatency, f.randomizedLatency)
lg.Info("Delaying traffic from and to member", zap.String("member", member.Config().Name), zap.Duration("baseLatency", f.baseLatency), zap.Duration("randomizedLatency", f.randomizedLatency))
time.Sleep(f.duration)
lg.Info("Traffic delay removed", zap.String("member", member.Config().Name))
proxy.UndelayRx()
proxy.UndelayTx()
forwardProxy.UndelayRx()
forwardProxy.UndelayTx()
return nil, nil
}
@ -181,7 +175,7 @@ func (f delayPeerNetworkFailpoint) Name() string {
}
func (f delayPeerNetworkFailpoint) Available(config e2e.EtcdProcessClusterConfig, clus e2e.EtcdProcess, profile traffic.Profile) bool {
return config.ClusterSize > 1 && clus.PeerProxy() != nil
return config.ClusterSize > 1 && clus.PeerForwardProxy() != nil
}
type dropPeerNetworkFailpoint struct {
@ -191,15 +185,15 @@ type dropPeerNetworkFailpoint struct {
func (f dropPeerNetworkFailpoint) Inject(ctx context.Context, t *testing.T, lg *zap.Logger, clus *e2e.EtcdProcessCluster, baseTime time.Time, ids identity.Provider) ([]report.ClientReport, error) {
member := clus.Procs[rand.Int()%len(clus.Procs)]
proxy := member.PeerProxy()
forwardProxy := member.PeerForwardProxy()
proxy.ModifyRx(f.modifyPacket)
proxy.ModifyTx(f.modifyPacket)
forwardProxy.ModifyRx(f.modifyPacket)
forwardProxy.ModifyTx(f.modifyPacket)
lg.Info("Dropping traffic from and to member", zap.String("member", member.Config().Name), zap.Int("probability", f.dropProbabilityPercent))
time.Sleep(f.duration)
lg.Info("Traffic drop removed", zap.String("member", member.Config().Name))
proxy.UnmodifyRx()
proxy.UnmodifyTx()
forwardProxy.UnmodifyRx()
forwardProxy.UnmodifyTx()
return nil, nil
}
@ -215,5 +209,5 @@ func (f dropPeerNetworkFailpoint) Name() string {
}
func (f dropPeerNetworkFailpoint) Available(config e2e.EtcdProcessClusterConfig, clus e2e.EtcdProcess, profile traffic.Profile) bool {
return config.ClusterSize > 1 && clus.PeerProxy() != nil
return config.ClusterSize > 1 && clus.PeerForwardProxy() != nil
}