diff --git a/server/embed/etcd.go b/server/embed/etcd.go index 86605021c..bbb39671c 100644 --- a/server/embed/etcd.go +++ b/server/embed/etcd.go @@ -20,6 +20,7 @@ import ( "fmt" "io/ioutil" defaultLog "log" + "math" "net" "net/http" "net/url" @@ -32,6 +33,7 @@ import ( "go.etcd.io/etcd/api/v3/version" "go.etcd.io/etcd/client/pkg/v3/transport" "go.etcd.io/etcd/client/pkg/v3/types" + "go.etcd.io/etcd/client/v3/credentials" "go.etcd.io/etcd/pkg/v3/debugutil" runtimeutil "go.etcd.io/etcd/pkg/v3/runtime" "go.etcd.io/etcd/server/v3/config" @@ -48,6 +50,7 @@ import ( "github.com/soheilhy/cmux" "go.uber.org/zap" "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" ) @@ -743,12 +746,57 @@ func (e *Etcd) serveClients() (err error) { // start client servers in each goroutine for _, sctx := range e.sctxs { go func(s *serveCtx) { - e.errHandler(s.serve(e.Server, &e.cfg.ClientTLSInfo, h, e.errHandler, gopts...)) + e.errHandler(s.serve(e.Server, &e.cfg.ClientTLSInfo, h, e.errHandler, e.grpcGatewayDial(), gopts...)) }(sctx) } return nil } +func (e *Etcd) grpcGatewayDial() (grpcDial func(ctx context.Context) (*grpc.ClientConn, error)) { + if !e.cfg.EnableGRPCGateway { + return nil + } + sctx := e.pickGrpcGatewayServeContext() + addr := sctx.addr + if network := sctx.network; network == "unix" { + // explicitly define unix network for gRPC socket support + addr = fmt.Sprintf("%s://%s", network, addr) + } + + opts := []grpc.DialOption{grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(math.MaxInt32))} + if sctx.secure { + tlscfg, tlsErr := e.cfg.ClientTLSInfo.ServerConfig() + if tlsErr != nil { + return func(ctx context.Context) (*grpc.ClientConn, error) { + return nil, tlsErr + } + } + dtls := tlscfg.Clone() + // trust local server + dtls.InsecureSkipVerify = true + bundle := credentials.NewBundle(credentials.Config{TLSConfig: dtls}) + opts = append(opts, grpc.WithTransportCredentials(bundle.TransportCredentials())) + } else { + opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) + } + + return func(ctx context.Context) (*grpc.ClientConn, error) { + conn, err := grpc.DialContext(ctx, addr, opts...) + if err != nil { + sctx.lg.Error("grpc gateway failed to dial", zap.String("addr", addr), zap.Error(err)) + return nil, err + } + return conn, err + } +} + +func (e *Etcd) pickGrpcGatewayServeContext() *serveCtx { + for _, sctx := range e.sctxs { + return sctx + } + panic("Expect at least one context able to serve grpc") +} + func (e *Etcd) serveMetrics() (err error) { if e.cfg.Metrics == "extensive" { grpc_prometheus.EnableHandlingTimeHistogram() diff --git a/server/embed/serve.go b/server/embed/serve.go index 6cfaeb01e..a53cb38e9 100644 --- a/server/embed/serve.go +++ b/server/embed/serve.go @@ -19,14 +19,12 @@ import ( "fmt" "io/ioutil" defaultLog "log" - "math" "net" "net/http" "strings" etcdservergw "go.etcd.io/etcd/api/v3/etcdserverpb/gw" "go.etcd.io/etcd/client/pkg/v3/transport" - "go.etcd.io/etcd/client/v3/credentials" "go.etcd.io/etcd/pkg/v3/debugutil" "go.etcd.io/etcd/pkg/v3/httputil" "go.etcd.io/etcd/server/v3/config" @@ -95,6 +93,7 @@ func (sctx *serveCtx) serve( tlsinfo *transport.TLSInfo, handler http.Handler, errHandler func(error), + grpcDialForRestGatewayBackends func(ctx context.Context) (*grpc.ClientConn, error), gopts ...grpc.ServerOption) (err error) { logger := defaultLog.New(ioutil.Discard, "etcdhttp", 0) <-s.ReadyNotify() @@ -106,6 +105,18 @@ func (sctx *serveCtx) serve( servElection := v3election.NewElectionServer(v3c) servLock := v3lock.NewLockServer(v3c) + // Make sure serversC is closed even if we prematurely exit the function. + defer close(sctx.serversC) + var gwmux *gw.ServeMux + if s.Cfg.EnableGRPCGateway { + // GRPC gateway connects to grpc server via connection provided by grpc dial. + gwmux, err = sctx.registerGateway(grpcDialForRestGatewayBackends) + if err != nil { + sctx.lg.Error("registerGateway failed", zap.Error(err)) + return err + } + } + if sctx.insecure { gs := v3rpc.Server(s, nil, nil, gopts...) v3electionpb.RegisterElectionServer(gs, servElection) @@ -127,14 +138,6 @@ func (sctx *serveCtx) serve( errHandler(gs.Serve(grpcLis)) }(gs, grpcl) - var gwmux *gw.ServeMux - if s.Cfg.EnableGRPCGateway { - gwmux, err = sctx.registerGateway([]grpc.DialOption{grpc.WithInsecure()}) - if err != nil { - return err - } - } - httpmux := sctx.createMux(gwmux, handler) srvhttp := &http.Server{ @@ -180,20 +183,6 @@ func (sctx *serveCtx) serve( }(gs) handler = grpcHandlerFunc(gs, handler) - - var gwmux *gw.ServeMux - if s.Cfg.EnableGRPCGateway { - dtls := tlscfg.Clone() - // trust local server - dtls.InsecureSkipVerify = true - bundle := credentials.NewBundle(credentials.Config{TLSConfig: dtls}) - opts := []grpc.DialOption{grpc.WithTransportCredentials(bundle.TransportCredentials())} - gwmux, err = sctx.registerGateway(opts) - if err != nil { - return err - } - } - var tlsl net.Listener tlsl, err = transport.NewTLSListener(m.Match(cmux.Any()), tlsinfo) if err != nil { @@ -255,20 +244,10 @@ func grpcHandlerFunc(grpcServer *grpc.Server, otherHandler http.Handler) http.Ha type registerHandlerFunc func(context.Context, *gw.ServeMux, *grpc.ClientConn) error -func (sctx *serveCtx) registerGateway(opts []grpc.DialOption) (*gw.ServeMux, error) { +func (sctx *serveCtx) registerGateway(dial func(ctx context.Context) (*grpc.ClientConn, error)) (*gw.ServeMux, error) { ctx := sctx.ctx - addr := sctx.addr - if network := sctx.network; network == "unix" { - // explicitly define unix network for gRPC socket support - addr = fmt.Sprintf("%s://%s", network, addr) - } - - opts = append(opts, grpc.WithDefaultCallOptions([]grpc.CallOption{ - grpc.MaxCallRecvMsgSize(math.MaxInt32), - }...)) - - conn, err := grpc.DialContext(ctx, addr, opts...) + conn, err := dial(ctx) if err != nil { return nil, err }