From df55438a609f095abb534b26c3438f36f629d7ee Mon Sep 17 00:00:00 2001 From: fanmin shi Date: Fri, 13 Jan 2017 14:31:27 -0800 Subject: [PATCH] clientv3: balancer uses one connection at a time FIX #7080 --- clientv3/balancer.go | 57 +++++++++++++++++++++------------------ clientv3/balancer_test.go | 18 +++++++++++++ 2 files changed, 49 insertions(+), 26 deletions(-) diff --git a/clientv3/balancer.go b/clientv3/balancer.go index 0fef9c549..5ab9eb987 100644 --- a/clientv3/balancer.go +++ b/clientv3/balancer.go @@ -43,8 +43,7 @@ type simpleBalancer struct { // mu protects upEps, pinAddr, and connectingAddr mu sync.RWMutex - // upEps holds the current endpoints that have an active connection - upEps map[string]struct{} + // upc closes when upEps transitions from empty to non-zero or the balancer closes. upc chan struct{} @@ -71,7 +70,6 @@ func newSimpleBalancer(eps []string) *simpleBalancer { addrs: addrs, notifyCh: notifyCh, readyc: make(chan struct{}), - upEps: make(map[string]struct{}), upc: make(chan struct{}), host2ep: getHost2ep(eps), } @@ -140,48 +138,45 @@ func (b *simpleBalancer) Up(addr grpc.Address) func(error) { return func(err error) {} } - if len(b.upEps) == 0 { + if b.pinAddr == "" { // notify waiting Get()s and pin first connected address close(b.upc) b.pinAddr = addr.Addr + // notify client that a connection is up + b.readyOnce.Do(func() { close(b.readyc) }) + // close opened connections that are not pinAddr + // this ensures only one connection is open per client + b.notifyCh <- []grpc.Address{addr} } - b.upEps[addr.Addr] = struct{}{} - - // notify client that a connection is up - b.readyOnce.Do(func() { close(b.readyc) }) return func(err error) { b.mu.Lock() - delete(b.upEps, addr.Addr) - if len(b.upEps) == 0 && b.pinAddr != "" { + if b.pinAddr == addr.Addr { b.upc = make(chan struct{}) - } else if b.pinAddr == addr.Addr { - // choose new random up endpoint - for k := range b.upEps { - b.pinAddr = k - break - } + b.pinAddr = "" + b.notifyCh <- b.addrs } b.mu.Unlock() } } func (b *simpleBalancer) Get(ctx context.Context, opts grpc.BalancerGetOptions) (grpc.Address, func(), error) { - var addr string + var ( + addr string + closed bool + ) // If opts.BlockingWait is false (for fail-fast RPCs), it should return // an address it has notified via Notify immediately instead of blocking. if !opts.BlockingWait { b.mu.RLock() - closed := b.closed + closed = b.closed addr = b.pinAddr - upEps := len(b.upEps) b.mu.RUnlock() if closed { return grpc.Address{Addr: ""}, nil, grpc.ErrClientConnClosing } - - if upEps == 0 { + if addr == "" { return grpc.Address{Addr: ""}, nil, ErrNoAddrAvilable } return grpc.Address{Addr: addr}, func() {}, nil @@ -197,13 +192,14 @@ func (b *simpleBalancer) Get(ctx context.Context, opts grpc.BalancerGetOptions) return grpc.Address{Addr: ""}, nil, ctx.Err() } b.mu.RLock() + closed = b.closed addr = b.pinAddr - upEps := len(b.upEps) b.mu.RUnlock() - if addr == "" { + // Close() which sets b.closed = true can be called before Get(), Get() must exit if balancer is closed. + if closed { return grpc.Address{Addr: ""}, nil, grpc.ErrClientConnClosing } - if upEps > 0 { + if addr != "" { break } } @@ -222,9 +218,18 @@ func (b *simpleBalancer) Close() error { } b.closed = true close(b.notifyCh) - // terminate all waiting Get()s b.pinAddr = "" - if len(b.upEps) == 0 { + + // In the case of follwing scenerio: + // 1. upc is not closed; no pinned address + // 2. client issues an rpc, calling invoke(), which calls Get(), enters for loop, blocks + // 3. clientconn.Close() calls balancer.Close(); closed = true + // 4. for loop in Get() never exits since ctx is the context passed in by the client and may not be canceled + // we must close upc so Get() exits from blocking on upc + select { + case <-b.upc: + default: + // terminate all waiting Get()s close(b.upc) } return nil diff --git a/clientv3/balancer_test.go b/clientv3/balancer_test.go index f75ceaaa5..5009d94eb 100644 --- a/clientv3/balancer_test.go +++ b/clientv3/balancer_test.go @@ -29,6 +29,9 @@ var ( func TestBalancerGetUnblocking(t *testing.T) { sb := newSimpleBalancer(endpoints) + if addrs := <-sb.Notify(); len(addrs) != len(endpoints) { + t.Errorf("Initialize newSimpleBalancer should have triggered Notify() chan, but it didn't") + } unblockingOpts := grpc.BalancerGetOptions{BlockingWait: false} _, _, err := sb.Get(context.Background(), unblockingOpts) @@ -37,6 +40,9 @@ func TestBalancerGetUnblocking(t *testing.T) { } down1 := sb.Up(grpc.Address{Addr: endpoints[1]}) + if addrs := <-sb.Notify(); len(addrs) != 1 { + t.Errorf("first Up() should have triggered balancer to send the first connected address via Notify chan so that other connections can be closed") + } down2 := sb.Up(grpc.Address{Addr: endpoints[2]}) addrFirst, putFun, err := sb.Get(context.Background(), unblockingOpts) if err != nil { @@ -54,6 +60,9 @@ func TestBalancerGetUnblocking(t *testing.T) { } down1(errors.New("error")) + if addrs := <-sb.Notify(); len(addrs) != len(endpoints) { + t.Errorf("closing the only connection should triggered balancer to send the all endpoints via Notify chan so that we can establish a connection") + } down2(errors.New("error")) _, _, err = sb.Get(context.Background(), unblockingOpts) if err != ErrNoAddrAvilable { @@ -63,6 +72,9 @@ func TestBalancerGetUnblocking(t *testing.T) { func TestBalancerGetBlocking(t *testing.T) { sb := newSimpleBalancer(endpoints) + if addrs := <-sb.Notify(); len(addrs) != len(endpoints) { + t.Errorf("Initialize newSimpleBalancer should have triggered Notify() chan, but it didn't") + } blockingOpts := grpc.BalancerGetOptions{BlockingWait: true} ctx, _ := context.WithTimeout(context.Background(), time.Millisecond*100) @@ -77,6 +89,9 @@ func TestBalancerGetBlocking(t *testing.T) { // ensure sb.Up() will be called after sb.Get() to see if Up() releases blocking Get() time.Sleep(time.Millisecond * 100) downC <- sb.Up(grpc.Address{Addr: endpoints[1]}) + if addrs := <-sb.Notify(); len(addrs) != 1 { + t.Errorf("first Up() should have triggered balancer to send the first connected address via Notify chan so that other connections can be closed") + } }() addrFirst, putFun, err := sb.Get(context.Background(), blockingOpts) if err != nil { @@ -97,6 +112,9 @@ func TestBalancerGetBlocking(t *testing.T) { } down1(errors.New("error")) + if addrs := <-sb.Notify(); len(addrs) != len(endpoints) { + t.Errorf("closing the only connection should triggered balancer to send the all endpoints via Notify chan so that we can establish a connection") + } down2(errors.New("error")) ctx, _ = context.WithTimeout(context.Background(), time.Millisecond*100) _, _, err = sb.Get(ctx, blockingOpts)