fixup: add ListenerOptions

Signed-off-by: Sam Batschelet <sbatsche@redhat.com>
This commit is contained in:
Sam Batschelet 2021-03-08 09:01:58 -05:00
parent 49078c683b
commit 5b49fb41c8
10 changed files with 288 additions and 96 deletions

View File

@ -40,34 +40,66 @@ import (
// NewListener creates a new listner.
func NewListener(addr, scheme string, tlsinfo *TLSInfo) (l net.Listener, err error) {
if l, err = newListener(addr, scheme, nil); err != nil {
return nil, err
}
return wrapTLS(scheme, tlsinfo, l)
return newListener(addr, scheme, WithTLSInfo(tlsinfo))
}
// NewListenerWithSocketOpts creates new listener with support for socket options.
func NewListenerWithSocketOpts(addr, scheme string, tlsinfo *TLSInfo, sopts *SocketOpts) (net.Listener, error) {
ln, err := newListener(addr, scheme, sopts)
if err != nil {
return nil, err
}
if tlsinfo != nil {
wrapTLS(scheme, tlsinfo, ln)
}
return ln, nil
// NewListenerWithOpts creates a new listener which accpets listener options.
func NewListenerWithOpts(addr, scheme string, opts ...ListenerOption) (net.Listener, error) {
return newListener(addr, scheme, opts...)
}
func newListener(addr string, scheme string, sopts *SocketOpts) (net.Listener, error) {
func newListener(addr, scheme string, opts ...ListenerOption) (net.Listener, error) {
if scheme == "unix" || scheme == "unixs" {
// unix sockets via unix://laddr
return NewUnixListener(addr)
}
config, err := newListenConfig(sopts)
if err != nil {
return nil, err
lnOpts := newListenOpts(opts...)
switch {
case lnOpts.IsSocketOpts():
// new ListenConfig with socket options.
config, err := newListenConfig(lnOpts.socketOpts)
if err != nil {
return nil, err
}
lnOpts.ListenConfig = config
// check for timeout
fallthrough
case lnOpts.IsTimeout(), lnOpts.IsSocketOpts():
// timeout listener with socket options.
ln, err := lnOpts.ListenConfig.Listen(context.TODO(), "tcp", addr)
if err != nil {
return nil, err
}
lnOpts.Listener = &rwTimeoutListener{
Listener: ln,
readTimeout: lnOpts.readTimeout,
writeTimeout: lnOpts.writeTimeout,
}
case lnOpts.IsTimeout():
ln, err := net.Listen("tcp", addr)
if err != nil {
return nil, err
}
lnOpts.Listener = &rwTimeoutListener{
Listener: ln,
readTimeout: lnOpts.readTimeout,
writeTimeout: lnOpts.writeTimeout,
}
default:
ln, err := net.Listen("tcp", addr)
if err != nil {
return nil, err
}
lnOpts.Listener = ln
}
return config.Listen(context.TODO(), "tcp", addr)
// only skip if not passing TLSInfo
if lnOpts.skipTLSInfoCheck && !lnOpts.IsTLS() {
return lnOpts.Listener, nil
}
return wrapTLS(scheme, lnOpts.tlsInfo, lnOpts.Listener)
}
func wrapTLS(scheme string, tlsinfo *TLSInfo, l net.Listener) (net.Listener, error) {

View File

@ -0,0 +1,76 @@
package transport
import (
"net"
"time"
)
type ListenerOptions struct {
Listener net.Listener
ListenConfig net.ListenConfig
socketOpts *SocketOpts
tlsInfo *TLSInfo
skipTLSInfoCheck bool
writeTimeout time.Duration
readTimeout time.Duration
}
func newListenOpts(opts ...ListenerOption) *ListenerOptions {
lnOpts := &ListenerOptions{}
lnOpts.applyOpts(opts)
return lnOpts
}
func (lo *ListenerOptions) applyOpts(opts []ListenerOption) {
for _, opt := range opts {
opt(lo)
}
}
// IsTimeout returns true if the listener has a read/write timeout defined.
func (lo *ListenerOptions) IsTimeout() bool { return lo.readTimeout != 0 || lo.writeTimeout != 0 }
// IsSocketOpts returns true if the listener options includes socket options.
func (lo *ListenerOptions) IsSocketOpts() bool {
if lo.socketOpts == nil {
return false
}
return lo.socketOpts.ReusePort == true || lo.socketOpts.ReuseAddress == true
}
// IsTLS returns true if listner options includes TLSInfo.
func (lo *ListenerOptions) IsTLS() bool {
if lo.tlsInfo == nil {
return false
}
return lo.tlsInfo.Empty() == false
}
// ListenerOption are options which can be applied to the listener.
type ListenerOption func(*ListenerOptions)
// WithTimeout allows for a read or write timeout to be applied to the listener.
func WithTimeout(read, write time.Duration) ListenerOption {
return func(lo *ListenerOptions) {
lo.writeTimeout = write
lo.readTimeout = read
}
}
// WithSocketOpts defines socket options that will be applied to the listener.
func WithSocketOpts(s *SocketOpts) ListenerOption {
return func(lo *ListenerOptions) { lo.socketOpts = s }
}
// WithTLSInfo adds TLS credentials to the listener.
func WithTLSInfo(t *TLSInfo) ListenerOption {
return func(lo *ListenerOptions) { lo.tlsInfo = t }
}
// WithSkipTLSInfoCheck when true a transport can be created with an https scheme
// without passing TLSInfo, circumventing not presented error. Skipping this check
// also requires that TLSInfo is not passed.
func WithSkipTLSInfoCheck(skip bool) ListenerOption {
return func(lo *ListenerOptions) { lo.skipTLSInfoCheck = skip }
}

View File

@ -61,53 +61,156 @@ func TestNewListenerTLSInfo(t *testing.T) {
testNewListenerTLSInfoAccept(t, *tlsInfo)
}
func TestNewListenerWithOpts(t *testing.T) {
tlsInfo, del, err := createSelfCert()
if err != nil {
t.Fatalf("unable to create cert: %v", err)
}
defer del()
tests := map[string]struct {
opts []ListenerOption
scheme string
expectedErr bool
}{
"https scheme no TLSInfo": {
opts: []ListenerOption{},
expectedErr: true,
scheme: "https",
},
"https scheme no TLSInfo with skip check": {
opts: []ListenerOption{WithSkipTLSInfoCheck(true)},
expectedErr: false,
scheme: "https",
},
"https scheme empty TLSInfo with skip check": {
opts: []ListenerOption{
WithSkipTLSInfoCheck(true),
WithTLSInfo(&TLSInfo{}),
},
expectedErr: false,
scheme: "https",
},
"https scheme empty TLSInfo no skip check": {
opts: []ListenerOption{
WithTLSInfo(&TLSInfo{}),
},
expectedErr: true,
scheme: "https",
},
"https scheme with TLSInfo and skip check": {
opts: []ListenerOption{
WithSkipTLSInfoCheck(true),
WithTLSInfo(tlsInfo),
},
expectedErr: false,
scheme: "https",
},
}
for testName, test := range tests {
t.Run(testName, func(t *testing.T) {
ln, err := NewListenerWithOpts("127.0.0.1:0", test.scheme, test.opts...)
if ln != nil {
defer ln.Close()
}
if test.expectedErr && err == nil {
t.Fatalf("expected error")
}
if !test.expectedErr && err != nil {
t.Fatalf("unexpected error: %v", err)
}
})
}
}
func TestNewListenerWithSocketOpts(t *testing.T) {
tlsInfo, del, err := createSelfCert()
if err != nil {
t.Fatalf("unable to create cert: %v", err)
}
defer del()
tests := map[string]struct {
socketOpts *SocketOpts
opts []ListenerOption
scheme string
expectedErr bool
}{
"nil": {
socketOpts: nil,
"nil socketopts": {
opts: []ListenerOption{WithSocketOpts(nil)},
expectedErr: true,
scheme: "http",
},
"empty": {
socketOpts: &SocketOpts{},
"empty socketopts": {
opts: []ListenerOption{WithSocketOpts(&SocketOpts{})},
expectedErr: true,
scheme: "http",
},
"reuse address": {
socketOpts: &SocketOpts{ReuseAddress: true},
opts: []ListenerOption{WithSocketOpts(&SocketOpts{ReuseAddress: true})},
scheme: "http",
expectedErr: true,
},
"reuse address and reuse port": {
socketOpts: &SocketOpts{ReuseAddress: true, ReusePort: true},
"reuse address with TLS": {
opts: []ListenerOption{
WithSocketOpts(&SocketOpts{ReuseAddress: true}),
WithTLSInfo(tlsInfo),
},
scheme: "https",
expectedErr: true,
},
"reuse address and port": {
opts: []ListenerOption{WithSocketOpts(&SocketOpts{ReuseAddress: true, ReusePort: true})},
scheme: "http",
expectedErr: false,
},
"reuse address and port with TLS": {
opts: []ListenerOption{
WithSocketOpts(&SocketOpts{ReuseAddress: true, ReusePort: true}),
WithTLSInfo(tlsInfo),
},
scheme: "https",
expectedErr: false,
},
"reuse port with TLS and timeout": {
opts: []ListenerOption{
WithSocketOpts(&SocketOpts{ReusePort: true}),
WithTLSInfo(tlsInfo),
WithTimeout(5*time.Second, 5*time.Second),
},
scheme: "https",
expectedErr: false,
},
"reuse port with https scheme and no TLSInfo skip check": {
opts: []ListenerOption{
WithSocketOpts(&SocketOpts{ReusePort: true}),
WithSkipTLSInfoCheck(true),
},
scheme: "https",
expectedErr: false,
},
"reuse port": {
socketOpts: &SocketOpts{ReusePort: true},
opts: []ListenerOption{WithSocketOpts(&SocketOpts{ReusePort: true})},
scheme: "http",
expectedErr: false,
},
}
for testName, test := range tests {
t.Run(testName, func(t *testing.T) {
ln, err := NewListenerWithSocketOpts("127.0.0.1:0", "https", tlsInfo, test.socketOpts)
ln, err := NewListenerWithOpts("127.0.0.1:0", test.scheme, test.opts...)
if err != nil {
t.Fatalf("unexpected NewListenerWithSocketOpts error: %v", err)
}
defer ln.Close()
ln2, err := NewListenerWithSocketOpts(ln.Addr().String(), "https", tlsInfo, test.socketOpts)
ln2, err := NewListenerWithOpts(ln.Addr().String(), test.scheme, test.opts...)
if ln2 != nil {
ln2.Close()
}
if test.expectedErr && err == nil {
t.Fatalf("expected error")
}
if !test.expectedErr && err != nil {
t.Fatalf("unexpected NewListenerWithSocketOpts error: %v", err)
}
if ln2 != nil {
ln2.Close()
t.Fatalf("unexpected error: %v", err)
}
})
}

View File

@ -21,13 +21,13 @@ import (
type timeoutConn struct {
net.Conn
wtimeoutd time.Duration
rdtimeoutd time.Duration
writeTimeout time.Duration
readTimeout time.Duration
}
func (c timeoutConn) Write(b []byte) (n int, err error) {
if c.wtimeoutd > 0 {
if err := c.SetWriteDeadline(time.Now().Add(c.wtimeoutd)); err != nil {
if c.writeTimeout > 0 {
if err := c.SetWriteDeadline(time.Now().Add(c.writeTimeout)); err != nil {
return 0, err
}
}
@ -35,8 +35,8 @@ func (c timeoutConn) Write(b []byte) (n int, err error) {
}
func (c timeoutConn) Read(b []byte) (n int, err error) {
if c.rdtimeoutd > 0 {
if err := c.SetReadDeadline(time.Now().Add(c.rdtimeoutd)); err != nil {
if c.readTimeout > 0 {
if err := c.SetReadDeadline(time.Now().Add(c.readTimeout)); err != nil {
return 0, err
}
}

View File

@ -28,9 +28,9 @@ type rwTimeoutDialer struct {
func (d *rwTimeoutDialer) Dial(network, address string) (net.Conn, error) {
conn, err := d.Dialer.Dial(network, address)
tconn := &timeoutConn{
rdtimeoutd: d.rdtimeoutd,
wtimeoutd: d.wtimeoutd,
Conn: conn,
readTimeout: d.rdtimeoutd,
writeTimeout: d.wtimeoutd,
Conn: conn,
}
return tconn, err
}

View File

@ -22,39 +22,14 @@ import (
// NewTimeoutListener returns a listener that listens on the given address.
// If read/write on the accepted connection blocks longer than its time limit,
// it will return timeout error.
func NewTimeoutListener(addr string, scheme string, tlsinfo *TLSInfo, rdtimeoutd, wtimeoutd time.Duration) (net.Listener, error) {
ln, err := newListener(addr, scheme, nil)
if err != nil {
return nil, err
}
return newTimeoutListener(ln, scheme, rdtimeoutd, wtimeoutd, tlsinfo)
}
// NewTimeoutListerWithSocketOpts returns a listener that listens on the given address.
// If read/write on the accepted connection blocks longer than its time limit,
// it will return timeout error. Socket options can be passed and will be applied to the
// ListenConfig.
func NewTimeoutListerWithSocketOpts(addr string, scheme string, tlsinfo *TLSInfo, rdtimeoutd, wtimeoutd time.Duration, sopts *SocketOpts) (net.Listener, error) {
ln, err := newListener(addr, scheme, sopts)
if err != nil {
return nil, err
}
return newTimeoutListener(ln, scheme, rdtimeoutd, wtimeoutd, tlsinfo)
}
func newTimeoutListener(ln net.Listener, scheme string, rdtimeoutd, wtimeoutd time.Duration, tlsinfo *TLSInfo) (net.Listener, error) {
timeoutListener := &rwTimeoutListener{
Listener: ln,
rdtimeoutd: rdtimeoutd,
wtimeoutd: wtimeoutd,
}
return wrapTLS(scheme, tlsinfo, timeoutListener)
func NewTimeoutListener(addr string, scheme string, tlsinfo *TLSInfo, readTimeout, writeTimeout time.Duration) (net.Listener, error) {
return newListener(addr, scheme, WithTimeout(readTimeout, writeTimeout), WithTLSInfo(tlsinfo))
}
type rwTimeoutListener struct {
net.Listener
wtimeoutd time.Duration
rdtimeoutd time.Duration
writeTimeout time.Duration
readTimeout time.Duration
}
func (rwln *rwTimeoutListener) Accept() (net.Conn, error) {
@ -63,8 +38,8 @@ func (rwln *rwTimeoutListener) Accept() (net.Conn, error) {
return nil, err
}
return timeoutConn{
Conn: c,
wtimeoutd: rwln.wtimeoutd,
rdtimeoutd: rwln.rdtimeoutd,
Conn: c,
writeTimeout: rwln.writeTimeout,
readTimeout: rwln.readTimeout,
}, nil
}

View File

@ -29,11 +29,11 @@ func TestNewTimeoutListener(t *testing.T) {
}
defer l.Close()
tln := l.(*rwTimeoutListener)
if tln.rdtimeoutd != time.Hour {
t.Errorf("read timeout = %s, want %s", tln.rdtimeoutd, time.Hour)
if tln.readTimeout != time.Hour {
t.Errorf("read timeout = %s, want %s", tln.readTimeout, time.Hour)
}
if tln.wtimeoutd != time.Hour {
t.Errorf("write timeout = %s, want %s", tln.wtimeoutd, time.Hour)
if tln.writeTimeout != time.Hour {
t.Errorf("write timeout = %s, want %s", tln.writeTimeout, time.Hour)
}
}
@ -43,9 +43,9 @@ func TestWriteReadTimeoutListener(t *testing.T) {
t.Fatalf("unexpected listen error: %v", err)
}
wln := rwTimeoutListener{
Listener: ln,
wtimeoutd: 10 * time.Millisecond,
rdtimeoutd: 10 * time.Millisecond,
Listener: ln,
writeTimeout: 10 * time.Millisecond,
readTimeout: 10 * time.Millisecond,
}
stop := make(chan struct{}, 1)
@ -78,7 +78,7 @@ func TestWriteReadTimeoutListener(t *testing.T) {
select {
case <-done:
// It waits 1s more to avoid delay in low-end system.
case <-time.After(wln.wtimeoutd*10 + time.Second):
case <-time.After(wln.writeTimeout*10 + time.Second):
stop <- struct{}{}
t.Fatal("wait timeout")
}
@ -104,7 +104,7 @@ func TestWriteReadTimeoutListener(t *testing.T) {
select {
case <-done:
case <-time.After(wln.rdtimeoutd * 10):
case <-time.After(wln.readTimeout * 10):
stop <- struct{}{}
t.Fatal("wait timeout")
}

View File

@ -47,11 +47,11 @@ func TestNewTimeoutTransport(t *testing.T) {
if !ok {
t.Fatalf("failed to dial out *timeoutConn")
}
if tconn.rdtimeoutd != time.Hour {
t.Errorf("read timeout = %s, want %s", tconn.rdtimeoutd, time.Hour)
if tconn.readTimeout != time.Hour {
t.Errorf("read timeout = %s, want %s", tconn.readTimeout, time.Hour)
}
if tconn.wtimeoutd != time.Hour {
t.Errorf("write timeout = %s, want %s", tconn.wtimeoutd, time.Hour)
if tconn.writeTimeout != time.Hour {
t.Errorf("write timeout = %s, want %s", tconn.writeTimeout, time.Hour)
}
// ensure not reuse timeout connection

View File

@ -466,7 +466,11 @@ func configurePeerListeners(cfg *Config) (peers []*peerListener, err error) {
}
}
peers[i] = &peerListener{close: func(context.Context) error { return nil }}
peers[i].Listener, err = rafthttp.NewListenerWithSocketOpts(u, &cfg.PeerTLSInfo, &cfg.SocketOpts)
peers[i].Listener, err = transport.NewListenerWithOpts(u.Host, u.Scheme,
transport.WithTLSInfo(&cfg.PeerTLSInfo),
transport.WithSocketOpts(&cfg.SocketOpts),
transport.WithTimeout(rafthttp.ConnReadTimeout, rafthttp.ConnWriteTimeout),
)
if err != nil {
return nil, err
}
@ -573,7 +577,10 @@ func configureClientListeners(cfg *Config) (sctxs map[string]*serveCtx, err erro
continue
}
if sctx.l, err = transport.NewListenerWithSocketOpts(addr, u.Scheme, nil, &cfg.SocketOpts); err != nil {
if sctx.l, err = transport.NewListenerWithOpts(addr, u.Scheme,
transport.WithSocketOpts(&cfg.SocketOpts),
transport.WithSkipTLSInfoCheck(true),
); err != nil {
return nil, err
}
// net.Listener will rewrite ipv4 0.0.0.0 to ipv6 [::], breaking
@ -686,7 +693,10 @@ func (e *Etcd) serveMetrics() (err error) {
if murl.Scheme == "http" {
tlsInfo = nil
}
ml, err := transport.NewListenerWithSocketOpts(murl.Host, murl.Scheme, tlsInfo, &e.cfg.SocketOpts)
ml, err := transport.NewListenerWithOpts(murl.Host, murl.Scheme,
transport.WithTLSInfo(tlsInfo),
transport.WithSocketOpts(&e.cfg.SocketOpts),
)
if err != nil {
return err
}

View File

@ -39,11 +39,7 @@ var (
// NewListener returns a listener for raft message transfer between peers.
// It uses timeout listener to identify broken streams promptly.
func NewListener(u url.URL, tlsinfo *transport.TLSInfo) (net.Listener, error) {
return transport.NewTimeoutListener(u.Host, u.Scheme, tlsinfo, ConnReadTimeout, ConnWriteTimeout)
}
func NewListenerWithSocketOpts(u url.URL, tlsinfo *transport.TLSInfo, sopts *transport.SocketOpts) (net.Listener, error) {
return transport.NewTimeoutListerWithSocketOpts(u.Host, u.Scheme, tlsinfo, ConnReadTimeout, ConnWriteTimeout, sopts)
return transport.NewListenerWithOpts(u.Host, u.Scheme, transport.WithTLSInfo(tlsinfo), transport.WithTimeout(ConnReadTimeout, ConnWriteTimeout))
}
// NewRoundTripper returns a roundTripper used to send requests