diff --git a/server/embed/etcd.go b/server/embed/etcd.go index c9621a5c3..dc271f928 100644 --- a/server/embed/etcd.go +++ b/server/embed/etcd.go @@ -19,6 +19,7 @@ import ( "fmt" "io" defaultLog "log" + "math" "net" "net/http" "net/url" @@ -32,6 +33,7 @@ import ( "go.etcd.io/etcd/api/v3/version" "go.etcd.io/etcd/client/pkg/v3/transport" "go.etcd.io/etcd/client/pkg/v3/types" + "go.etcd.io/etcd/client/v3/credentials" "go.etcd.io/etcd/pkg/v3/debugutil" runtimeutil "go.etcd.io/etcd/pkg/v3/runtime" "go.etcd.io/etcd/server/v3/config" @@ -45,6 +47,7 @@ import ( "github.com/soheilhy/cmux" "go.uber.org/zap" "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" ) @@ -736,12 +739,56 @@ func (e *Etcd) serveClients() (err error) { // start client servers in each goroutine for _, sctx := range e.sctxs { go func(s *serveCtx) { - e.errHandler(s.serve(e.Server, &e.cfg.ClientTLSInfo, mux, e.errHandler, gopts...)) + e.errHandler(s.serve(e.Server, &e.cfg.ClientTLSInfo, mux, e.errHandler, e.grpcGatewayDial(), gopts...)) }(sctx) } return nil } +func (e *Etcd) grpcGatewayDial() (grpcDial func(ctx context.Context) (*grpc.ClientConn, error)) { + if !e.cfg.EnableGRPCGateway { + return nil + } + sctx := e.pickGrpcGatewayServeContext() + addr := sctx.addr + if network := sctx.network; network == "unix" { + // explicitly define unix network for gRPC socket support + addr = fmt.Sprintf("%s:%s", network, addr) + } + opts := []grpc.DialOption{grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(math.MaxInt32))} + if sctx.secure { + tlscfg, tlsErr := e.cfg.ClientTLSInfo.ServerConfig() + if tlsErr != nil { + return func(ctx context.Context) (*grpc.ClientConn, error) { + return nil, tlsErr + } + } + dtls := tlscfg.Clone() + // trust local server + dtls.InsecureSkipVerify = true + bundle := credentials.NewBundle(credentials.Config{TLSConfig: dtls}) + opts = append(opts, grpc.WithTransportCredentials(bundle.TransportCredentials())) + } else { + opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) + } + + return func(ctx context.Context) (*grpc.ClientConn, error) { + conn, err := grpc.DialContext(ctx, addr, opts...) + if err != nil { + sctx.lg.Error("grpc gateway failed to dial", zap.String("addr", addr), zap.Error(err)) + return nil, err + } + return conn, err + } +} + +func (e *Etcd) pickGrpcGatewayServeContext() *serveCtx { + for _, sctx := range e.sctxs { + return sctx + } + panic("Expect at least one context able to serve grpc") +} + func (e *Etcd) serveMetrics() (err error) { if e.cfg.Metrics == "extensive" { grpc_prometheus.EnableHandlingTimeHistogram() diff --git a/server/embed/serve.go b/server/embed/serve.go index 577430518..044a88d77 100644 --- a/server/embed/serve.go +++ b/server/embed/serve.go @@ -19,7 +19,6 @@ import ( "fmt" "io" defaultLog "log" - "math" "net" "net/http" "strings" @@ -27,7 +26,6 @@ import ( etcdservergw "go.etcd.io/etcd/api/v3/etcdserverpb/gw" "go.etcd.io/etcd/client/pkg/v3/transport" - "go.etcd.io/etcd/client/v3/credentials" "go.etcd.io/etcd/pkg/v3/debugutil" "go.etcd.io/etcd/pkg/v3/httputil" "go.etcd.io/etcd/server/v3/config" @@ -48,7 +46,6 @@ import ( "golang.org/x/net/http2" "golang.org/x/net/trace" "google.golang.org/grpc" - "google.golang.org/grpc/credentials/insecure" ) type serveCtx struct { @@ -97,6 +94,7 @@ func (sctx *serveCtx) serve( tlsinfo *transport.TLSInfo, handler http.Handler, errHandler func(error), + grpcDialForRestGatewayBackends func(ctx context.Context) (*grpc.ClientConn, error), gopts ...grpc.ServerOption) (err error) { logger := defaultLog.New(io.Discard, "etcdhttp", 0) @@ -118,6 +116,15 @@ func (sctx *serveCtx) serve( // Make sure serversC is closed even if we prematurely exit the function. defer close(sctx.serversC) + var gwmux *gw.ServeMux + if s.Cfg.EnableGRPCGateway { + // GRPC gateway connects to grpc server via connection provided by grpc dial. + gwmux, err = sctx.registerGateway(grpcDialForRestGatewayBackends) + if err != nil { + sctx.lg.Error("registerGateway failed", zap.Error(err)) + return err + } + } if sctx.insecure { gs := v3rpc.Server(s, nil, nil, gopts...) @@ -140,15 +147,6 @@ func (sctx *serveCtx) serve( errHandler(gs.Serve(grpcLis)) }(gs, grpcl) - var gwmux *gw.ServeMux - if s.Cfg.EnableGRPCGateway { - gwmux, err = sctx.registerGateway([]grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())}) - if err != nil { - sctx.lg.Error("registerGateway failed", zap.Error(err)) - return err - } - } - httpmux := sctx.createMux(gwmux, handler) srvhttp := &http.Server{ @@ -194,20 +192,6 @@ func (sctx *serveCtx) serve( }(gs) handler = grpcHandlerFunc(gs, handler) - - var gwmux *gw.ServeMux - if s.Cfg.EnableGRPCGateway { - dtls := tlscfg.Clone() - // trust local server - dtls.InsecureSkipVerify = true - bundle := credentials.NewBundle(credentials.Config{TLSConfig: dtls}) - opts := []grpc.DialOption{grpc.WithTransportCredentials(bundle.TransportCredentials())} - gwmux, err = sctx.registerGateway(opts) - if err != nil { - return err - } - } - var tlsl net.Listener tlsl, err = transport.NewTLSListener(m.Match(cmux.Any()), tlsinfo) if err != nil { @@ -268,22 +252,11 @@ func grpcHandlerFunc(grpcServer *grpc.Server, otherHandler http.Handler) http.Ha type registerHandlerFunc func(context.Context, *gw.ServeMux, *grpc.ClientConn) error -func (sctx *serveCtx) registerGateway(opts []grpc.DialOption) (*gw.ServeMux, error) { +func (sctx *serveCtx) registerGateway(dial func(ctx context.Context) (*grpc.ClientConn, error)) (*gw.ServeMux, error) { ctx := sctx.ctx - addr := sctx.addr - if network := sctx.network; network == "unix" { - // explicitly define unix network for gRPC socket support - addr = fmt.Sprintf("%s:%s", network, addr) - } - - opts = append(opts, grpc.WithDefaultCallOptions([]grpc.CallOption{ - grpc.MaxCallRecvMsgSize(math.MaxInt32), - }...)) - - conn, err := grpc.DialContext(ctx, addr, opts...) + conn, err := dial(ctx) if err != nil { - sctx.lg.Error("registerGateway failed to dial", zap.String("addr", addr), zap.Error(err)) return nil, err } gwmux := gw.NewServeMux()