diff --git a/clientv3/client.go b/clientv3/client.go index 559d33265..3df10272d 100644 --- a/clientv3/client.go +++ b/clientv3/client.go @@ -277,6 +277,7 @@ func (c *Client) retryConnection(err error) (newConn *grpc.ClientConn, dialErr e // wait so grpc doesn't leak sleeping goroutines c.conn.WaitForStateChange(context.Background(), st) } + c.conn = nil } if c.cancel == nil { // client has called Close() so don't try to dial out diff --git a/clientv3/kv.go b/clientv3/kv.go index efa377b1a..7b4082a08 100644 --- a/clientv3/kv.go +++ b/clientv3/kv.go @@ -105,7 +105,12 @@ func (kv *kv) Delete(ctx context.Context, key string, opts ...OpOption) (*Delete } func (kv *kv) Compact(ctx context.Context, rev int64) error { - _, err := kv.getRemote().Compact(ctx, &pb.CompactionRequest{Revision: rev}) + remote, err := kv.getRemote(ctx) + if err != nil { + return rpctypes.Error(err) + } + defer kv.rc.release() + _, err = remote.Compact(ctx, &pb.CompactionRequest{Revision: rev}) if err == nil { return nil } @@ -125,58 +130,68 @@ func (kv *kv) Txn(ctx context.Context) Txn { func (kv *kv) Do(ctx context.Context, op Op) (OpResponse, error) { for { - var err error - remote := kv.getRemote() - switch op.t { - // TODO: handle other ops - case tRange: - var resp *pb.RangeResponse - r := &pb.RangeRequest{Key: op.key, RangeEnd: op.end, Limit: op.limit, Revision: op.rev, Serializable: op.serializable} - if op.sort != nil { - r.SortOrder = pb.RangeRequest_SortOrder(op.sort.Order) - r.SortTarget = pb.RangeRequest_SortTarget(op.sort.Target) - } - - resp, err = remote.Range(ctx, r) - if err == nil { - return OpResponse{get: (*GetResponse)(resp)}, nil - } - case tPut: - var resp *pb.PutResponse - r := &pb.PutRequest{Key: op.key, Value: op.val, Lease: int64(op.leaseID)} - resp, err = remote.Put(ctx, r) - if err == nil { - return OpResponse{put: (*PutResponse)(resp)}, nil - } - case tDeleteRange: - var resp *pb.DeleteRangeResponse - r := &pb.DeleteRangeRequest{Key: op.key, RangeEnd: op.end} - resp, err = remote.DeleteRange(ctx, r) - if err == nil { - return OpResponse{del: (*DeleteResponse)(resp)}, nil - } - default: - panic("Unknown op") + resp, err := kv.do(ctx, op) + if err == nil { + return resp, nil } - if isHaltErr(ctx, err) { - return OpResponse{}, rpctypes.Error(err) + return resp, rpctypes.Error(err) } - // do not retry on modifications if op.isWrite() { kv.rc.reconnect(err) - return OpResponse{}, rpctypes.Error(err) + return resp, rpctypes.Error(err) } - if nerr := kv.rc.reconnectWait(ctx, err); nerr != nil { - return OpResponse{}, nerr + return resp, rpctypes.Error(nerr) } } } -func (kv *kv) getRemote() pb.KVClient { - kv.rc.mu.Lock() - defer kv.rc.mu.Unlock() - return kv.remote +func (kv *kv) do(ctx context.Context, op Op) (OpResponse, error) { + remote, err := kv.getRemote(ctx) + if err != nil { + return OpResponse{}, err + } + defer kv.rc.release() + + switch op.t { + // TODO: handle other ops + case tRange: + var resp *pb.RangeResponse + r := &pb.RangeRequest{Key: op.key, RangeEnd: op.end, Limit: op.limit, Revision: op.rev, Serializable: op.serializable} + if op.sort != nil { + r.SortOrder = pb.RangeRequest_SortOrder(op.sort.Order) + r.SortTarget = pb.RangeRequest_SortTarget(op.sort.Target) + } + + resp, err = remote.Range(ctx, r) + if err == nil { + return OpResponse{get: (*GetResponse)(resp)}, nil + } + case tPut: + var resp *pb.PutResponse + r := &pb.PutRequest{Key: op.key, Value: op.val, Lease: int64(op.leaseID)} + resp, err = remote.Put(ctx, r) + if err == nil { + return OpResponse{put: (*PutResponse)(resp)}, nil + } + case tDeleteRange: + var resp *pb.DeleteRangeResponse + r := &pb.DeleteRangeRequest{Key: op.key, RangeEnd: op.end} + resp, err = remote.DeleteRange(ctx, r) + if err == nil { + return OpResponse{del: (*DeleteResponse)(resp)}, nil + } + default: + panic("Unknown op") + } + return OpResponse{}, err +} + +func (kv *kv) getRemote(ctx context.Context) (pb.KVClient, error) { + if err := kv.rc.acquire(ctx); err != nil { + return nil, err + } + return kv.remote, nil } diff --git a/clientv3/remote_client.go b/clientv3/remote_client.go index bfdb7f937..216d7ed08 100644 --- a/clientv3/remote_client.go +++ b/clientv3/remote_client.go @@ -77,3 +77,22 @@ func (r *remoteClient) tryUpdate() bool { r.updateConn(activeConn) return true } + +func (r *remoteClient) acquire(ctx context.Context) error { + for { + r.client.mu.RLock() + c := r.client.conn + r.mu.Lock() + match := r.conn == c + r.mu.Unlock() + if match { + return nil + } + r.client.mu.RUnlock() + if err := r.reconnectWait(ctx, nil); err != nil { + return err + } + } +} + +func (r *remoteClient) release() { r.client.mu.RUnlock() } diff --git a/clientv3/txn.go b/clientv3/txn.go index dd84289a5..d08a46eb8 100644 --- a/clientv3/txn.go +++ b/clientv3/txn.go @@ -137,27 +137,35 @@ func (txn *txn) Else(ops ...Op) Txn { func (txn *txn) Commit() (*TxnResponse, error) { txn.mu.Lock() defer txn.mu.Unlock() - - kv := txn.kv - for { - r := &pb.TxnRequest{Compare: txn.cmps, Success: txn.sus, Failure: txn.fas} - resp, err := kv.getRemote().Txn(txn.ctx, r) + resp, err := txn.commit() if err == nil { - return (*TxnResponse)(resp), nil + return resp, err } - if isHaltErr(txn.ctx, err) { return nil, rpctypes.Error(err) } - if txn.isWrite { - kv.rc.reconnect(err) + txn.kv.rc.reconnect(err) return nil, rpctypes.Error(err) } - - if nerr := kv.rc.reconnectWait(txn.ctx, err); nerr != nil { + if nerr := txn.kv.rc.reconnectWait(txn.ctx, err); nerr != nil { return nil, nerr } } } + +func (txn *txn) commit() (*TxnResponse, error) { + rem, rerr := txn.kv.getRemote(txn.ctx) + if rerr != nil { + return nil, rerr + } + defer txn.kv.rc.release() + + r := &pb.TxnRequest{Compare: txn.cmps, Success: txn.sus, Failure: txn.fas} + resp, err := rem.Txn(txn.ctx, r) + if err != nil { + return nil, err + } + return (*TxnResponse)(resp), nil +}