diff --git a/etcdmain/gateway.go b/etcdmain/gateway.go index 18a2f78c8..56c4ae355 100644 --- a/etcdmain/gateway.go +++ b/etcdmain/gateway.go @@ -20,14 +20,19 @@ import ( "os" "time" + "github.com/coreos/etcd/client" + "github.com/coreos/etcd/pkg/transport" "github.com/coreos/etcd/proxy/tcpproxy" "github.com/spf13/cobra" ) var ( - gatewayListenAddr string - gatewayEndpoints []string - getewayRetryDelay time.Duration + gatewayListenAddr string + gatewayEndpoints []string + gatewayDNSCluster string + gatewayInsecureDiscovery bool + getewayRetryDelay time.Duration + gatewayCA string ) var ( @@ -61,6 +66,10 @@ func newGatewayStartCommand() *cobra.Command { } cmd.Flags().StringVar(&gatewayListenAddr, "listen-addr", "127.0.0.1:23790", "listen address") + cmd.Flags().StringVar(&gatewayDNSCluster, "discovery-srv", "", "DNS domain used to bootstrap initial cluster") + cmd.Flags().BoolVar(&gatewayInsecureDiscovery, "insecure-discovery", false, "accept insecure SRV records") + cmd.Flags().StringVar(&gatewayCA, "trusted-ca-file", "", "path to the client server TLS CA file.") + cmd.Flags().StringSliceVar(&gatewayEndpoints, "endpoints", []string{"127.0.0.1:2379"}, "comma separated etcd cluster endpoints") cmd.Flags().DurationVar(&getewayRetryDelay, "retry-delay", time.Minute, "duration of delay before retrying failed endpoints") @@ -68,6 +77,33 @@ func newGatewayStartCommand() *cobra.Command { } func startGateway(cmd *cobra.Command, args []string) { + endpoints := gatewayEndpoints + if gatewayDNSCluster != "" { + eps, err := client.NewSRVDiscover().Discover(gatewayDNSCluster) + if err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } + plog.Infof("discovered the cluster %s from %s", eps, gatewayDNSCluster) + // confirm TLS connections are good + if !gatewayInsecureDiscovery { + tlsInfo := transport.TLSInfo{ + TrustedCAFile: gatewayCA, + ServerName: gatewayDNSCluster, + } + plog.Infof("validating discovered endpoints %v", eps) + endpoints, err = transport.ValidateSecureEndpoints(tlsInfo, eps) + if err != nil { + plog.Warningf("%v", err) + } + plog.Infof("using discovered endpoints %v", endpoints) + } + } + + if len(endpoints) == 0 { + plog.Fatalf("no endpoints found") + } + l, err := net.Listen("tcp", gatewayListenAddr) if err != nil { fmt.Fprintln(os.Stderr, err)