diff --git a/clientv3/client.go b/clientv3/client.go index 28041ffb3..ef5e7ea24 100644 --- a/clientv3/client.go +++ b/clientv3/client.go @@ -384,5 +384,5 @@ func dialEndpointList(c *Client) (*grpc.ClientConn, error) { // progress can be made, even after reconnecting. func isHaltErr(ctx context.Context, err error) bool { isRPCError := strings.HasPrefix(grpc.ErrorDesc(err), "etcdserver: ") - return isRPCError || ctx.Err() != nil + return isRPCError || ctx.Err() != nil || err == rpctypes.ErrConnClosed } diff --git a/clientv3/integration/kv_test.go b/clientv3/integration/kv_test.go index 391022bef..8a5050aab 100644 --- a/clientv3/integration/kv_test.go +++ b/clientv3/integration/kv_test.go @@ -279,6 +279,42 @@ func TestKVRange(t *testing.T) { } } +func TestKVGetErrConnClosed(t *testing.T) { + defer testutil.AfterTest(t) + + clus := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 1}) + defer clus.Terminate(t) + + cli := clus.Client(0) + kv := clientv3.NewKV(cli) + + closed, donec := make(chan struct{}), make(chan struct{}) + go func() { + select { + case <-time.After(3 * time.Second): + t.Fatal("cli.Close took too long") + case <-closed: + } + + if _, err := kv.Get(context.TODO(), "foo"); err != rpctypes.ErrConnClosed { + t.Fatalf("expected %v, got %v", rpctypes.ErrConnClosed, err) + } + close(donec) + }() + + if err := cli.Close(); err != nil { + t.Fatal(err) + } + clus.TakeClient(0) + close(closed) + + select { + case <-time.After(3 * time.Second): + t.Fatal("kv.Get took too long") + case <-donec: + } +} + func TestKVDeleteRange(t *testing.T) { defer testutil.AfterTest(t) diff --git a/clientv3/remote_client.go b/clientv3/remote_client.go index 216d7ed08..b8209b8a5 100644 --- a/clientv3/remote_client.go +++ b/clientv3/remote_client.go @@ -17,6 +17,8 @@ package clientv3 import ( "sync" + "github.com/coreos/etcd/etcdserver/api/v3rpc/rpctypes" + "golang.org/x/net/context" "google.golang.org/grpc" ) @@ -81,10 +83,14 @@ func (r *remoteClient) tryUpdate() bool { func (r *remoteClient) acquire(ctx context.Context) error { for { r.client.mu.RLock() + closed := r.client.cancel == nil c := r.client.conn r.mu.Lock() match := r.conn == c r.mu.Unlock() + if closed { + return rpctypes.ErrConnClosed + } if match { return nil }