diff --git a/integration/v3_grpc_test.go b/integration/v3_grpc_test.go index ccae35b06..ac6d66467 100644 --- a/integration/v3_grpc_test.go +++ b/integration/v3_grpc_test.go @@ -16,7 +16,6 @@ package integration import ( "bytes" - "crypto/tls" "fmt" "io/ioutil" "math/rand" @@ -30,6 +29,7 @@ import ( "github.com/coreos/etcd/etcdserver/api/v3rpc/rpctypes" pb "github.com/coreos/etcd/etcdserver/etcdserverpb" "github.com/coreos/etcd/pkg/testutil" + "github.com/coreos/etcd/pkg/transport" "golang.org/x/net/context" "google.golang.org/grpc" @@ -1382,146 +1382,112 @@ func TestTLSGRPCAcceptSecureAll(t *testing.T) { // when all certs are atomically replaced by directory renaming. // And expects server to reject client requests, and vice versa. func TestTLSReloadAtomicReplace(t *testing.T) { - defer testutil.AfterTest(t) - - // clone valid,expired certs to separate directories for atomic renaming - vDir, err := ioutil.TempDir(os.TempDir(), "fixtures-valid") + tmpDir, err := ioutil.TempDir(os.TempDir(), "fixtures-tmp") if err != nil { t.Fatal(err) } - defer os.RemoveAll(vDir) - ts, err := copyTLSFiles(testTLSInfo, vDir) + os.RemoveAll(tmpDir) + defer os.RemoveAll(tmpDir) + + certsDir, err := ioutil.TempDir(os.TempDir(), "fixtures-to-load") if err != nil { t.Fatal(err) } - eDir, err := ioutil.TempDir(os.TempDir(), "fixtures-expired") + defer os.RemoveAll(certsDir) + + certsDirExp, err := ioutil.TempDir(os.TempDir(), "fixtures-expired") if err != nil { t.Fatal(err) } - defer os.RemoveAll(eDir) - if _, err = copyTLSFiles(testTLSInfoExpired, eDir); err != nil { - t.Fatal(err) - } + defer os.RemoveAll(certsDirExp) - tDir, err := ioutil.TempDir(os.TempDir(), "fixtures") - if err != nil { - t.Fatal(err) - } - os.RemoveAll(tDir) - defer os.RemoveAll(tDir) - - // start with valid certs - clus := NewClusterV3(t, &ClusterConfig{Size: 1, PeerTLS: &ts, ClientTLS: &ts}) - defer clus.Terminate(t) - - // concurrent client dialing while certs transition from valid to expired - errc := make(chan error, 1) - go func() { - for { - cc, err := ts.ClientConfig() - if err != nil { - if os.IsNotExist(err) { - // from concurrent renaming - continue - } - t.Fatal(err) - } - cli, cerr := clientv3.New(clientv3.Config{ - Endpoints: []string{clus.Members[0].GRPCAddr()}, - DialTimeout: time.Second, - TLS: cc, - }) - if cerr != nil { - errc <- cerr - return - } - cli.Close() + cloneFunc := func() transport.TLSInfo { + tlsInfo, err := copyTLSFiles(testTLSInfo, certsDir) + if err != nil { + t.Fatal(err) } - }() - - // replace certs directory with expired ones - if err = os.Rename(vDir, tDir); err != nil { - t.Fatal(err) - } - if err = os.Rename(eDir, vDir); err != nil { - t.Fatal(err) - } - - // after rename, - // 'vDir' contains expired certs - // 'tDir' contains valid certs - // 'eDir' does not exist - - select { - case err = <-errc: - if err != grpc.ErrClientConnTimeout { - t.Fatalf("expected %v, got %v", grpc.ErrClientConnTimeout, err) + if _, err = copyTLSFiles(testTLSInfoExpired, certsDirExp); err != nil { + t.Fatal(err) } - case <-time.After(5 * time.Second): - t.Fatal("failed to receive dial timeout error") + return tlsInfo } - - // now, replace expired certs back with valid ones - if err = os.Rename(tDir, eDir); err != nil { - t.Fatal(err) + replaceFunc := func() { + if err = os.Rename(certsDir, tmpDir); err != nil { + t.Fatal(err) + } + if err = os.Rename(certsDirExp, certsDir); err != nil { + t.Fatal(err) + } + // after rename, + // 'certsDir' contains expired certs + // 'tmpDir' contains valid certs + // 'certsDirExp' does not exist } - if err = os.Rename(vDir, tDir); err != nil { - t.Fatal(err) + revertFunc := func() { + if err = os.Rename(tmpDir, certsDirExp); err != nil { + t.Fatal(err) + } + if err = os.Rename(certsDir, tmpDir); err != nil { + t.Fatal(err) + } + if err = os.Rename(certsDirExp, certsDir); err != nil { + t.Fatal(err) + } } - if err = os.Rename(eDir, vDir); err != nil { - t.Fatal(err) - } - - // new incoming client request should trigger - // listener to reload valid certs - var tls *tls.Config - tls, err = ts.ClientConfig() - if err != nil { - t.Fatal(err) - } - var cl *clientv3.Client - cl, err = clientv3.New(clientv3.Config{ - Endpoints: []string{clus.Members[0].GRPCAddr()}, - DialTimeout: time.Second, - TLS: tls, - }) - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - cl.Close() + testTLSReload(t, cloneFunc, replaceFunc, revertFunc) } // TestTLSReloadCopy ensures server reloads expired/valid certs // when new certs are copied over, one by one. And expects server // to reject client requests, and vice versa. func TestTLSReloadCopy(t *testing.T) { + certsDir, err := ioutil.TempDir(os.TempDir(), "fixtures-to-load") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(certsDir) + + cloneFunc := func() transport.TLSInfo { + tlsInfo, err := copyTLSFiles(testTLSInfo, certsDir) + if err != nil { + t.Fatal(err) + } + return tlsInfo + } + replaceFunc := func() { + if _, err = copyTLSFiles(testTLSInfoExpired, certsDir); err != nil { + t.Fatal(err) + } + } + revertFunc := func() { + if _, err = copyTLSFiles(testTLSInfo, certsDir); err != nil { + t.Fatal(err) + } + } + testTLSReload(t, cloneFunc, replaceFunc, revertFunc) +} + +func testTLSReload(t *testing.T, cloneFunc func() transport.TLSInfo, replaceFunc func(), revertFunc func()) { defer testutil.AfterTest(t) - // clone certs directory, free to overwrite - cDir, err := ioutil.TempDir(os.TempDir(), "fixtures-test") - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(cDir) - ts, err := copyTLSFiles(testTLSInfo, cDir) - if err != nil { - t.Fatal(err) - } + // 1. separate copies for TLS assets modification + tlsInfo := cloneFunc() - // start with valid certs - clus := NewClusterV3(t, &ClusterConfig{Size: 1, PeerTLS: &ts, ClientTLS: &ts}) + // 2. start cluster with valid certs + clus := NewClusterV3(t, &ClusterConfig{Size: 1, PeerTLS: &tlsInfo, ClientTLS: &tlsInfo}) defer clus.Terminate(t) - // concurrent client dialing while certs transition from valid to expired + // 3. concurrent client dialing while certs become expired errc := make(chan error, 1) go func() { for { - cc, err := ts.ClientConfig() + cc, err := tlsInfo.ClientConfig() if err != nil { // errors in 'go/src/crypto/tls/tls.go' // tls: private key does not match public key // tls: failed to find any PEM data in key input // tls: failed to find any PEM data in certificate input + // Or 'does not exist', 'not found', etc t.Log(err) continue } @@ -1538,12 +1504,10 @@ func TestTLSReloadCopy(t *testing.T) { } }() - // overwrite valid certs with expired ones - // (e.g. simulate cert expiration in practice) - if _, err = copyTLSFiles(testTLSInfoExpired, cDir); err != nil { - t.Fatal(err) - } + // 4. replace certs with expired ones + replaceFunc() + // 5. expect dial time-out when loading expired certs select { case gerr := <-errc: if gerr != grpc.ErrClientConnTimeout { @@ -1553,26 +1517,21 @@ func TestTLSReloadCopy(t *testing.T) { t.Fatal("failed to receive dial timeout error") } - // now, replace expired certs back with valid ones - if _, err = copyTLSFiles(testTLSInfo, cDir); err != nil { - t.Fatal(err) - } + // 6. replace expired certs back with valid ones + revertFunc() - // new incoming client request should trigger - // listener to reload valid certs - var tls *tls.Config - tls, err = ts.ClientConfig() - if err != nil { - t.Fatal(err) + // 7. new requests should trigger listener to reload valid certs + tls, terr := tlsInfo.ClientConfig() + if terr != nil { + t.Fatal(terr) } - var cl *clientv3.Client - cl, err = clientv3.New(clientv3.Config{ + cl, cerr := clientv3.New(clientv3.Config{ Endpoints: []string{clus.Members[0].GRPCAddr()}, DialTimeout: time.Second, TLS: tls, }) - if err != nil { - t.Fatalf("expected no error, got %v", err) + if cerr != nil { + t.Fatalf("expected no error, got %v", cerr) } cl.Close() }