diff --git a/pkg/transport/listener.go b/pkg/transport/listener.go index d94757b2c..78407ad8b 100644 --- a/pkg/transport/listener.go +++ b/pkg/transport/listener.go @@ -34,27 +34,30 @@ import ( "github.com/coreos/etcd/pkg/tlsutil" ) -func NewListener(addr string, scheme string, tlscfg *tls.Config) (l net.Listener, err error) { - if scheme == "unix" || scheme == "unixs" { - // unix sockets via unix://laddr - l, err = NewUnixListener(addr) - } else { - l, err = net.Listen("tcp", addr) - } - - if err != nil { +func NewListener(addr, scheme string, tlscfg *tls.Config) (l net.Listener, err error) { + if l, err = newListener(addr, scheme); err != nil { return nil, err } + return wrapTLS(addr, scheme, tlscfg, l) +} - if scheme == "https" || scheme == "unixs" { - if tlscfg == nil { - return nil, fmt.Errorf("cannot listen on TLS for %s: KeyFile and CertFile are not presented", scheme+"://"+addr) - } - - l = tls.NewListener(l, tlscfg) +func newListener(addr string, scheme string) (net.Listener, error) { + if scheme == "unix" || scheme == "unixs" { + // unix sockets via unix://laddr + return NewUnixListener(addr) } + return net.Listen("tcp", addr) +} - return l, nil +func wrapTLS(addr, scheme string, tlscfg *tls.Config, l net.Listener) (net.Listener, error) { + if scheme != "https" && scheme != "unixs" { + return l, nil + } + if tlscfg == nil { + l.Close() + return nil, fmt.Errorf("cannot listen on TLS for %s: KeyFile and CertFile are not presented", scheme+"://"+addr) + } + return tls.NewListener(l, tlscfg), nil } type TLSInfo struct { diff --git a/pkg/transport/timeout_listener.go b/pkg/transport/timeout_listener.go index f176c43b9..0f4df5fbe 100644 --- a/pkg/transport/timeout_listener.go +++ b/pkg/transport/timeout_listener.go @@ -24,15 +24,19 @@ import ( // If read/write on the accepted connection blocks longer than its time limit, // it will return timeout error. func NewTimeoutListener(addr string, scheme string, tlscfg *tls.Config, rdtimeoutd, wtimeoutd time.Duration) (net.Listener, error) { - ln, err := NewListener(addr, scheme, tlscfg) + ln, err := newListener(addr, scheme) if err != nil { return nil, err } - return &rwTimeoutListener{ + ln = &rwTimeoutListener{ Listener: ln, rdtimeoutd: rdtimeoutd, wtimeoutd: wtimeoutd, - }, nil + } + if ln, err = wrapTLS(addr, scheme, tlscfg, ln); err != nil { + return nil, err + } + return ln, nil } type rwTimeoutListener struct {