diff --git a/integration/cluster.go b/integration/cluster.go index ed245eca2..7948566ac 100644 --- a/integration/cluster.go +++ b/integration/cluster.go @@ -31,6 +31,7 @@ import ( "testing" "time" + "github.com/cockroachdb/cmux" "golang.org/x/net/context" "google.golang.org/grpc" @@ -475,13 +476,14 @@ type member struct { // ClientTLSInfo enables client TLS when set ClientTLSInfo *transport.TLSInfo - raftHandler *testutil.PauseableHandler - s *etcdserver.EtcdServer - hss []*httptest.Server + raftHandler *testutil.PauseableHandler + s *etcdserver.EtcdServer + serverClosers []func() - grpcServer *grpc.Server - grpcAddr string - grpcBridge *bridge + grpcServer *grpc.Server + grpcServerPeer *grpc.Server + grpcAddr string + grpcBridge *bridge // serverClient is a clientv3 that directly calls the etcdserver. serverClient *clientv3.Client @@ -649,23 +651,80 @@ func (m *member) Launch() error { m.s.SyncTicker = time.NewTicker(500 * time.Millisecond) m.s.Start() - m.raftHandler = &testutil.PauseableHandler{Next: etcdhttp.NewPeerHandler(m.s)} - - for _, ln := range m.PeerListeners { - hs := &httptest.Server{ - Listener: ln, - Config: &http.Server{Handler: m.raftHandler}, + var peerTLScfg *tls.Config + if m.PeerTLSInfo != nil && !m.PeerTLSInfo.Empty() { + if peerTLScfg, err = m.PeerTLSInfo.ServerConfig(); err != nil { + return err } - if m.PeerTLSInfo == nil { - hs.Start() - } else { - hs.TLS, err = m.PeerTLSInfo.ServerConfig() + } + + if m.grpcListener != nil { + var ( + tlscfg *tls.Config + ) + if m.ClientTLSInfo != nil && !m.ClientTLSInfo.Empty() { + tlscfg, err = m.ClientTLSInfo.ServerConfig() if err != nil { return err } - hs.StartTLS() } - m.hss = append(m.hss, hs) + m.grpcServer = v3rpc.Server(m.s, tlscfg) + m.grpcServerPeer = v3rpc.Server(m.s, peerTLScfg) + m.serverClient = v3client.New(m.s) + lockpb.RegisterLockServer(m.grpcServer, v3lock.NewLockServer(m.serverClient)) + epb.RegisterElectionServer(m.grpcServer, v3election.NewElectionServer(m.serverClient)) + go m.grpcServer.Serve(m.grpcListener) + } + + m.raftHandler = &testutil.PauseableHandler{Next: etcdhttp.NewPeerHandler(m.s)} + + h := (http.Handler)(m.raftHandler) + if m.grpcListener != nil { + h = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.ProtoMajor == 2 && strings.Contains(r.Header.Get("Content-Type"), "application/grpc") { + m.grpcServerPeer.ServeHTTP(w, r) + } else { + m.raftHandler.ServeHTTP(w, r) + } + }) + } + + for _, ln := range m.PeerListeners { + cm := cmux.New(ln) + // don't hang on matcher after closing listener + cm.SetReadTimeout(time.Second) + + if m.grpcServer != nil { + grpcl := cm.Match(cmux.HTTP2()) + go m.grpcServerPeer.Serve(grpcl) + } + + // serve http1/http2 rafthttp/grpc + ll := cm.Match(cmux.Any()) + if peerTLScfg != nil { + if ll, err = transport.NewTLSListener(ll, m.PeerTLSInfo); err != nil { + return err + } + } + hs := &httptest.Server{ + Listener: ll, + Config: &http.Server{Handler: h, TLSConfig: peerTLScfg}, + TLS: peerTLScfg, + } + hs.Start() + + donec := make(chan struct{}) + go func() { + defer close(donec) + cm.Serve() + }() + closer := func() { + ll.Close() + hs.CloseClientConnections() + hs.Close() + <-donec + } + m.serverClosers = append(m.serverClosers, closer) } for _, ln := range m.ClientListeners { hs := &httptest.Server{ @@ -681,23 +740,12 @@ func (m *member) Launch() error { } hs.StartTLS() } - m.hss = append(m.hss, hs) - } - if m.grpcListener != nil { - var ( - tlscfg *tls.Config - ) - if m.ClientTLSInfo != nil && !m.ClientTLSInfo.Empty() { - tlscfg, err = m.ClientTLSInfo.ServerConfig() - if err != nil { - return err - } + closer := func() { + ln.Close() + hs.CloseClientConnections() + hs.Close() } - m.grpcServer = v3rpc.Server(m.s, tlscfg) - m.serverClient = v3client.New(m.s) - lockpb.RegisterLockServer(m.grpcServer, v3lock.NewLockServer(m.serverClient)) - epb.RegisterElectionServer(m.grpcServer, v3election.NewElectionServer(m.serverClient)) - go m.grpcServer.Serve(m.grpcListener) + m.serverClosers = append(m.serverClosers, closer) } plog.Printf("launched %s (%s)", m.Name, m.grpcAddr) @@ -745,13 +793,16 @@ func (m *member) Close() { m.serverClient = nil } if m.grpcServer != nil { + m.grpcServer.Stop() m.grpcServer.GracefulStop() m.grpcServer = nil + m.grpcServerPeer.Stop() + m.grpcServerPeer.GracefulStop() + m.grpcServerPeer = nil } m.s.HardStop() - for _, hs := range m.hss { - hs.CloseClientConnections() - hs.Close() + for _, f := range m.serverClosers { + f() } } @@ -759,7 +810,7 @@ func (m *member) Close() { func (m *member) Stop(t *testing.T) { plog.Printf("stopping %s (%s)", m.Name, m.grpcAddr) m.Close() - m.hss = nil + m.serverClosers = nil plog.Printf("stopped %s (%s)", m.Name, m.grpcAddr) }