From f45542394b386ac651634008d7f9dc503b57c920 Mon Sep 17 00:00:00 2001 From: Gyu-Ho Lee Date: Mon, 3 Oct 2016 01:03:28 -0700 Subject: [PATCH 1/2] clientv3: handle 'https' scheme in endpoint --- clientv3/client.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/clientv3/client.go b/clientv3/client.go index d8b04a43f..df4ab3478 100644 --- a/clientv3/client.go +++ b/clientv3/client.go @@ -151,14 +151,14 @@ func (cred authTokenCredential) GetRequestMetadata(ctx context.Context, s ...str }, nil } -func parseEndpoint(endpoint string) (proto string, host string, scheme bool) { +func parseEndpoint(endpoint string) (proto string, host string, scheme string) { proto = "tcp" host = endpoint url, uerr := url.Parse(endpoint) if uerr != nil || !strings.Contains(endpoint, "://") { return } - scheme = true + scheme = url.Scheme // strip scheme:// prefix since grpc dials by host host = url.Host @@ -172,9 +172,9 @@ func parseEndpoint(endpoint string) (proto string, host string, scheme bool) { return } -func (c *Client) processCreds(protocol string) (creds *credentials.TransportCredentials) { +func (c *Client) processCreds(scheme string) (creds *credentials.TransportCredentials) { creds = c.creds - switch protocol { + switch scheme { case "unix": case "http": creds = nil @@ -213,8 +213,8 @@ func (c *Client) dialSetupOpts(endpoint string, dopts ...grpc.DialOption) (opts opts = append(opts, grpc.WithDialer(f)) creds := c.creds - if proto, _, scheme := parseEndpoint(endpoint); scheme { - creds = c.processCreds(proto) + if _, _, scheme := parseEndpoint(endpoint); len(scheme) != 0 { + creds = c.processCreds(scheme) } if creds != nil { opts = append(opts, grpc.WithTransportCredentials(*creds)) From a96a28d6030807ffe0f0cd6c708581cdaebdbb87 Mon Sep 17 00:00:00 2001 From: Gyu-Ho Lee Date: Mon, 3 Oct 2016 01:25:32 -0700 Subject: [PATCH 2/2] clientv3/integration: add TestDialWithHTTPS --- clientv3/integration/dial_test.go | 72 +++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/clientv3/integration/dial_test.go b/clientv3/integration/dial_test.go index 9d9e6b47f..5f758ae0c 100644 --- a/clientv3/integration/dial_test.go +++ b/clientv3/integration/dial_test.go @@ -15,11 +15,17 @@ package integration import ( + "fmt" + "io/ioutil" "math/rand" + "net/url" + "os" + "sync" "testing" "time" "github.com/coreos/etcd/clientv3" + "github.com/coreos/etcd/embed" "github.com/coreos/etcd/integration" "github.com/coreos/etcd/pkg/testutil" "golang.org/x/net/context" @@ -58,3 +64,69 @@ func TestDialSetEndpoints(t *testing.T) { } cancel() } + +var ( + testMu sync.Mutex + testPort = 31000 +) + +// TestDialWithHTTPS ensures that client can handle 'https' scheme in endpoints. +func TestDialWithHTTPS(t *testing.T) { + defer testutil.AfterTest(t) + + testMu.Lock() + port := testPort + testPort += 10 // to avoid port conflicts + testMu.Unlock() + + dir, err := ioutil.TempDir(os.TempDir(), "dial-test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + // set up single-node cluster with client auto TLS + cfg := embed.NewConfig() + cfg.Dir = dir + + cfg.ClientAutoTLS = true + clientURL := url.URL{Scheme: "https", Host: fmt.Sprintf("localhost:%d", port)} + cfg.LCUrls, cfg.ACUrls = []url.URL{clientURL}, []url.URL{clientURL} + + peerURL := url.URL{Scheme: "http", Host: fmt.Sprintf("localhost:%d", port+1)} + cfg.LPUrls, cfg.APUrls = []url.URL{peerURL}, []url.URL{peerURL} + cfg.InitialCluster = cfg.Name + "=" + peerURL.String() + + srv, err := embed.StartEtcd(cfg) + if err != nil { + t.Fatal(err) + } + nc := srv.Config() // overwrite config after processing ClientTLSInfo + cfg = &nc + + <-srv.Server.ReadyNotify() + defer func() { + srv.Close() + <-srv.Err() + }() + + // wait for leader election to finish + time.Sleep(500 * time.Millisecond) + + ccfg := clientv3.Config{Endpoints: []string{clientURL.String()}} + tcfg, err := cfg.ClientTLSInfo.ClientConfig() + if err != nil { + t.Fatal(err) + } + ccfg.TLS = tcfg + + cli, err := clientv3.New(ccfg) + if err != nil { + t.Fatal(err) + } + defer cli.Close() + + if _, err = cli.Get(context.Background(), "foo"); err != nil { + t.Fatal(err) + } +}