From 4960324876b693b181a80291aaca3bf1492b12c7 Mon Sep 17 00:00:00 2001 From: Xiang Li Date: Thu, 29 Jan 2015 09:38:43 -0800 Subject: [PATCH] pkg/transport: fix tlskeepalive --- pkg/transport/keepalive_listener.go | 53 +++++++++++++++++++++--- pkg/transport/keepalive_listener_test.go | 27 ++++++++++++ 2 files changed, 75 insertions(+), 5 deletions(-) diff --git a/pkg/transport/keepalive_listener.go b/pkg/transport/keepalive_listener.go index 516addf83..cc7ed9e71 100644 --- a/pkg/transport/keepalive_listener.go +++ b/pkg/transport/keepalive_listener.go @@ -15,6 +15,7 @@ package transport import ( + "crypto/tls" "net" "time" ) @@ -22,18 +23,26 @@ import ( // NewKeepAliveListener returns a listener that listens on the given address. // http://tldp.org/HOWTO/TCP-Keepalive-HOWTO/overview.html func NewKeepAliveListener(addr string, scheme string, info TLSInfo) (net.Listener, error) { - ln, err := NewListener(addr, scheme, info) + l, err := net.Listen("tcp", addr) if err != nil { return nil, err } + + if !info.Empty() && scheme == "https" { + cfg, err := info.ServerConfig() + if err != nil { + return nil, err + } + + return newTLSKeepaliveListener(l, cfg), nil + } + return &keepaliveListener{ - Listener: ln, + Listener: l, }, nil } -type keepaliveListener struct { - net.Listener -} +type keepaliveListener struct{ net.Listener } func (kln *keepaliveListener) Accept() (net.Conn, error) { c, err := kln.Listener.Accept() @@ -48,3 +57,37 @@ func (kln *keepaliveListener) Accept() (net.Conn, error) { tcpc.SetKeepAlivePeriod(30 * time.Second) return tcpc, nil } + +// A tlsKeepaliveListener implements a network listener (net.Listener) for TLS connections. +type tlsKeepaliveListener struct { + net.Listener + config *tls.Config +} + +// Accept waits for and returns the next incoming TLS connection. +// The returned connection c is a *tls.Conn. +func (l *tlsKeepaliveListener) Accept() (c net.Conn, err error) { + c, err = l.Listener.Accept() + if err != nil { + return + } + tcpc := c.(*net.TCPConn) + // detection time: tcp_keepalive_time + tcp_keepalive_probes + tcp_keepalive_intvl + // default on linux: 30 + 8 * 30 + // default on osx: 30 + 8 * 75 + tcpc.SetKeepAlive(true) + tcpc.SetKeepAlivePeriod(30 * time.Second) + c = tls.Server(c, l.config) + return +} + +// NewListener creates a Listener which accepts connections from an inner +// Listener and wraps each connection with Server. +// The configuration config must be non-nil and must have +// at least one certificate. +func newTLSKeepaliveListener(inner net.Listener, config *tls.Config) net.Listener { + l := &tlsKeepaliveListener{} + l.Listener = inner + l.config = config + return l +} diff --git a/pkg/transport/keepalive_listener_test.go b/pkg/transport/keepalive_listener_test.go index 599ad10f7..f9458436a 100644 --- a/pkg/transport/keepalive_listener_test.go +++ b/pkg/transport/keepalive_listener_test.go @@ -15,7 +15,9 @@ package transport import ( + "crypto/tls" "net/http" + "os" "testing" ) @@ -34,4 +36,29 @@ func TestNewKeepAliveListener(t *testing.T) { t.Fatalf("unexpected Accept error: %v", err) } conn.Close() + ln.Close() + + // tls + tmp, err := createTempFile([]byte("XXX")) + if err != nil { + t.Fatalf("unable to create tmpfile: %v", err) + } + defer os.Remove(tmp) + tlsInfo := TLSInfo{CertFile: tmp, KeyFile: tmp} + tlsInfo.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil) + tlsln, err := NewKeepAliveListener("127.0.0.1:0", "https", tlsInfo) + if err != nil { + t.Fatalf("unexpected NewKeepAliveListener error: %v", err) + } + + go http.Get("https://" + tlsln.Addr().String()) + conn, err = tlsln.Accept() + if err != nil { + t.Fatalf("unexpected Accept error: %v", err) + } + if _, ok := conn.(*tls.Conn); !ok { + t.Errorf("failed to accept *tls.Conn") + } + conn.Close() + tlsln.Close() }