diff --git a/clientv3/client.go b/clientv3/client.go index 28041ffb3..ef5e7ea24 100644 --- a/clientv3/client.go +++ b/clientv3/client.go @@ -384,5 +384,5 @@ func dialEndpointList(c *Client) (*grpc.ClientConn, error) { // progress can be made, even after reconnecting. func isHaltErr(ctx context.Context, err error) bool { isRPCError := strings.HasPrefix(grpc.ErrorDesc(err), "etcdserver: ") - return isRPCError || ctx.Err() != nil + return isRPCError || ctx.Err() != nil || err == rpctypes.ErrConnClosed } diff --git a/clientv3/integration/kv_test.go b/clientv3/integration/kv_test.go index 391022bef..8a5050aab 100644 --- a/clientv3/integration/kv_test.go +++ b/clientv3/integration/kv_test.go @@ -279,6 +279,42 @@ func TestKVRange(t *testing.T) { } } +func TestKVGetErrConnClosed(t *testing.T) { + defer testutil.AfterTest(t) + + clus := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 1}) + defer clus.Terminate(t) + + cli := clus.Client(0) + kv := clientv3.NewKV(cli) + + closed, donec := make(chan struct{}), make(chan struct{}) + go func() { + select { + case <-time.After(3 * time.Second): + t.Fatal("cli.Close took too long") + case <-closed: + } + + if _, err := kv.Get(context.TODO(), "foo"); err != rpctypes.ErrConnClosed { + t.Fatalf("expected %v, got %v", rpctypes.ErrConnClosed, err) + } + close(donec) + }() + + if err := cli.Close(); err != nil { + t.Fatal(err) + } + clus.TakeClient(0) + close(closed) + + select { + case <-time.After(3 * time.Second): + t.Fatal("kv.Get took too long") + case <-donec: + } +} + func TestKVDeleteRange(t *testing.T) { defer testutil.AfterTest(t) diff --git a/clientv3/kv.go b/clientv3/kv.go index 7b4082a08..41ac5113f 100644 --- a/clientv3/kv.go +++ b/clientv3/kv.go @@ -189,6 +189,7 @@ func (kv *kv) do(ctx context.Context, op Op) (OpResponse, error) { return OpResponse{}, err } +// getRemote must be followed by kv.rc.release() call. func (kv *kv) getRemote(ctx context.Context) (pb.KVClient, error) { if err := kv.rc.acquire(ctx); err != nil { return nil, err diff --git a/clientv3/remote_client.go b/clientv3/remote_client.go index 216d7ed08..b8209b8a5 100644 --- a/clientv3/remote_client.go +++ b/clientv3/remote_client.go @@ -17,6 +17,8 @@ package clientv3 import ( "sync" + "github.com/coreos/etcd/etcdserver/api/v3rpc/rpctypes" + "golang.org/x/net/context" "google.golang.org/grpc" ) @@ -81,10 +83,14 @@ func (r *remoteClient) tryUpdate() bool { func (r *remoteClient) acquire(ctx context.Context) error { for { r.client.mu.RLock() + closed := r.client.cancel == nil c := r.client.conn r.mu.Lock() match := r.conn == c r.mu.Unlock() + if closed { + return rpctypes.ErrConnClosed + } if match { return nil } diff --git a/etcdserver/api/v3rpc/rpctypes/error.go b/etcdserver/api/v3rpc/rpctypes/error.go index 42c8c5a13..3868eff26 100644 --- a/etcdserver/api/v3rpc/rpctypes/error.go +++ b/etcdserver/api/v3rpc/rpctypes/error.go @@ -104,6 +104,8 @@ var ( ErrNoLeader = Error(ErrGRPCNoLeader) ErrNotCapable = Error(ErrGRPCNotCapable) + + ErrConnClosed = EtcdError{code: codes.Unavailable, desc: "clientv3: connection closed"} ) // EtcdError defines gRPC server errors. diff --git a/integration/cluster.go b/integration/cluster.go index 68f6bd112..c339e37c9 100644 --- a/integration/cluster.go +++ b/integration/cluster.go @@ -27,6 +27,7 @@ import ( "sort" "strconv" "strings" + "sync" "sync/atomic" "testing" "time" @@ -749,6 +750,8 @@ func (p SortableMemberSliceByPeerURLs) Swap(i, j int) { p[i], p[j] = p[j], p[i] type ClusterV3 struct { *cluster + + mu sync.Mutex clients []*clientv3.Client } @@ -756,7 +759,9 @@ type ClusterV3 struct { // for each cluster member. func NewClusterV3(t *testing.T, cfg *ClusterConfig) *ClusterV3 { cfg.UseGRPC = true - clus := &ClusterV3{cluster: NewClusterByConfig(t, cfg)} + clus := &ClusterV3{ + cluster: NewClusterByConfig(t, cfg), + } for _, m := range clus.Members { client, err := NewClientV3(m) if err != nil { @@ -769,12 +774,23 @@ func NewClusterV3(t *testing.T, cfg *ClusterConfig) *ClusterV3 { return clus } +func (c *ClusterV3) TakeClient(idx int) { + c.mu.Lock() + c.clients[idx] = nil + c.mu.Unlock() +} + func (c *ClusterV3) Terminate(t *testing.T) { + c.mu.Lock() for _, client := range c.clients { + if client == nil { + continue + } if err := client.Close(); err != nil { t.Error(err) } } + c.mu.Unlock() c.cluster.Terminate(t) }