From 4a13c9f9b31c65c7cc567055771fa2c873c7e7ed Mon Sep 17 00:00:00 2001 From: Anthony Romano Date: Tue, 7 Jun 2016 16:28:08 -0700 Subject: [PATCH] clientv3: use grpc balancer --- clientv3/balancer.go | 64 +++++++++++++++++++++++++++++++++ clientv3/client.go | 86 +++++++++++++++++++++++++++++--------------- 2 files changed, 121 insertions(+), 29 deletions(-) create mode 100644 clientv3/balancer.go diff --git a/clientv3/balancer.go b/clientv3/balancer.go new file mode 100644 index 000000000..31871b8a4 --- /dev/null +++ b/clientv3/balancer.go @@ -0,0 +1,64 @@ +// Copyright 2016 The etcd Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package clientv3 + +import ( + "net/url" + "strings" + "sync/atomic" + + "golang.org/x/net/context" + "google.golang.org/grpc" +) + +// simpleBalancer does the bare minimum to expose multiple eps +// to the grpc reconnection code path +type simpleBalancer struct { + // eps are the client's endpoints stripped of any URL scheme + eps []string + ch chan []grpc.Address + numGets uint32 +} + +func newSimpleBalancer(eps []string) grpc.Balancer { + ch := make(chan []grpc.Address, 1) + addrs := make([]grpc.Address, len(eps)) + for i := range eps { + addrs[i].Addr = getHost(eps[i]) + } + ch <- addrs + return &simpleBalancer{eps: eps, ch: ch} +} + +func (b *simpleBalancer) Start(target string) error { return nil } +func (b *simpleBalancer) Up(addr grpc.Address) func(error) { return func(error) {} } +func (b *simpleBalancer) Get(ctx context.Context, opts grpc.BalancerGetOptions) (grpc.Address, func(), error) { + v := atomic.AddUint32(&b.numGets, 1) + ep := b.eps[v%uint32(len(b.eps))] + return grpc.Address{Addr: getHost(ep)}, func() {}, nil +} +func (b *simpleBalancer) Notify() <-chan []grpc.Address { return b.ch } +func (b *simpleBalancer) Close() error { + close(b.ch) + return nil +} + +func getHost(ep string) string { + url, uerr := url.Parse(ep) + if uerr != nil || !strings.Contains(ep, "://") { + return ep + } + return url.Host +} diff --git a/clientv3/client.go b/clientv3/client.go index b1abeb503..b13a770c1 100644 --- a/clientv3/client.go +++ b/clientv3/client.go @@ -46,9 +46,9 @@ type Client struct { Auth Maintenance - conn *grpc.ClientConn - cfg Config - creds *credentials.TransportAuthenticator + conn *grpc.ClientConn + cfg Config + creds *credentials.TransportAuthenticator ctx context.Context cancel context.CancelFunc @@ -110,43 +110,70 @@ func (cred authTokenCredential) GetRequestMetadata(ctx context.Context, s ...str }, nil } -// Dial establishes a connection for a given endpoint using the client's config +func (c *Client) dialTarget(endpoint string) (proto string, host string, creds *credentials.TransportAuthenticator) { + proto = "tcp" + host = endpoint + creds = c.creds + url, uerr := url.Parse(endpoint) + if uerr != nil || !strings.Contains(endpoint, "://") { + return + } + // strip scheme:// prefix since grpc dials by host + host = url.Host + switch url.Scheme { + case "unix": + proto = "unix" + case "http": + creds = nil + case "https": + if creds != nil { + break + } + tlsconfig := &tls.Config{} + emptyCreds := credentials.NewTLS(tlsconfig) + creds = &emptyCreds + default: + return "", "", nil + } + return +} + +// Dial connects to a single endpoint using the client's config. func (c *Client) Dial(endpoint string) (*grpc.ClientConn, error) { + return c.dial(endpoint) +} + +func (c *Client) dial(endpoint string, dopts ...grpc.DialOption) (*grpc.ClientConn, error) { opts := []grpc.DialOption{ grpc.WithBlock(), grpc.WithTimeout(c.cfg.DialTimeout), } + opts = append(opts, dopts...) - proto := "tcp" - creds := c.creds - if url, uerr := url.Parse(endpoint); uerr == nil && strings.Contains(endpoint, "://") { - switch url.Scheme { - case "unix": - proto = "unix" - case "http": - creds = nil - case "https": - if creds == nil { - tlsconfig := &tls.Config{InsecureSkipVerify: true} - emptyCreds := credentials.NewTLS(tlsconfig) - creds = &emptyCreds - } - default: - return nil, fmt.Errorf("unknown scheme %q for %q", url.Scheme, endpoint) - } - // strip scheme:// prefix since grpc dials by host - endpoint = url.Host + // grpc issues TLS cert checks using the string passed into dial so + // that string must be the host. To recover the full scheme://host URL, + // have a map from hosts to the original endpoint. + host2ep := make(map[string]string) + for i := range c.cfg.Endpoints { + _, host, _ := c.dialTarget(c.cfg.Endpoints[i]) + host2ep[host] = c.cfg.Endpoints[i] } - f := func(a string, t time.Duration) (net.Conn, error) { + + f := func(host string, t time.Duration) (net.Conn, error) { + proto, host, _ := c.dialTarget(host2ep[host]) + if proto == "" { + return nil, fmt.Errorf("unknown scheme for %q", endpoint) + } select { case <-c.ctx.Done(): return nil, c.ctx.Err() default: } - return net.DialTimeout(proto, a, t) + return net.DialTimeout(proto, host, t) } opts = append(opts, grpc.WithDialer(f)) + _, host, creds := c.dialTarget(endpoint) if creds != nil { opts = append(opts, grpc.WithTransportCredentials(*creds)) } else { @@ -154,7 +181,7 @@ func (c *Client) Dial(endpoint string) (*grpc.ClientConn, error) { } if c.Username != "" && c.Password != "" { - auth, err := newAuthenticator(endpoint, opts) + auth, err := newAuthenticator(host, opts) if err != nil { return nil, err } @@ -168,7 +195,7 @@ func (c *Client) Dial(endpoint string) (*grpc.ClientConn, error) { opts = append(opts, grpc.WithPerRPCCredentials(authTokenCredential{token: resp.Token})) } - conn, err := grpc.Dial(endpoint, opts...) + conn, err := grpc.Dial(host, opts...) if err != nil { return nil, err } @@ -205,8 +232,9 @@ func newClient(cfg *Config) (*Client, error) { client.Username = cfg.Username client.Password = cfg.Password } - // TODO: use grpc balancer - conn, err := client.Dial(cfg.Endpoints[0]) + + b := newSimpleBalancer(cfg.Endpoints) + conn, err := client.dial(cfg.Endpoints[0], grpc.WithBalancer(b)) if err != nil { return nil, err }