From 49078c683bb52e0440c6fdfc9671022f8cc6ef39 Mon Sep 17 00:00:00 2001 From: Sam Batschelet Date: Fri, 19 Feb 2021 08:19:57 -0500 Subject: [PATCH 1/2] *: add support for socket options Signed-off-by: Sam Batschelet --- pkg/transport/listener.go | 34 +++++++++++-- pkg/transport/listener_test.go | 70 ++++++++++++++++++++++++++ pkg/transport/sockopt.go | 45 +++++++++++++++++ pkg/transport/sockopt_unix.go | 20 ++++++++ pkg/transport/sockopt_windows.go | 18 +++++++ pkg/transport/timeout_listener.go | 25 ++++++--- server/embed/config.go | 5 ++ server/embed/etcd.go | 14 ++++-- server/etcdmain/config.go | 2 + server/etcdmain/help.go | 4 ++ server/etcdserver/api/rafthttp/util.go | 4 ++ server/etcdserver/config.go | 3 ++ 12 files changed, 232 insertions(+), 12 deletions(-) create mode 100644 pkg/transport/sockopt.go create mode 100644 pkg/transport/sockopt_unix.go create mode 100644 pkg/transport/sockopt_windows.go diff --git a/pkg/transport/listener.go b/pkg/transport/listener.go index df9a895bb..31ba4876f 100644 --- a/pkg/transport/listener.go +++ b/pkg/transport/listener.go @@ -15,6 +15,7 @@ package transport import ( + "context" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" @@ -39,18 +40,34 @@ import ( // NewListener creates a new listner. func NewListener(addr, scheme string, tlsinfo *TLSInfo) (l net.Listener, err error) { - if l, err = newListener(addr, scheme); err != nil { + if l, err = newListener(addr, scheme, nil); err != nil { return nil, err } return wrapTLS(scheme, tlsinfo, l) } -func newListener(addr string, scheme string) (net.Listener, error) { +// NewListenerWithSocketOpts creates new listener with support for socket options. +func NewListenerWithSocketOpts(addr, scheme string, tlsinfo *TLSInfo, sopts *SocketOpts) (net.Listener, error) { + ln, err := newListener(addr, scheme, sopts) + if err != nil { + return nil, err + } + if tlsinfo != nil { + wrapTLS(scheme, tlsinfo, ln) + } + return ln, nil +} + +func newListener(addr string, scheme string, sopts *SocketOpts) (net.Listener, error) { if scheme == "unix" || scheme == "unixs" { // unix sockets via unix://laddr return NewUnixListener(addr) } - return net.Listen("tcp", addr) + config, err := newListenConfig(sopts) + if err != nil { + return nil, err + } + return config.Listen(context.TODO(), "tcp", addr) } func wrapTLS(scheme string, tlsinfo *TLSInfo, l net.Listener) (net.Listener, error) { @@ -63,6 +80,17 @@ func wrapTLS(scheme string, tlsinfo *TLSInfo, l net.Listener) (net.Listener, err return newTLSListener(l, tlsinfo, checkSAN) } +func newListenConfig(sopts *SocketOpts) (net.ListenConfig, error) { + lc := net.ListenConfig{} + if sopts != nil { + ctls := getControls(sopts) + if len(ctls) > 0 { + lc.Control = ctls.Control + } + } + return lc, nil +} + type TLSInfo struct { CertFile string KeyFile string diff --git a/pkg/transport/listener_test.go b/pkg/transport/listener_test.go index a34d97055..b79f34a31 100644 --- a/pkg/transport/listener_test.go +++ b/pkg/transport/listener_test.go @@ -61,6 +61,58 @@ func TestNewListenerTLSInfo(t *testing.T) { testNewListenerTLSInfoAccept(t, *tlsInfo) } +func TestNewListenerWithSocketOpts(t *testing.T) { + tlsInfo, del, err := createSelfCert() + if err != nil { + t.Fatalf("unable to create cert: %v", err) + } + defer del() + tests := map[string]struct { + socketOpts *SocketOpts + expectedErr bool + }{ + "nil": { + socketOpts: nil, + expectedErr: true, + }, + "empty": { + socketOpts: &SocketOpts{}, + expectedErr: true, + }, + "reuse address": { + socketOpts: &SocketOpts{ReuseAddress: true}, + expectedErr: true, + }, + "reuse address and reuse port": { + socketOpts: &SocketOpts{ReuseAddress: true, ReusePort: true}, + expectedErr: false, + }, + "reuse port": { + socketOpts: &SocketOpts{ReusePort: true}, + expectedErr: false, + }, + } + for testName, test := range tests { + t.Run(testName, func(t *testing.T) { + ln, err := NewListenerWithSocketOpts("127.0.0.1:0", "https", tlsInfo, test.socketOpts) + if err != nil { + t.Fatalf("unexpected NewListenerWithSocketOpts error: %v", err) + } + defer ln.Close() + ln2, err := NewListenerWithSocketOpts(ln.Addr().String(), "https", tlsInfo, test.socketOpts) + if test.expectedErr && err == nil { + t.Fatalf("expected error") + } + if !test.expectedErr && err != nil { + t.Fatalf("unexpected NewListenerWithSocketOpts error: %v", err) + } + if ln2 != nil { + ln2.Close() + } + }) + } +} + func testNewListenerTLSInfoAccept(t *testing.T, tlsInfo TLSInfo) { ln, err := NewListener("127.0.0.1:0", "https", &tlsInfo) if err != nil { @@ -401,3 +453,21 @@ func TestIsClosedConnError(t *testing.T) { t.Fatalf("expect true, got false (%v)", err) } } + +func TestSocktOptsEmpty(t *testing.T) { + tests := []struct { + sopts SocketOpts + want bool + }{ + {SocketOpts{}, true}, + {SocketOpts{ReuseAddress: true, ReusePort: false}, false}, + {SocketOpts{ReusePort: true}, false}, + } + + for i, tt := range tests { + got := tt.sopts.Empty() + if tt.want != got { + t.Errorf("#%d: result of Empty() incorrect: want=%t got=%t", i, tt.want, got) + } + } +} diff --git a/pkg/transport/sockopt.go b/pkg/transport/sockopt.go new file mode 100644 index 000000000..9941048a8 --- /dev/null +++ b/pkg/transport/sockopt.go @@ -0,0 +1,45 @@ +package transport + +import ( + "syscall" +) + +type Controls []func(network, addr string, conn syscall.RawConn) error + +func (ctls Controls) Control(network, addr string, conn syscall.RawConn) error { + for _, s := range ctls { + if err := s(network, addr, conn); err != nil { + return err + } + } + return nil +} + +type SocketOpts struct { + // ReusePort enables socket option SO_REUSEPORT [1] which allows rebind of + // a port already in use. User should keep in mind that flock can fail + // in which case lock on data file could result in unexpected + // condition. User should take caution to protect against lock race. + // [1] https://man7.org/linux/man-pages/man7/socket.7.html + ReusePort bool + // ReuseAddress enables a socket option SO_REUSEADDR which allows + // binding to an address in `TIME_WAIT` state. Useful to improve MTTR + // in cases where etcd slow to restart due to excessive `TIME_WAIT`. + // [1] https://man7.org/linux/man-pages/man7/socket.7.html + ReuseAddress bool +} + +func getControls(sopts *SocketOpts) Controls { + ctls := Controls{} + if sopts.ReuseAddress { + ctls = append(ctls, setReuseAddress) + } + if sopts.ReusePort { + ctls = append(ctls, setReusePort) + } + return ctls +} + +func (sopts *SocketOpts) Empty() bool { + return sopts.ReuseAddress == false && sopts.ReusePort == false +} diff --git a/pkg/transport/sockopt_unix.go b/pkg/transport/sockopt_unix.go new file mode 100644 index 000000000..a3322fded --- /dev/null +++ b/pkg/transport/sockopt_unix.go @@ -0,0 +1,20 @@ +// +build !windows + +package transport + +import ( + "golang.org/x/sys/unix" + "syscall" +) + +func setReusePort(network, address string, conn syscall.RawConn) error { + return conn.Control(func(fd uintptr) { + syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, unix.SO_REUSEPORT, 1) + }) +} + +func setReuseAddress(network, address string, conn syscall.RawConn) error { + return conn.Control(func(fd uintptr) { + syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, unix.SO_REUSEADDR, 1) + }) +} diff --git a/pkg/transport/sockopt_windows.go b/pkg/transport/sockopt_windows.go new file mode 100644 index 000000000..000077991 --- /dev/null +++ b/pkg/transport/sockopt_windows.go @@ -0,0 +1,18 @@ +// +build windows + +package transport + +import ( + "fmt" + "syscall" +) + +func setReusePort(network, address string, c syscall.RawConn) error { + return fmt.Errorf("port reuse is not supported on Windows") +} + +// Windows supports SO_REUSEADDR, but it may cause undefined behavior, as +// there is no protection against port hijacking. +func setReuseAddress(network, addr string, conn syscall.RawConn) error { + return fmt.Errorf("address reuse is not supported on Windows") +} diff --git a/pkg/transport/timeout_listener.go b/pkg/transport/timeout_listener.go index 273e99fe0..29a62d997 100644 --- a/pkg/transport/timeout_listener.go +++ b/pkg/transport/timeout_listener.go @@ -23,19 +23,32 @@ import ( // If read/write on the accepted connection blocks longer than its time limit, // it will return timeout error. func NewTimeoutListener(addr string, scheme string, tlsinfo *TLSInfo, rdtimeoutd, wtimeoutd time.Duration) (net.Listener, error) { - ln, err := newListener(addr, scheme) + ln, err := newListener(addr, scheme, nil) if err != nil { return nil, err } - ln = &rwTimeoutListener{ + return newTimeoutListener(ln, scheme, rdtimeoutd, wtimeoutd, tlsinfo) +} + +// NewTimeoutListerWithSocketOpts returns a listener that listens on the given address. +// If read/write on the accepted connection blocks longer than its time limit, +// it will return timeout error. Socket options can be passed and will be applied to the +// ListenConfig. +func NewTimeoutListerWithSocketOpts(addr string, scheme string, tlsinfo *TLSInfo, rdtimeoutd, wtimeoutd time.Duration, sopts *SocketOpts) (net.Listener, error) { + ln, err := newListener(addr, scheme, sopts) + if err != nil { + return nil, err + } + return newTimeoutListener(ln, scheme, rdtimeoutd, wtimeoutd, tlsinfo) +} + +func newTimeoutListener(ln net.Listener, scheme string, rdtimeoutd, wtimeoutd time.Duration, tlsinfo *TLSInfo) (net.Listener, error) { + timeoutListener := &rwTimeoutListener{ Listener: ln, rdtimeoutd: rdtimeoutd, wtimeoutd: wtimeoutd, } - if ln, err = wrapTLS(scheme, tlsinfo, ln); err != nil { - return nil, err - } - return ln, nil + return wrapTLS(scheme, tlsinfo, timeoutListener) } type rwTimeoutListener struct { diff --git a/server/embed/config.go b/server/embed/config.go index e91518cbc..5558254c1 100644 --- a/server/embed/config.go +++ b/server/embed/config.go @@ -232,6 +232,9 @@ type Config struct { // before closing a non-responsive connection. 0 to disable. GRPCKeepAliveTimeout time.Duration `json:"grpc-keepalive-timeout"` + // SocketOpts are socket options passed to listener config. + SocketOpts transport.SocketOpts + // PreVote is true to enable Raft Pre-Vote. // If enabled, Raft runs an additional election phase // to check whether it would get enough votes to win @@ -398,6 +401,8 @@ func NewConfig() *Config { GRPCKeepAliveInterval: DefaultGRPCKeepAliveInterval, GRPCKeepAliveTimeout: DefaultGRPCKeepAliveTimeout, + SocketOpts: transport.SocketOpts{}, + TickMs: 100, ElectionMs: 1000, InitialElectionTickAdvance: true, diff --git a/server/embed/etcd.go b/server/embed/etcd.go index c5b99a9ff..3d75bf1b2 100644 --- a/server/embed/etcd.go +++ b/server/embed/etcd.go @@ -110,6 +110,13 @@ func StartEtcd(inCfg *Config) (e *Etcd, err error) { e = nil }() + if !cfg.SocketOpts.Empty() { + cfg.logger.Info( + "configuring socket options", + zap.Bool("reuse-address", cfg.SocketOpts.ReuseAddress), + zap.Bool("reuse-port", cfg.SocketOpts.ReusePort), + ) + } e.cfg.logger.Info( "configuring peer listeners", zap.Strings("listen-peer-urls", e.cfg.getLPURLs()), @@ -181,6 +188,7 @@ func StartEtcd(inCfg *Config) (e *Etcd, err error) { BackendBatchInterval: cfg.BackendBatchInterval, MaxTxnOps: cfg.MaxTxnOps, MaxRequestBytes: cfg.MaxRequestBytes, + SocketOpts: cfg.SocketOpts, StrictReconfigCheck: cfg.StrictReconfigCheck, ClientCertAuthEnabled: cfg.ClientTLSInfo.ClientCertAuth, AuthToken: cfg.AuthToken, @@ -458,7 +466,7 @@ func configurePeerListeners(cfg *Config) (peers []*peerListener, err error) { } } peers[i] = &peerListener{close: func(context.Context) error { return nil }} - peers[i].Listener, err = rafthttp.NewListener(u, &cfg.PeerTLSInfo) + peers[i].Listener, err = rafthttp.NewListenerWithSocketOpts(u, &cfg.PeerTLSInfo, &cfg.SocketOpts) if err != nil { return nil, err } @@ -565,7 +573,7 @@ func configureClientListeners(cfg *Config) (sctxs map[string]*serveCtx, err erro continue } - if sctx.l, err = net.Listen(network, addr); err != nil { + if sctx.l, err = transport.NewListenerWithSocketOpts(addr, u.Scheme, nil, &cfg.SocketOpts); err != nil { return nil, err } // net.Listener will rewrite ipv4 0.0.0.0 to ipv6 [::], breaking @@ -678,7 +686,7 @@ func (e *Etcd) serveMetrics() (err error) { if murl.Scheme == "http" { tlsInfo = nil } - ml, err := transport.NewListener(murl.Host, murl.Scheme, tlsInfo) + ml, err := transport.NewListenerWithSocketOpts(murl.Host, murl.Scheme, tlsInfo, &e.cfg.SocketOpts) if err != nil { return err } diff --git a/server/etcdmain/config.go b/server/etcdmain/config.go index ef7dcc283..ce62ae69e 100644 --- a/server/etcdmain/config.go +++ b/server/etcdmain/config.go @@ -160,6 +160,8 @@ func newConfig() *config { fs.DurationVar(&cfg.ec.GRPCKeepAliveMinTime, "grpc-keepalive-min-time", cfg.ec.GRPCKeepAliveMinTime, "Minimum interval duration that a client should wait before pinging server.") fs.DurationVar(&cfg.ec.GRPCKeepAliveInterval, "grpc-keepalive-interval", cfg.ec.GRPCKeepAliveInterval, "Frequency duration of server-to-client ping to check if a connection is alive (0 to disable).") fs.DurationVar(&cfg.ec.GRPCKeepAliveTimeout, "grpc-keepalive-timeout", cfg.ec.GRPCKeepAliveTimeout, "Additional duration of wait before closing a non-responsive connection (0 to disable).") + fs.BoolVar(&cfg.ec.SocketOpts.ReusePort, "socket-reuse-port", cfg.ec.SocketOpts.ReusePort, "Enable to set socket option SO_REUSEPORT on listeners allowing rebinding of a port already in use.") + fs.BoolVar(&cfg.ec.SocketOpts.ReuseAddress, "socket-reuse-address", cfg.ec.SocketOpts.ReuseAddress, "Enable to set socket option SO_REUSEADDR on listeners allowing binding to an address in `TIME_WAIT` state.") // clustering fs.Var( diff --git a/server/etcdmain/help.go b/server/etcdmain/help.go index 0834881ab..1579f5558 100644 --- a/server/etcdmain/help.go +++ b/server/etcdmain/help.go @@ -85,6 +85,10 @@ Member: Frequency duration of server-to-client ping to check if a connection is alive (0 to disable). --grpc-keepalive-timeout '20s' Additional duration of wait before closing a non-responsive connection (0 to disable). + --socket-reuse-port 'false' + Enable to set socket option SO_REUSEPORT on listeners allowing rebinding of a port already in use. + --socket-reuse-address 'false' + Enable to set socket option SO_REUSEADDR on listeners allowing binding to an address in TIME_WAIT state. Clustering: --initial-advertise-peer-urls 'http://localhost:2380' diff --git a/server/etcdserver/api/rafthttp/util.go b/server/etcdserver/api/rafthttp/util.go index 37bdac8e6..905974451 100644 --- a/server/etcdserver/api/rafthttp/util.go +++ b/server/etcdserver/api/rafthttp/util.go @@ -42,6 +42,10 @@ func NewListener(u url.URL, tlsinfo *transport.TLSInfo) (net.Listener, error) { return transport.NewTimeoutListener(u.Host, u.Scheme, tlsinfo, ConnReadTimeout, ConnWriteTimeout) } +func NewListenerWithSocketOpts(u url.URL, tlsinfo *transport.TLSInfo, sopts *transport.SocketOpts) (net.Listener, error) { + return transport.NewTimeoutListerWithSocketOpts(u.Host, u.Scheme, tlsinfo, ConnReadTimeout, ConnWriteTimeout, sopts) +} + // NewRoundTripper returns a roundTripper used to send requests // to rafthttp listener of remote peers. func NewRoundTripper(tlsInfo transport.TLSInfo, dialTimeout time.Duration) (http.RoundTripper, error) { diff --git a/server/etcdserver/config.go b/server/etcdserver/config.go index 49fe04005..4a6716509 100644 --- a/server/etcdserver/config.go +++ b/server/etcdserver/config.go @@ -138,6 +138,9 @@ type ServerConfig struct { // PreVote is true to enable Raft Pre-Vote. PreVote bool + // SocketOpts are socket options passed to listener config. + SocketOpts transport.SocketOpts + // Logger logs server-side operations. // If not nil, it disables "capnslog" and uses the given logger. Logger *zap.Logger From 5b49fb41c8c29a9f041a686240ee0f165e243295 Mon Sep 17 00:00:00 2001 From: Sam Batschelet Date: Mon, 8 Mar 2021 09:01:58 -0500 Subject: [PATCH 2/2] fixup: add ListenerOptions Signed-off-by: Sam Batschelet --- pkg/transport/listener.go | 70 +++++++++---- pkg/transport/listener_opts.go | 76 ++++++++++++++ pkg/transport/listener_test.go | 133 +++++++++++++++++++++--- pkg/transport/timeout_conn.go | 12 +-- pkg/transport/timeout_dialer.go | 6 +- pkg/transport/timeout_listener.go | 39 ++----- pkg/transport/timeout_listener_test.go | 18 ++-- pkg/transport/timeout_transport_test.go | 8 +- server/embed/etcd.go | 16 ++- server/etcdserver/api/rafthttp/util.go | 6 +- 10 files changed, 288 insertions(+), 96 deletions(-) create mode 100644 pkg/transport/listener_opts.go diff --git a/pkg/transport/listener.go b/pkg/transport/listener.go index 31ba4876f..c2bfdd869 100644 --- a/pkg/transport/listener.go +++ b/pkg/transport/listener.go @@ -40,34 +40,66 @@ import ( // NewListener creates a new listner. func NewListener(addr, scheme string, tlsinfo *TLSInfo) (l net.Listener, err error) { - if l, err = newListener(addr, scheme, nil); err != nil { - return nil, err - } - return wrapTLS(scheme, tlsinfo, l) + return newListener(addr, scheme, WithTLSInfo(tlsinfo)) } -// NewListenerWithSocketOpts creates new listener with support for socket options. -func NewListenerWithSocketOpts(addr, scheme string, tlsinfo *TLSInfo, sopts *SocketOpts) (net.Listener, error) { - ln, err := newListener(addr, scheme, sopts) - if err != nil { - return nil, err - } - if tlsinfo != nil { - wrapTLS(scheme, tlsinfo, ln) - } - return ln, nil +// NewListenerWithOpts creates a new listener which accpets listener options. +func NewListenerWithOpts(addr, scheme string, opts ...ListenerOption) (net.Listener, error) { + return newListener(addr, scheme, opts...) } -func newListener(addr string, scheme string, sopts *SocketOpts) (net.Listener, error) { +func newListener(addr, scheme string, opts ...ListenerOption) (net.Listener, error) { if scheme == "unix" || scheme == "unixs" { // unix sockets via unix://laddr return NewUnixListener(addr) } - config, err := newListenConfig(sopts) - if err != nil { - return nil, err + + lnOpts := newListenOpts(opts...) + + switch { + case lnOpts.IsSocketOpts(): + // new ListenConfig with socket options. + config, err := newListenConfig(lnOpts.socketOpts) + if err != nil { + return nil, err + } + lnOpts.ListenConfig = config + // check for timeout + fallthrough + case lnOpts.IsTimeout(), lnOpts.IsSocketOpts(): + // timeout listener with socket options. + ln, err := lnOpts.ListenConfig.Listen(context.TODO(), "tcp", addr) + if err != nil { + return nil, err + } + lnOpts.Listener = &rwTimeoutListener{ + Listener: ln, + readTimeout: lnOpts.readTimeout, + writeTimeout: lnOpts.writeTimeout, + } + case lnOpts.IsTimeout(): + ln, err := net.Listen("tcp", addr) + if err != nil { + return nil, err + } + lnOpts.Listener = &rwTimeoutListener{ + Listener: ln, + readTimeout: lnOpts.readTimeout, + writeTimeout: lnOpts.writeTimeout, + } + default: + ln, err := net.Listen("tcp", addr) + if err != nil { + return nil, err + } + lnOpts.Listener = ln } - return config.Listen(context.TODO(), "tcp", addr) + + // only skip if not passing TLSInfo + if lnOpts.skipTLSInfoCheck && !lnOpts.IsTLS() { + return lnOpts.Listener, nil + } + return wrapTLS(scheme, lnOpts.tlsInfo, lnOpts.Listener) } func wrapTLS(scheme string, tlsinfo *TLSInfo, l net.Listener) (net.Listener, error) { diff --git a/pkg/transport/listener_opts.go b/pkg/transport/listener_opts.go new file mode 100644 index 000000000..d6c6830ce --- /dev/null +++ b/pkg/transport/listener_opts.go @@ -0,0 +1,76 @@ +package transport + +import ( + "net" + "time" +) + +type ListenerOptions struct { + Listener net.Listener + ListenConfig net.ListenConfig + + socketOpts *SocketOpts + tlsInfo *TLSInfo + skipTLSInfoCheck bool + writeTimeout time.Duration + readTimeout time.Duration +} + +func newListenOpts(opts ...ListenerOption) *ListenerOptions { + lnOpts := &ListenerOptions{} + lnOpts.applyOpts(opts) + return lnOpts +} + +func (lo *ListenerOptions) applyOpts(opts []ListenerOption) { + for _, opt := range opts { + opt(lo) + } +} + +// IsTimeout returns true if the listener has a read/write timeout defined. +func (lo *ListenerOptions) IsTimeout() bool { return lo.readTimeout != 0 || lo.writeTimeout != 0 } + +// IsSocketOpts returns true if the listener options includes socket options. +func (lo *ListenerOptions) IsSocketOpts() bool { + if lo.socketOpts == nil { + return false + } + return lo.socketOpts.ReusePort == true || lo.socketOpts.ReuseAddress == true +} + +// IsTLS returns true if listner options includes TLSInfo. +func (lo *ListenerOptions) IsTLS() bool { + if lo.tlsInfo == nil { + return false + } + return lo.tlsInfo.Empty() == false +} + +// ListenerOption are options which can be applied to the listener. +type ListenerOption func(*ListenerOptions) + +// WithTimeout allows for a read or write timeout to be applied to the listener. +func WithTimeout(read, write time.Duration) ListenerOption { + return func(lo *ListenerOptions) { + lo.writeTimeout = write + lo.readTimeout = read + } +} + +// WithSocketOpts defines socket options that will be applied to the listener. +func WithSocketOpts(s *SocketOpts) ListenerOption { + return func(lo *ListenerOptions) { lo.socketOpts = s } +} + +// WithTLSInfo adds TLS credentials to the listener. +func WithTLSInfo(t *TLSInfo) ListenerOption { + return func(lo *ListenerOptions) { lo.tlsInfo = t } +} + +// WithSkipTLSInfoCheck when true a transport can be created with an https scheme +// without passing TLSInfo, circumventing not presented error. Skipping this check +// also requires that TLSInfo is not passed. +func WithSkipTLSInfoCheck(skip bool) ListenerOption { + return func(lo *ListenerOptions) { lo.skipTLSInfoCheck = skip } +} diff --git a/pkg/transport/listener_test.go b/pkg/transport/listener_test.go index b79f34a31..0a7b0ad16 100644 --- a/pkg/transport/listener_test.go +++ b/pkg/transport/listener_test.go @@ -61,53 +61,156 @@ func TestNewListenerTLSInfo(t *testing.T) { testNewListenerTLSInfoAccept(t, *tlsInfo) } +func TestNewListenerWithOpts(t *testing.T) { + tlsInfo, del, err := createSelfCert() + if err != nil { + t.Fatalf("unable to create cert: %v", err) + } + defer del() + + tests := map[string]struct { + opts []ListenerOption + scheme string + expectedErr bool + }{ + "https scheme no TLSInfo": { + opts: []ListenerOption{}, + expectedErr: true, + scheme: "https", + }, + "https scheme no TLSInfo with skip check": { + opts: []ListenerOption{WithSkipTLSInfoCheck(true)}, + expectedErr: false, + scheme: "https", + }, + "https scheme empty TLSInfo with skip check": { + opts: []ListenerOption{ + WithSkipTLSInfoCheck(true), + WithTLSInfo(&TLSInfo{}), + }, + expectedErr: false, + scheme: "https", + }, + "https scheme empty TLSInfo no skip check": { + opts: []ListenerOption{ + WithTLSInfo(&TLSInfo{}), + }, + expectedErr: true, + scheme: "https", + }, + "https scheme with TLSInfo and skip check": { + opts: []ListenerOption{ + WithSkipTLSInfoCheck(true), + WithTLSInfo(tlsInfo), + }, + expectedErr: false, + scheme: "https", + }, + } + for testName, test := range tests { + t.Run(testName, func(t *testing.T) { + ln, err := NewListenerWithOpts("127.0.0.1:0", test.scheme, test.opts...) + if ln != nil { + defer ln.Close() + } + if test.expectedErr && err == nil { + t.Fatalf("expected error") + } + if !test.expectedErr && err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} + func TestNewListenerWithSocketOpts(t *testing.T) { tlsInfo, del, err := createSelfCert() if err != nil { t.Fatalf("unable to create cert: %v", err) } defer del() + tests := map[string]struct { - socketOpts *SocketOpts + opts []ListenerOption + scheme string expectedErr bool }{ - "nil": { - socketOpts: nil, + "nil socketopts": { + opts: []ListenerOption{WithSocketOpts(nil)}, expectedErr: true, + scheme: "http", }, - "empty": { - socketOpts: &SocketOpts{}, + "empty socketopts": { + opts: []ListenerOption{WithSocketOpts(&SocketOpts{})}, expectedErr: true, + scheme: "http", }, + "reuse address": { - socketOpts: &SocketOpts{ReuseAddress: true}, + opts: []ListenerOption{WithSocketOpts(&SocketOpts{ReuseAddress: true})}, + scheme: "http", expectedErr: true, }, - "reuse address and reuse port": { - socketOpts: &SocketOpts{ReuseAddress: true, ReusePort: true}, + "reuse address with TLS": { + opts: []ListenerOption{ + WithSocketOpts(&SocketOpts{ReuseAddress: true}), + WithTLSInfo(tlsInfo), + }, + scheme: "https", + expectedErr: true, + }, + "reuse address and port": { + opts: []ListenerOption{WithSocketOpts(&SocketOpts{ReuseAddress: true, ReusePort: true})}, + scheme: "http", + expectedErr: false, + }, + "reuse address and port with TLS": { + opts: []ListenerOption{ + WithSocketOpts(&SocketOpts{ReuseAddress: true, ReusePort: true}), + WithTLSInfo(tlsInfo), + }, + scheme: "https", + expectedErr: false, + }, + "reuse port with TLS and timeout": { + opts: []ListenerOption{ + WithSocketOpts(&SocketOpts{ReusePort: true}), + WithTLSInfo(tlsInfo), + WithTimeout(5*time.Second, 5*time.Second), + }, + scheme: "https", + expectedErr: false, + }, + "reuse port with https scheme and no TLSInfo skip check": { + opts: []ListenerOption{ + WithSocketOpts(&SocketOpts{ReusePort: true}), + WithSkipTLSInfoCheck(true), + }, + scheme: "https", expectedErr: false, }, "reuse port": { - socketOpts: &SocketOpts{ReusePort: true}, + opts: []ListenerOption{WithSocketOpts(&SocketOpts{ReusePort: true})}, + scheme: "http", expectedErr: false, }, } for testName, test := range tests { t.Run(testName, func(t *testing.T) { - ln, err := NewListenerWithSocketOpts("127.0.0.1:0", "https", tlsInfo, test.socketOpts) + ln, err := NewListenerWithOpts("127.0.0.1:0", test.scheme, test.opts...) if err != nil { t.Fatalf("unexpected NewListenerWithSocketOpts error: %v", err) } defer ln.Close() - ln2, err := NewListenerWithSocketOpts(ln.Addr().String(), "https", tlsInfo, test.socketOpts) + ln2, err := NewListenerWithOpts(ln.Addr().String(), test.scheme, test.opts...) + if ln2 != nil { + ln2.Close() + } if test.expectedErr && err == nil { t.Fatalf("expected error") } if !test.expectedErr && err != nil { - t.Fatalf("unexpected NewListenerWithSocketOpts error: %v", err) - } - if ln2 != nil { - ln2.Close() + t.Fatalf("unexpected error: %v", err) } }) } diff --git a/pkg/transport/timeout_conn.go b/pkg/transport/timeout_conn.go index 7e8c02030..80e329394 100644 --- a/pkg/transport/timeout_conn.go +++ b/pkg/transport/timeout_conn.go @@ -21,13 +21,13 @@ import ( type timeoutConn struct { net.Conn - wtimeoutd time.Duration - rdtimeoutd time.Duration + writeTimeout time.Duration + readTimeout time.Duration } func (c timeoutConn) Write(b []byte) (n int, err error) { - if c.wtimeoutd > 0 { - if err := c.SetWriteDeadline(time.Now().Add(c.wtimeoutd)); err != nil { + if c.writeTimeout > 0 { + if err := c.SetWriteDeadline(time.Now().Add(c.writeTimeout)); err != nil { return 0, err } } @@ -35,8 +35,8 @@ func (c timeoutConn) Write(b []byte) (n int, err error) { } func (c timeoutConn) Read(b []byte) (n int, err error) { - if c.rdtimeoutd > 0 { - if err := c.SetReadDeadline(time.Now().Add(c.rdtimeoutd)); err != nil { + if c.readTimeout > 0 { + if err := c.SetReadDeadline(time.Now().Add(c.readTimeout)); err != nil { return 0, err } } diff --git a/pkg/transport/timeout_dialer.go b/pkg/transport/timeout_dialer.go index 6ae39ecfc..9c0245d31 100644 --- a/pkg/transport/timeout_dialer.go +++ b/pkg/transport/timeout_dialer.go @@ -28,9 +28,9 @@ type rwTimeoutDialer struct { func (d *rwTimeoutDialer) Dial(network, address string) (net.Conn, error) { conn, err := d.Dialer.Dial(network, address) tconn := &timeoutConn{ - rdtimeoutd: d.rdtimeoutd, - wtimeoutd: d.wtimeoutd, - Conn: conn, + readTimeout: d.rdtimeoutd, + writeTimeout: d.wtimeoutd, + Conn: conn, } return tconn, err } diff --git a/pkg/transport/timeout_listener.go b/pkg/transport/timeout_listener.go index 29a62d997..5d74bd70c 100644 --- a/pkg/transport/timeout_listener.go +++ b/pkg/transport/timeout_listener.go @@ -22,39 +22,14 @@ import ( // NewTimeoutListener returns a listener that listens on the given address. // If read/write on the accepted connection blocks longer than its time limit, // it will return timeout error. -func NewTimeoutListener(addr string, scheme string, tlsinfo *TLSInfo, rdtimeoutd, wtimeoutd time.Duration) (net.Listener, error) { - ln, err := newListener(addr, scheme, nil) - if err != nil { - return nil, err - } - return newTimeoutListener(ln, scheme, rdtimeoutd, wtimeoutd, tlsinfo) -} - -// NewTimeoutListerWithSocketOpts returns a listener that listens on the given address. -// If read/write on the accepted connection blocks longer than its time limit, -// it will return timeout error. Socket options can be passed and will be applied to the -// ListenConfig. -func NewTimeoutListerWithSocketOpts(addr string, scheme string, tlsinfo *TLSInfo, rdtimeoutd, wtimeoutd time.Duration, sopts *SocketOpts) (net.Listener, error) { - ln, err := newListener(addr, scheme, sopts) - if err != nil { - return nil, err - } - return newTimeoutListener(ln, scheme, rdtimeoutd, wtimeoutd, tlsinfo) -} - -func newTimeoutListener(ln net.Listener, scheme string, rdtimeoutd, wtimeoutd time.Duration, tlsinfo *TLSInfo) (net.Listener, error) { - timeoutListener := &rwTimeoutListener{ - Listener: ln, - rdtimeoutd: rdtimeoutd, - wtimeoutd: wtimeoutd, - } - return wrapTLS(scheme, tlsinfo, timeoutListener) +func NewTimeoutListener(addr string, scheme string, tlsinfo *TLSInfo, readTimeout, writeTimeout time.Duration) (net.Listener, error) { + return newListener(addr, scheme, WithTimeout(readTimeout, writeTimeout), WithTLSInfo(tlsinfo)) } type rwTimeoutListener struct { net.Listener - wtimeoutd time.Duration - rdtimeoutd time.Duration + writeTimeout time.Duration + readTimeout time.Duration } func (rwln *rwTimeoutListener) Accept() (net.Conn, error) { @@ -63,8 +38,8 @@ func (rwln *rwTimeoutListener) Accept() (net.Conn, error) { return nil, err } return timeoutConn{ - Conn: c, - wtimeoutd: rwln.wtimeoutd, - rdtimeoutd: rwln.rdtimeoutd, + Conn: c, + writeTimeout: rwln.writeTimeout, + readTimeout: rwln.readTimeout, }, nil } diff --git a/pkg/transport/timeout_listener_test.go b/pkg/transport/timeout_listener_test.go index f2eaad7b7..0c4f20837 100644 --- a/pkg/transport/timeout_listener_test.go +++ b/pkg/transport/timeout_listener_test.go @@ -29,11 +29,11 @@ func TestNewTimeoutListener(t *testing.T) { } defer l.Close() tln := l.(*rwTimeoutListener) - if tln.rdtimeoutd != time.Hour { - t.Errorf("read timeout = %s, want %s", tln.rdtimeoutd, time.Hour) + if tln.readTimeout != time.Hour { + t.Errorf("read timeout = %s, want %s", tln.readTimeout, time.Hour) } - if tln.wtimeoutd != time.Hour { - t.Errorf("write timeout = %s, want %s", tln.wtimeoutd, time.Hour) + if tln.writeTimeout != time.Hour { + t.Errorf("write timeout = %s, want %s", tln.writeTimeout, time.Hour) } } @@ -43,9 +43,9 @@ func TestWriteReadTimeoutListener(t *testing.T) { t.Fatalf("unexpected listen error: %v", err) } wln := rwTimeoutListener{ - Listener: ln, - wtimeoutd: 10 * time.Millisecond, - rdtimeoutd: 10 * time.Millisecond, + Listener: ln, + writeTimeout: 10 * time.Millisecond, + readTimeout: 10 * time.Millisecond, } stop := make(chan struct{}, 1) @@ -78,7 +78,7 @@ func TestWriteReadTimeoutListener(t *testing.T) { select { case <-done: // It waits 1s more to avoid delay in low-end system. - case <-time.After(wln.wtimeoutd*10 + time.Second): + case <-time.After(wln.writeTimeout*10 + time.Second): stop <- struct{}{} t.Fatal("wait timeout") } @@ -104,7 +104,7 @@ func TestWriteReadTimeoutListener(t *testing.T) { select { case <-done: - case <-time.After(wln.rdtimeoutd * 10): + case <-time.After(wln.readTimeout * 10): stop <- struct{}{} t.Fatal("wait timeout") } diff --git a/pkg/transport/timeout_transport_test.go b/pkg/transport/timeout_transport_test.go index f64fd01f3..d2dfe5f6f 100644 --- a/pkg/transport/timeout_transport_test.go +++ b/pkg/transport/timeout_transport_test.go @@ -47,11 +47,11 @@ func TestNewTimeoutTransport(t *testing.T) { if !ok { t.Fatalf("failed to dial out *timeoutConn") } - if tconn.rdtimeoutd != time.Hour { - t.Errorf("read timeout = %s, want %s", tconn.rdtimeoutd, time.Hour) + if tconn.readTimeout != time.Hour { + t.Errorf("read timeout = %s, want %s", tconn.readTimeout, time.Hour) } - if tconn.wtimeoutd != time.Hour { - t.Errorf("write timeout = %s, want %s", tconn.wtimeoutd, time.Hour) + if tconn.writeTimeout != time.Hour { + t.Errorf("write timeout = %s, want %s", tconn.writeTimeout, time.Hour) } // ensure not reuse timeout connection diff --git a/server/embed/etcd.go b/server/embed/etcd.go index 3d75bf1b2..e5ca64b12 100644 --- a/server/embed/etcd.go +++ b/server/embed/etcd.go @@ -466,7 +466,11 @@ func configurePeerListeners(cfg *Config) (peers []*peerListener, err error) { } } peers[i] = &peerListener{close: func(context.Context) error { return nil }} - peers[i].Listener, err = rafthttp.NewListenerWithSocketOpts(u, &cfg.PeerTLSInfo, &cfg.SocketOpts) + peers[i].Listener, err = transport.NewListenerWithOpts(u.Host, u.Scheme, + transport.WithTLSInfo(&cfg.PeerTLSInfo), + transport.WithSocketOpts(&cfg.SocketOpts), + transport.WithTimeout(rafthttp.ConnReadTimeout, rafthttp.ConnWriteTimeout), + ) if err != nil { return nil, err } @@ -573,7 +577,10 @@ func configureClientListeners(cfg *Config) (sctxs map[string]*serveCtx, err erro continue } - if sctx.l, err = transport.NewListenerWithSocketOpts(addr, u.Scheme, nil, &cfg.SocketOpts); err != nil { + if sctx.l, err = transport.NewListenerWithOpts(addr, u.Scheme, + transport.WithSocketOpts(&cfg.SocketOpts), + transport.WithSkipTLSInfoCheck(true), + ); err != nil { return nil, err } // net.Listener will rewrite ipv4 0.0.0.0 to ipv6 [::], breaking @@ -686,7 +693,10 @@ func (e *Etcd) serveMetrics() (err error) { if murl.Scheme == "http" { tlsInfo = nil } - ml, err := transport.NewListenerWithSocketOpts(murl.Host, murl.Scheme, tlsInfo, &e.cfg.SocketOpts) + ml, err := transport.NewListenerWithOpts(murl.Host, murl.Scheme, + transport.WithTLSInfo(tlsInfo), + transport.WithSocketOpts(&e.cfg.SocketOpts), + ) if err != nil { return err } diff --git a/server/etcdserver/api/rafthttp/util.go b/server/etcdserver/api/rafthttp/util.go index 905974451..bfc5cf6f4 100644 --- a/server/etcdserver/api/rafthttp/util.go +++ b/server/etcdserver/api/rafthttp/util.go @@ -39,11 +39,7 @@ var ( // NewListener returns a listener for raft message transfer between peers. // It uses timeout listener to identify broken streams promptly. func NewListener(u url.URL, tlsinfo *transport.TLSInfo) (net.Listener, error) { - return transport.NewTimeoutListener(u.Host, u.Scheme, tlsinfo, ConnReadTimeout, ConnWriteTimeout) -} - -func NewListenerWithSocketOpts(u url.URL, tlsinfo *transport.TLSInfo, sopts *transport.SocketOpts) (net.Listener, error) { - return transport.NewTimeoutListerWithSocketOpts(u.Host, u.Scheme, tlsinfo, ConnReadTimeout, ConnWriteTimeout, sopts) + return transport.NewListenerWithOpts(u.Host, u.Scheme, transport.WithTLSInfo(tlsinfo), transport.WithTimeout(ConnReadTimeout, ConnWriteTimeout)) } // NewRoundTripper returns a roundTripper used to send requests