etcdmain: check TLS on gateway SRV records

This commit is contained in:
Anthony Romano 2016-08-02 16:52:05 -07:00
parent e218834b58
commit ab4ac828f3

View File

@ -21,15 +21,18 @@ import (
"time" "time"
"github.com/coreos/etcd/client" "github.com/coreos/etcd/client"
"github.com/coreos/etcd/pkg/transport"
"github.com/coreos/etcd/proxy/tcpproxy" "github.com/coreos/etcd/proxy/tcpproxy"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
var ( var (
gatewayListenAddr string gatewayListenAddr string
gatewayEndpoints []string gatewayEndpoints []string
gatewayDNSCluster string gatewayDNSCluster string
getewayRetryDelay time.Duration gatewayInsecureDiscovery bool
getewayRetryDelay time.Duration
gatewayCA string
) )
var ( var (
@ -64,6 +67,8 @@ func newGatewayStartCommand() *cobra.Command {
cmd.Flags().StringVar(&gatewayListenAddr, "listen-addr", "127.0.0.1:23790", "listen address") cmd.Flags().StringVar(&gatewayListenAddr, "listen-addr", "127.0.0.1:23790", "listen address")
cmd.Flags().StringVar(&gatewayDNSCluster, "discovery-srv", "", "DNS domain used to bootstrap initial cluster") cmd.Flags().StringVar(&gatewayDNSCluster, "discovery-srv", "", "DNS domain used to bootstrap initial cluster")
cmd.Flags().BoolVar(&gatewayInsecureDiscovery, "insecure-discovery", false, "accept insecure SRV records")
cmd.Flags().StringVar(&gatewayCA, "trusted-ca-file", "", "path to the client server TLS CA file.")
cmd.Flags().StringSliceVar(&gatewayEndpoints, "endpoints", []string{"127.0.0.1:2379"}, "comma separated etcd cluster endpoints") cmd.Flags().StringSliceVar(&gatewayEndpoints, "endpoints", []string{"127.0.0.1:2379"}, "comma separated etcd cluster endpoints")
@ -81,6 +86,23 @@ func startGateway(cmd *cobra.Command, args []string) {
os.Exit(1) os.Exit(1)
} }
plog.Infof("discovered the cluster %s from %s", eps, gatewayDNSCluster) plog.Infof("discovered the cluster %s from %s", eps, gatewayDNSCluster)
// confirm TLS connections are good
if !gatewayInsecureDiscovery {
tlsInfo := transport.TLSInfo{
TrustedCAFile: gatewayCA,
ServerName: gatewayDNSCluster,
}
plog.Infof("validating discovered endpoints %v", eps)
endpoints, err = transport.ValidateSecureEndpoints(tlsInfo, eps)
if err != nil {
plog.Warningf("%v", err)
}
plog.Infof("using discovered endpoints %v", endpoints)
}
}
if len(endpoints) == 0 {
plog.Fatalf("no endpoints found")
} }
l, err := net.Listen("tcp", gatewayListenAddr) l, err := net.Listen("tcp", gatewayListenAddr)