From 70a9929b5d33ea1369ed8687ca8c0369ad6fdd01 Mon Sep 17 00:00:00 2001 From: Anthony Romano Date: Fri, 7 Apr 2017 13:29:54 -0700 Subject: [PATCH] transport: use actual certs for listener tests --- pkg/transport/keepalive_listener_test.go | 7 +- pkg/transport/listener_test.go | 94 +++++++++++------------- pkg/transport/listener_tls.go | 1 - 3 files changed, 47 insertions(+), 55 deletions(-) diff --git a/pkg/transport/keepalive_listener_test.go b/pkg/transport/keepalive_listener_test.go index f8c062cf7..425f53368 100644 --- a/pkg/transport/keepalive_listener_test.go +++ b/pkg/transport/keepalive_listener_test.go @@ -18,7 +18,6 @@ import ( "crypto/tls" "net" "net/http" - "os" "testing" ) @@ -50,12 +49,12 @@ func TestNewKeepAliveListener(t *testing.T) { } // tls - tmp, err := createTempFile([]byte("XXX")) + tlsinfo, del, err := createSelfCert() if err != nil { t.Fatalf("unable to create tmpfile: %v", err) } - defer os.Remove(tmp) - tlsInfo := TLSInfo{CertFile: tmp, KeyFile: tmp} + defer del() + tlsInfo := TLSInfo{CertFile: tlsinfo.CertFile, KeyFile: tlsinfo.KeyFile} tlsInfo.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil) tlscfg, err := tlsInfo.ServerConfig() if err != nil { diff --git a/pkg/transport/listener_test.go b/pkg/transport/listener_test.go index d5ff31789..ff0fe94cf 100644 --- a/pkg/transport/listener_test.go +++ b/pkg/transport/listener_test.go @@ -24,18 +24,16 @@ import ( "time" ) -func createTempFile(b []byte) (string, error) { - f, err := ioutil.TempFile("", "etcd-test-tls-") +func createSelfCert() (*TLSInfo, func(), error) { + d, terr := ioutil.TempDir("", "etcd-test-tls-") + if terr != nil { + return nil, nil, terr + } + info, err := SelfCert(d, []string{"127.0.0.1"}) if err != nil { - return "", err + return nil, nil, err } - defer f.Close() - - if _, err = f.Write(b); err != nil { - return "", err - } - - return f.Name(), nil + return &info, func() { os.RemoveAll(d) }, nil } func fakeCertificateParserFunc(cert tls.Certificate, err error) func(certPEMBlock, keyPEMBlock []byte) (tls.Certificate, error) { @@ -47,28 +45,25 @@ func fakeCertificateParserFunc(cert tls.Certificate, err error) func(certPEMBloc // TestNewListenerTLSInfo tests that NewListener with valid TLSInfo returns // a TLS listener that accepts TLS connections. func TestNewListenerTLSInfo(t *testing.T) { - tmp, err := createTempFile([]byte("XXX")) + tlsInfo, del, err := createSelfCert() if err != nil { - t.Fatalf("unable to create tmpfile: %v", err) + t.Fatalf("unable to create cert: %v", err) } - defer os.Remove(tmp) - tlsInfo := TLSInfo{CertFile: tmp, KeyFile: tmp} - tlsInfo.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil) - testNewListenerTLSInfoAccept(t, tlsInfo) + defer del() + testNewListenerTLSInfoAccept(t, *tlsInfo) } func testNewListenerTLSInfoAccept(t *testing.T, tlsInfo TLSInfo) { - tlscfg, err := tlsInfo.ServerConfig() - if err != nil { - t.Fatalf("unexpected serverConfig error: %v", err) - } - ln, err := NewListener("127.0.0.1:0", "https", tlscfg) + ln, err := NewListener("127.0.0.1:0", "https", &tlsInfo) if err != nil { t.Fatalf("unexpected NewListener error: %v", err) } defer ln.Close() - go http.Get("https://" + ln.Addr().String()) + tr := &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} + cli := &http.Client{Transport: tr} + go cli.Get("https://" + ln.Addr().String()) + conn, err := ln.Accept() if err != nil { t.Fatalf("unexpected Accept error: %v", err) @@ -87,25 +82,25 @@ func TestNewListenerTLSEmptyInfo(t *testing.T) { } func TestNewTransportTLSInfo(t *testing.T) { - tmp, err := createTempFile([]byte("XXX")) + tlsinfo, del, err := createSelfCert() if err != nil { - t.Fatalf("Unable to prepare tmpfile: %v", err) + t.Fatalf("unable to create cert: %v", err) } - defer os.Remove(tmp) + defer del() tests := []TLSInfo{ {}, { - CertFile: tmp, - KeyFile: tmp, + CertFile: tlsinfo.CertFile, + KeyFile: tlsinfo.KeyFile, }, { - CertFile: tmp, - KeyFile: tmp, - CAFile: tmp, + CertFile: tlsinfo.CertFile, + KeyFile: tlsinfo.KeyFile, + CAFile: tlsinfo.CAFile, }, { - CAFile: tmp, + CAFile: tlsinfo.CAFile, }, } @@ -159,17 +154,17 @@ func TestTLSInfoEmpty(t *testing.T) { } func TestTLSInfoMissingFields(t *testing.T) { - tmp, err := createTempFile([]byte("XXX")) + tlsinfo, del, err := createSelfCert() if err != nil { - t.Fatalf("Unable to prepare tmpfile: %v", err) + t.Fatalf("unable to create cert: %v", err) } - defer os.Remove(tmp) + defer del() tests := []TLSInfo{ - {CertFile: tmp}, - {KeyFile: tmp}, - {CertFile: tmp, CAFile: tmp}, - {KeyFile: tmp, CAFile: tmp}, + {CertFile: tlsinfo.CertFile}, + {KeyFile: tlsinfo.KeyFile}, + {CertFile: tlsinfo.CertFile, CAFile: tlsinfo.CAFile}, + {KeyFile: tlsinfo.KeyFile, CAFile: tlsinfo.CAFile}, } for i, info := range tests { @@ -184,30 +179,29 @@ func TestTLSInfoMissingFields(t *testing.T) { } func TestTLSInfoParseFuncError(t *testing.T) { - tmp, err := createTempFile([]byte("XXX")) + tlsinfo, del, err := createSelfCert() if err != nil { - t.Fatalf("Unable to prepare tmpfile: %v", err) + t.Fatalf("unable to create cert: %v", err) } - defer os.Remove(tmp) + defer del() - info := TLSInfo{CertFile: tmp, KeyFile: tmp, CAFile: tmp} - info.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, errors.New("fake")) + tlsinfo.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, errors.New("fake")) - if _, err = info.ServerConfig(); err == nil { + if _, err = tlsinfo.ServerConfig(); err == nil { t.Errorf("expected non-nil error from ServerConfig()") } - if _, err = info.ClientConfig(); err == nil { + if _, err = tlsinfo.ClientConfig(); err == nil { t.Errorf("expected non-nil error from ClientConfig()") } } func TestTLSInfoConfigFuncs(t *testing.T) { - tmp, err := createTempFile([]byte("XXX")) + tlsinfo, del, err := createSelfCert() if err != nil { - t.Fatalf("Unable to prepare tmpfile: %v", err) + t.Fatalf("unable to create cert: %v", err) } - defer os.Remove(tmp) + defer del() tests := []struct { info TLSInfo @@ -215,13 +209,13 @@ func TestTLSInfoConfigFuncs(t *testing.T) { wantCAs bool }{ { - info: TLSInfo{CertFile: tmp, KeyFile: tmp}, + info: TLSInfo{CertFile: tlsinfo.CertFile, KeyFile: tlsinfo.KeyFile}, clientAuth: tls.NoClientCert, wantCAs: false, }, { - info: TLSInfo{CertFile: tmp, KeyFile: tmp, CAFile: tmp}, + info: TLSInfo{CertFile: tlsinfo.CertFile, KeyFile: tlsinfo.KeyFile, CAFile: tlsinfo.CertFile}, clientAuth: tls.RequireAndVerifyClientCert, wantCAs: true, }, diff --git a/pkg/transport/listener_tls.go b/pkg/transport/listener_tls.go index 39228bb4a..53e6a1048 100644 --- a/pkg/transport/listener_tls.go +++ b/pkg/transport/listener_tls.go @@ -121,7 +121,6 @@ func (l *tlsListener) acceptLoop() { } } } - select { case l.connc <- tlsConn: conn = nil