clientv3: rewrite based on 3.4

Signed-off-by: Gyuho Lee <leegyuho@amazon.com>
This commit is contained in:
Gyuho Lee 2019-08-14 01:21:57 -07:00
parent a317433854
commit 9561f6b3b6
43 changed files with 3562 additions and 1681 deletions

View File

@ -1,85 +0,0 @@
# etcd/clientv3
[![Godoc](https://img.shields.io/badge/go-documentation-blue.svg?style=flat-square)](https://godoc.org/github.com/coreos/etcd/clientv3)
`etcd/clientv3` is the official Go etcd client for v3.
## Install
```bash
go get github.com/coreos/etcd/clientv3
```
## Get started
Create client using `clientv3.New`:
```go
cli, err := clientv3.New(clientv3.Config{
Endpoints: []string{"localhost:2379", "localhost:22379", "localhost:32379"},
DialTimeout: 5 * time.Second,
})
if err != nil {
// handle error!
}
defer cli.Close()
```
etcd v3 uses [`gRPC`](http://www.grpc.io) for remote procedure calls. And `clientv3` uses
[`grpc-go`](https://github.com/grpc/grpc-go) to connect to etcd. Make sure to close the client after using it.
If the client is not closed, the connection will have leaky goroutines. To specify client request timeout,
pass `context.WithTimeout` to APIs:
```go
ctx, cancel := context.WithTimeout(context.Background(), timeout)
resp, err := cli.Put(ctx, "sample_key", "sample_value")
cancel()
if err != nil {
// handle error!
}
// use the response
```
etcd uses `cmd/vendor` directory to store external dependencies, which are
to be compiled into etcd release binaries. `client` can be imported without
vendoring. For full compatibility, it is recommended to vendor builds using
etcd's vendored packages, using tools like godep, as in
[vendor directories](https://golang.org/cmd/go/#hdr-Vendor_Directories).
For more detail, please read [Go vendor design](https://golang.org/s/go15vendor).
## Error Handling
etcd client returns 2 types of errors:
1. context error: canceled or deadline exceeded.
2. gRPC error: see [api/v3rpc/rpctypes](https://godoc.org/github.com/coreos/etcd/etcdserver/api/v3rpc/rpctypes).
Here is the example code to handle client errors:
```go
resp, err := cli.Put(ctx, "", "")
if err != nil {
switch err {
case context.Canceled:
log.Fatalf("ctx is canceled by another routine: %v", err)
case context.DeadlineExceeded:
log.Fatalf("ctx is attached with a deadline is exceeded: %v", err)
case rpctypes.ErrEmptyKey:
log.Fatalf("client-side error: %v", err)
default:
log.Fatalf("bad cluster endpoints, which are not etcd servers: %v", err)
}
}
```
## Metrics
The etcd client optionally exposes RPC metrics through [go-grpc-prometheus](https://github.com/grpc-ecosystem/go-grpc-prometheus). See the [examples](https://github.com/coreos/etcd/blob/master/clientv3/example_metrics_test.go).
## Namespacing
The [namespace](https://godoc.org/github.com/coreos/etcd/clientv3/namespace) package provides `clientv3` interface wrappers to transparently isolate client requests to a user-defined prefix.
## Examples
More code examples can be found at [GoDoc](https://godoc.org/github.com/coreos/etcd/clientv3).

View File

@ -216,8 +216,8 @@ func (auth *authenticator) close() {
auth.conn.Close()
}
func newAuthenticator(endpoint string, opts []grpc.DialOption, c *Client) (*authenticator, error) {
conn, err := grpc.Dial(endpoint, opts...)
func newAuthenticator(ctx context.Context, target string, opts []grpc.DialOption, c *Client) (*authenticator, error) {
conn, err := grpc.DialContext(ctx, target, opts...)
if err != nil {
return nil, err
}

View File

@ -0,0 +1,293 @@
// Copyright 2018 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package balancer implements client balancer.
package balancer
import (
"strconv"
"sync"
"time"
"github.com/coreos/etcd/clientv3/balancer/connectivity"
"github.com/coreos/etcd/clientv3/balancer/picker"
"go.uber.org/zap"
"google.golang.org/grpc/balancer"
grpcconnectivity "google.golang.org/grpc/connectivity"
"google.golang.org/grpc/resolver"
_ "google.golang.org/grpc/resolver/dns" // register DNS resolver
_ "google.golang.org/grpc/resolver/passthrough" // register passthrough resolver
)
// Config defines balancer configurations.
type Config struct {
// Policy configures balancer policy.
Policy picker.Policy
// Picker implements gRPC picker.
// Leave empty if "Policy" field is not custom.
// TODO: currently custom policy is not supported.
// Picker picker.Picker
// Name defines an additional name for balancer.
// Useful for balancer testing to avoid register conflicts.
// If empty, defaults to policy name.
Name string
// Logger configures balancer logging.
// If nil, logs are discarded.
Logger *zap.Logger
}
// RegisterBuilder creates and registers a builder. Since this function calls balancer.Register, it
// must be invoked at initialization time.
func RegisterBuilder(cfg Config) {
bb := &builder{cfg}
balancer.Register(bb)
bb.cfg.Logger.Debug(
"registered balancer",
zap.String("policy", bb.cfg.Policy.String()),
zap.String("name", bb.cfg.Name),
)
}
type builder struct {
cfg Config
}
// Build is called initially when creating "ccBalancerWrapper".
// "grpc.Dial" is called to this client connection.
// Then, resolved addresses will be handled via "HandleResolvedAddrs".
func (b *builder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) balancer.Balancer {
bb := &baseBalancer{
id: strconv.FormatInt(time.Now().UnixNano(), 36),
policy: b.cfg.Policy,
name: b.cfg.Name,
lg: b.cfg.Logger,
addrToSc: make(map[resolver.Address]balancer.SubConn),
scToAddr: make(map[balancer.SubConn]resolver.Address),
scToSt: make(map[balancer.SubConn]grpcconnectivity.State),
currentConn: nil,
connectivityRecorder: connectivity.New(b.cfg.Logger),
// initialize picker always returns "ErrNoSubConnAvailable"
picker: picker.NewErr(balancer.ErrNoSubConnAvailable),
}
// TODO: support multiple connections
bb.mu.Lock()
bb.currentConn = cc
bb.mu.Unlock()
bb.lg.Info(
"built balancer",
zap.String("balancer-id", bb.id),
zap.String("policy", bb.policy.String()),
zap.String("resolver-target", cc.Target()),
)
return bb
}
// Name implements "grpc/balancer.Builder" interface.
func (b *builder) Name() string { return b.cfg.Name }
// Balancer defines client balancer interface.
type Balancer interface {
// Balancer is called on specified client connection. Client initiates gRPC
// connection with "grpc.Dial(addr, grpc.WithBalancerName)", and then those resolved
// addresses are passed to "grpc/balancer.Balancer.HandleResolvedAddrs".
// For each resolved address, balancer calls "balancer.ClientConn.NewSubConn".
// "grpc/balancer.Balancer.HandleSubConnStateChange" is called when connectivity state
// changes, thus requires failover logic in this method.
balancer.Balancer
// Picker calls "Pick" for every client request.
picker.Picker
}
type baseBalancer struct {
id string
policy picker.Policy
name string
lg *zap.Logger
mu sync.RWMutex
addrToSc map[resolver.Address]balancer.SubConn
scToAddr map[balancer.SubConn]resolver.Address
scToSt map[balancer.SubConn]grpcconnectivity.State
currentConn balancer.ClientConn
connectivityRecorder connectivity.Recorder
picker picker.Picker
}
// HandleResolvedAddrs implements "grpc/balancer.Balancer" interface.
// gRPC sends initial or updated resolved addresses from "Build".
func (bb *baseBalancer) HandleResolvedAddrs(addrs []resolver.Address, err error) {
if err != nil {
bb.lg.Warn("HandleResolvedAddrs called with error", zap.String("balancer-id", bb.id), zap.Error(err))
return
}
bb.lg.Info("resolved",
zap.String("picker", bb.picker.String()),
zap.String("balancer-id", bb.id),
zap.Strings("addresses", addrsToStrings(addrs)),
)
bb.mu.Lock()
defer bb.mu.Unlock()
resolved := make(map[resolver.Address]struct{})
for _, addr := range addrs {
resolved[addr] = struct{}{}
if _, ok := bb.addrToSc[addr]; !ok {
sc, err := bb.currentConn.NewSubConn([]resolver.Address{addr}, balancer.NewSubConnOptions{})
if err != nil {
bb.lg.Warn("NewSubConn failed", zap.String("picker", bb.picker.String()), zap.String("balancer-id", bb.id), zap.Error(err), zap.String("address", addr.Addr))
continue
}
bb.lg.Info("created subconn", zap.String("address", addr.Addr))
bb.addrToSc[addr] = sc
bb.scToAddr[sc] = addr
bb.scToSt[sc] = grpcconnectivity.Idle
sc.Connect()
}
}
for addr, sc := range bb.addrToSc {
if _, ok := resolved[addr]; !ok {
// was removed by resolver or failed to create subconn
bb.currentConn.RemoveSubConn(sc)
delete(bb.addrToSc, addr)
bb.lg.Info(
"removed subconn",
zap.String("picker", bb.picker.String()),
zap.String("balancer-id", bb.id),
zap.String("address", addr.Addr),
zap.String("subconn", scToString(sc)),
)
// Keep the state of this sc in bb.scToSt until sc's state becomes Shutdown.
// The entry will be deleted in HandleSubConnStateChange.
// (DO NOT) delete(bb.scToAddr, sc)
// (DO NOT) delete(bb.scToSt, sc)
}
}
}
// HandleSubConnStateChange implements "grpc/balancer.Balancer" interface.
func (bb *baseBalancer) HandleSubConnStateChange(sc balancer.SubConn, s grpcconnectivity.State) {
bb.mu.Lock()
defer bb.mu.Unlock()
old, ok := bb.scToSt[sc]
if !ok {
bb.lg.Warn(
"state change for an unknown subconn",
zap.String("picker", bb.picker.String()),
zap.String("balancer-id", bb.id),
zap.String("subconn", scToString(sc)),
zap.Int("subconn-size", len(bb.scToAddr)),
zap.String("state", s.String()),
)
return
}
bb.lg.Info(
"state changed",
zap.String("picker", bb.picker.String()),
zap.String("balancer-id", bb.id),
zap.Bool("connected", s == grpcconnectivity.Ready),
zap.String("subconn", scToString(sc)),
zap.Int("subconn-size", len(bb.scToAddr)),
zap.String("address", bb.scToAddr[sc].Addr),
zap.String("old-state", old.String()),
zap.String("new-state", s.String()),
)
bb.scToSt[sc] = s
switch s {
case grpcconnectivity.Idle:
sc.Connect()
case grpcconnectivity.Shutdown:
// When an address was removed by resolver, b called RemoveSubConn but
// kept the sc's state in scToSt. Remove state for this sc here.
delete(bb.scToAddr, sc)
delete(bb.scToSt, sc)
}
oldAggrState := bb.connectivityRecorder.GetCurrentState()
bb.connectivityRecorder.RecordTransition(old, s)
// Update balancer picker when one of the following happens:
// - this sc became ready from not-ready
// - this sc became not-ready from ready
// - the aggregated state of balancer became TransientFailure from non-TransientFailure
// - the aggregated state of balancer became non-TransientFailure from TransientFailure
if (s == grpcconnectivity.Ready) != (old == grpcconnectivity.Ready) ||
(bb.connectivityRecorder.GetCurrentState() == grpcconnectivity.TransientFailure) != (oldAggrState == grpcconnectivity.TransientFailure) {
bb.updatePicker()
}
bb.currentConn.UpdateBalancerState(bb.connectivityRecorder.GetCurrentState(), bb.picker)
}
func (bb *baseBalancer) updatePicker() {
if bb.connectivityRecorder.GetCurrentState() == grpcconnectivity.TransientFailure {
bb.picker = picker.NewErr(balancer.ErrTransientFailure)
bb.lg.Info(
"updated picker to transient error picker",
zap.String("picker", bb.picker.String()),
zap.String("balancer-id", bb.id),
zap.String("policy", bb.policy.String()),
)
return
}
// only pass ready subconns to picker
scToAddr := make(map[balancer.SubConn]resolver.Address)
for addr, sc := range bb.addrToSc {
if st, ok := bb.scToSt[sc]; ok && st == grpcconnectivity.Ready {
scToAddr[sc] = addr
}
}
bb.picker = picker.New(picker.Config{
Policy: bb.policy,
Logger: bb.lg,
SubConnToResolverAddress: scToAddr,
})
bb.lg.Info(
"updated picker",
zap.String("picker", bb.picker.String()),
zap.String("balancer-id", bb.id),
zap.String("policy", bb.policy.String()),
zap.Strings("subconn-ready", scsToStrings(scToAddr)),
zap.Int("subconn-size", len(scToAddr)),
)
}
// Close implements "grpc/balancer.Balancer" interface.
// Close is a nop because base balancer doesn't have internal state to clean up,
// and it doesn't need to call RemoveSubConn for the SubConns.
func (bb *baseBalancer) Close() {
// TODO
}

View File

@ -0,0 +1,310 @@
// Copyright 2018 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package balancer
import (
"context"
"fmt"
"strings"
"testing"
"time"
"github.com/coreos/etcd/clientv3/balancer/picker"
"github.com/coreos/etcd/clientv3/balancer/resolver/endpoint"
pb "github.com/coreos/etcd/etcdserver/etcdserverpb"
"github.com/coreos/etcd/pkg/mock/mockserver"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
)
// TestRoundRobinBalancedResolvableNoFailover ensures that
// requests to a resolvable endpoint can be balanced between
// multiple, if any, nodes. And there needs be no failover.
func TestRoundRobinBalancedResolvableNoFailover(t *testing.T) {
testCases := []struct {
name string
serverCount int
reqN int
network string
}{
{name: "rrBalanced_1", serverCount: 1, reqN: 5, network: "tcp"},
{name: "rrBalanced_1_unix_sockets", serverCount: 1, reqN: 5, network: "unix"},
{name: "rrBalanced_3", serverCount: 3, reqN: 7, network: "tcp"},
{name: "rrBalanced_5", serverCount: 5, reqN: 10, network: "tcp"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ms, err := mockserver.StartMockServersOnNetwork(tc.serverCount, tc.network)
if err != nil {
t.Fatalf("failed to start mock servers: %v", err)
}
defer ms.Stop()
var eps []string
for _, svr := range ms.Servers {
eps = append(eps, svr.ResolverAddress().Addr)
}
rsv, err := endpoint.NewResolverGroup("nofailover")
if err != nil {
t.Fatal(err)
}
defer rsv.Close()
rsv.SetEndpoints(eps)
name := genName()
cfg := Config{
Policy: picker.RoundrobinBalanced,
Name: name,
Logger: zap.NewExample(),
}
RegisterBuilder(cfg)
conn, err := grpc.Dial(fmt.Sprintf("endpoint://nofailover/*"), grpc.WithInsecure(), grpc.WithBalancerName(name))
if err != nil {
t.Fatalf("failed to dial mock server: %v", err)
}
defer conn.Close()
cli := pb.NewKVClient(conn)
reqFunc := func(ctx context.Context) (picked string, err error) {
var p peer.Peer
_, err = cli.Range(ctx, &pb.RangeRequest{Key: []byte("/x")}, grpc.Peer(&p))
if p.Addr != nil {
picked = p.Addr.String()
}
return picked, err
}
prev, switches := "", 0
for i := 0; i < tc.reqN; i++ {
picked, err := reqFunc(context.Background())
if err != nil {
t.Fatalf("#%d: unexpected failure %v", i, err)
}
if prev == "" {
prev = picked
continue
}
if prev != picked {
switches++
}
prev = picked
}
if tc.serverCount > 1 && switches < tc.reqN-3 { // -3 for initial resolutions
// TODO: FIX ME
t.Skipf("expected balanced loads for %d requests, got switches %d", tc.reqN, switches)
}
})
}
}
// TestRoundRobinBalancedResolvableFailoverFromServerFail ensures that
// loads be rebalanced while one server goes down and comes back.
func TestRoundRobinBalancedResolvableFailoverFromServerFail(t *testing.T) {
serverCount := 5
ms, err := mockserver.StartMockServers(serverCount)
if err != nil {
t.Fatalf("failed to start mock servers: %s", err)
}
defer ms.Stop()
var eps []string
for _, svr := range ms.Servers {
eps = append(eps, svr.ResolverAddress().Addr)
}
rsv, err := endpoint.NewResolverGroup("serverfail")
if err != nil {
t.Fatal(err)
}
defer rsv.Close()
rsv.SetEndpoints(eps)
name := genName()
cfg := Config{
Policy: picker.RoundrobinBalanced,
Name: name,
Logger: zap.NewExample(),
}
RegisterBuilder(cfg)
conn, err := grpc.Dial(fmt.Sprintf("endpoint://serverfail/mock.server"), grpc.WithInsecure(), grpc.WithBalancerName(name))
if err != nil {
t.Fatalf("failed to dial mock server: %s", err)
}
defer conn.Close()
cli := pb.NewKVClient(conn)
reqFunc := func(ctx context.Context) (picked string, err error) {
var p peer.Peer
_, err = cli.Range(ctx, &pb.RangeRequest{Key: []byte("/x")}, grpc.Peer(&p))
if p.Addr != nil {
picked = p.Addr.String()
}
return picked, err
}
// stop first server, loads should be redistributed
// stopped server should never be picked
ms.StopAt(0)
available := make(map[string]struct{})
for i := 1; i < serverCount; i++ {
available[eps[i]] = struct{}{}
}
reqN := 10
prev, switches := "", 0
for i := 0; i < reqN; i++ {
picked, err := reqFunc(context.Background())
if err != nil && strings.Contains(err.Error(), "transport is closing") {
continue
}
if prev == "" { // first failover
if eps[0] == picked {
t.Fatalf("expected failover from %q, picked %q", eps[0], picked)
}
prev = picked
continue
}
if _, ok := available[picked]; !ok {
t.Fatalf("picked unavailable address %q (available %v)", picked, available)
}
if prev != picked {
switches++
}
prev = picked
}
if switches < reqN-3 { // -3 for initial resolutions + failover
// TODO: FIX ME!
t.Skipf("expected balanced loads for %d requests, got switches %d", reqN, switches)
}
// now failed server comes back
ms.StartAt(0)
// enough time for reconnecting to recovered server
time.Sleep(time.Second)
prev, switches = "", 0
recoveredAddr, recovered := eps[0], 0
available[recoveredAddr] = struct{}{}
for i := 0; i < 2*reqN; i++ {
picked, err := reqFunc(context.Background())
if err != nil {
t.Fatalf("#%d: unexpected failure %v", i, err)
}
if prev == "" {
prev = picked
continue
}
if _, ok := available[picked]; !ok {
t.Fatalf("#%d: picked unavailable address %q (available %v)", i, picked, available)
}
if prev != picked {
switches++
}
if picked == recoveredAddr {
recovered++
}
prev = picked
}
if switches < reqN-3 { // -3 for initial resolutions
t.Fatalf("expected balanced loads for %d requests, got switches %d", reqN, switches)
}
if recovered < reqN/serverCount {
t.Fatalf("recovered server %q got only %d requests", recoveredAddr, recovered)
}
}
// TestRoundRobinBalancedResolvableFailoverFromRequestFail ensures that
// loads be rebalanced while some requests are failed.
func TestRoundRobinBalancedResolvableFailoverFromRequestFail(t *testing.T) {
serverCount := 5
ms, err := mockserver.StartMockServers(serverCount)
if err != nil {
t.Fatalf("failed to start mock servers: %s", err)
}
defer ms.Stop()
var eps []string
available := make(map[string]struct{})
for _, svr := range ms.Servers {
eps = append(eps, svr.ResolverAddress().Addr)
available[svr.Address] = struct{}{}
}
rsv, err := endpoint.NewResolverGroup("requestfail")
if err != nil {
t.Fatal(err)
}
defer rsv.Close()
rsv.SetEndpoints(eps)
name := genName()
cfg := Config{
Policy: picker.RoundrobinBalanced,
Name: name,
Logger: zap.NewExample(),
}
RegisterBuilder(cfg)
conn, err := grpc.Dial(fmt.Sprintf("endpoint://requestfail/mock.server"), grpc.WithInsecure(), grpc.WithBalancerName(name))
if err != nil {
t.Fatalf("failed to dial mock server: %s", err)
}
defer conn.Close()
cli := pb.NewKVClient(conn)
reqFunc := func(ctx context.Context) (picked string, err error) {
var p peer.Peer
_, err = cli.Range(ctx, &pb.RangeRequest{Key: []byte("/x")}, grpc.Peer(&p))
if p.Addr != nil {
picked = p.Addr.String()
}
return picked, err
}
reqN := 20
prev, switches := "", 0
for i := 0; i < reqN; i++ {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if i%2 == 0 {
cancel()
}
picked, err := reqFunc(ctx)
if i%2 == 0 {
if s, ok := status.FromError(err); ok && s.Code() != codes.Canceled || picked != "" {
t.Fatalf("#%d: expected %v, got %v", i, context.Canceled, err)
}
continue
}
if prev == "" && picked != "" {
prev = picked
continue
}
if _, ok := available[picked]; !ok {
t.Fatalf("#%d: picked unavailable address %q (available %v)", i, picked, available)
}
if prev != picked {
switches++
}
prev = picked
}
if switches < reqN/2-3 { // -3 for initial resolutions + failover
t.Fatalf("expected balanced loads for %d requests, got switches %d", reqN, switches)
}
}

View File

@ -0,0 +1,93 @@
// Copyright 2019 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package connectivity implements client connectivity operations.
package connectivity
import (
"sync"
"go.uber.org/zap"
"google.golang.org/grpc/connectivity"
)
// Recorder records gRPC connectivity.
type Recorder interface {
GetCurrentState() connectivity.State
RecordTransition(oldState, newState connectivity.State)
}
// New returns a new Recorder.
func New(lg *zap.Logger) Recorder {
return &recorder{lg: lg}
}
// recorder takes the connectivity states of multiple SubConns
// and returns one aggregated connectivity state.
// ref. https://github.com/grpc/grpc-go/blob/master/balancer/balancer.go
type recorder struct {
lg *zap.Logger
mu sync.RWMutex
cur connectivity.State
numReady uint64 // Number of addrConns in ready state.
numConnecting uint64 // Number of addrConns in connecting state.
numTransientFailure uint64 // Number of addrConns in transientFailure.
}
func (rc *recorder) GetCurrentState() (state connectivity.State) {
rc.mu.RLock()
defer rc.mu.RUnlock()
return rc.cur
}
// RecordTransition records state change happening in subConn and based on that
// it evaluates what aggregated state should be.
//
// - If at least one SubConn in Ready, the aggregated state is Ready;
// - Else if at least one SubConn in Connecting, the aggregated state is Connecting;
// - Else the aggregated state is TransientFailure.
//
// Idle and Shutdown are not considered.
//
// ref. https://github.com/grpc/grpc-go/blob/master/balancer/balancer.go
func (rc *recorder) RecordTransition(oldState, newState connectivity.State) {
rc.mu.Lock()
defer rc.mu.Unlock()
for idx, state := range []connectivity.State{oldState, newState} {
updateVal := 2*uint64(idx) - 1 // -1 for oldState and +1 for new.
switch state {
case connectivity.Ready:
rc.numReady += updateVal
case connectivity.Connecting:
rc.numConnecting += updateVal
case connectivity.TransientFailure:
rc.numTransientFailure += updateVal
default:
rc.lg.Warn("connectivity recorder received unknown state", zap.String("connectivity-state", state.String()))
}
}
switch { // must be exclusive, no overlap
case rc.numReady > 0:
rc.cur = connectivity.Ready
case rc.numConnecting > 0:
rc.cur = connectivity.Connecting
default:
rc.cur = connectivity.TransientFailure
}
}

View File

@ -1,4 +1,4 @@
// Copyright 2017 The etcd Authors
// Copyright 2018 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -12,19 +12,5 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package clientv3
import "context"
// TODO: remove this when "FailFast=false" is fixed.
// See https://github.com/grpc/grpc-go/issues/1532.
func readyWait(rpcCtx, clientCtx context.Context, ready <-chan struct{}) error {
select {
case <-ready:
return nil
case <-rpcCtx.Done():
return rpcCtx.Err()
case <-clientCtx.Done():
return clientCtx.Err()
}
}
// Package picker defines/implements client balancer picker policy.
package picker

View File

@ -0,0 +1,39 @@
// Copyright 2018 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package picker
import (
"context"
"google.golang.org/grpc/balancer"
)
// NewErr returns a picker that always returns err on "Pick".
func NewErr(err error) Picker {
return &errPicker{p: Error, err: err}
}
type errPicker struct {
p Policy
err error
}
func (ep *errPicker) String() string {
return ep.p.String()
}
func (ep *errPicker) Pick(context.Context, balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) {
return nil, nil, ep.err
}

View File

@ -0,0 +1,91 @@
// Copyright 2018 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package picker
import (
"fmt"
"go.uber.org/zap"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/resolver"
)
// Picker defines balancer Picker methods.
type Picker interface {
balancer.Picker
String() string
}
// Config defines picker configuration.
type Config struct {
// Policy specifies etcd clientv3's built in balancer policy.
Policy Policy
// Logger defines picker logging object.
Logger *zap.Logger
// SubConnToResolverAddress maps each gRPC sub-connection to an address.
// Basically, it is a list of addresses that the Picker can pick from.
SubConnToResolverAddress map[balancer.SubConn]resolver.Address
}
// Policy defines balancer picker policy.
type Policy uint8
const (
// Error is error picker policy.
Error Policy = iota
// RoundrobinBalanced balances loads over multiple endpoints
// and implements failover in roundrobin fashion.
RoundrobinBalanced
// Custom defines custom balancer picker.
// TODO: custom picker is not supported yet.
Custom
)
func (p Policy) String() string {
switch p {
case Error:
return "picker-error"
case RoundrobinBalanced:
return "picker-roundrobin-balanced"
case Custom:
panic("'custom' picker policy is not supported yet")
default:
panic(fmt.Errorf("invalid balancer picker policy (%d)", p))
}
}
// New creates a new Picker.
func New(cfg Config) Picker {
switch cfg.Policy {
case Error:
panic("'error' picker policy is not supported here; use 'picker.NewErr'")
case RoundrobinBalanced:
return newRoundrobinBalanced(cfg)
case Custom:
panic("'custom' picker policy is not supported yet")
default:
panic(fmt.Errorf("invalid balancer picker policy (%d)", cfg.Policy))
}
}

View File

@ -0,0 +1,95 @@
// Copyright 2018 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package picker
import (
"context"
"sync"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/resolver"
)
// newRoundrobinBalanced returns a new roundrobin balanced picker.
func newRoundrobinBalanced(cfg Config) Picker {
scs := make([]balancer.SubConn, 0, len(cfg.SubConnToResolverAddress))
for sc := range cfg.SubConnToResolverAddress {
scs = append(scs, sc)
}
return &rrBalanced{
p: RoundrobinBalanced,
lg: cfg.Logger,
scs: scs,
scToAddr: cfg.SubConnToResolverAddress,
}
}
type rrBalanced struct {
p Policy
lg *zap.Logger
mu sync.RWMutex
next int
scs []balancer.SubConn
scToAddr map[balancer.SubConn]resolver.Address
}
func (rb *rrBalanced) String() string { return rb.p.String() }
// Pick is called for every client request.
func (rb *rrBalanced) Pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) {
rb.mu.RLock()
n := len(rb.scs)
rb.mu.RUnlock()
if n == 0 {
return nil, nil, balancer.ErrNoSubConnAvailable
}
rb.mu.Lock()
cur := rb.next
sc := rb.scs[cur]
picked := rb.scToAddr[sc].Addr
rb.next = (rb.next + 1) % len(rb.scs)
rb.mu.Unlock()
rb.lg.Debug(
"picked",
zap.String("picker", rb.p.String()),
zap.String("address", picked),
zap.Int("subconn-index", cur),
zap.Int("subconn-size", n),
)
doneFunc := func(info balancer.DoneInfo) {
// TODO: error handling?
fss := []zapcore.Field{
zap.Error(info.Err),
zap.String("picker", rb.p.String()),
zap.String("address", picked),
zap.Bool("success", info.Err == nil),
zap.Bool("bytes-sent", info.BytesSent),
zap.Bool("bytes-received", info.BytesReceived),
}
if info.Err == nil {
rb.lg.Debug("balancer done", fss...)
} else {
rb.lg.Warn("balancer failed", fss...)
}
}
return sc, doneFunc, nil
}

View File

@ -0,0 +1,240 @@
// Copyright 2018 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package endpoint resolves etcd entpoints using grpc targets of the form 'endpoint://<id>/<endpoint>'.
package endpoint
import (
"fmt"
"net/url"
"strings"
"sync"
"google.golang.org/grpc/resolver"
)
const scheme = "endpoint"
var (
targetPrefix = fmt.Sprintf("%s://", scheme)
bldr *builder
)
func init() {
bldr = &builder{
resolverGroups: make(map[string]*ResolverGroup),
}
resolver.Register(bldr)
}
type builder struct {
mu sync.RWMutex
resolverGroups map[string]*ResolverGroup
}
// NewResolverGroup creates a new ResolverGroup with the given id.
func NewResolverGroup(id string) (*ResolverGroup, error) {
return bldr.newResolverGroup(id)
}
// ResolverGroup keeps all endpoints of resolvers using a common endpoint://<id>/ target
// up-to-date.
type ResolverGroup struct {
mu sync.RWMutex
id string
endpoints []string
resolvers []*Resolver
}
func (e *ResolverGroup) addResolver(r *Resolver) {
e.mu.Lock()
addrs := epsToAddrs(e.endpoints...)
e.resolvers = append(e.resolvers, r)
e.mu.Unlock()
r.cc.NewAddress(addrs)
}
func (e *ResolverGroup) removeResolver(r *Resolver) {
e.mu.Lock()
for i, er := range e.resolvers {
if er == r {
e.resolvers = append(e.resolvers[:i], e.resolvers[i+1:]...)
break
}
}
e.mu.Unlock()
}
// SetEndpoints updates the endpoints for ResolverGroup. All registered resolver are updated
// immediately with the new endpoints.
func (e *ResolverGroup) SetEndpoints(endpoints []string) {
addrs := epsToAddrs(endpoints...)
e.mu.Lock()
e.endpoints = endpoints
for _, r := range e.resolvers {
r.cc.NewAddress(addrs)
}
e.mu.Unlock()
}
// Target constructs a endpoint target using the endpoint id of the ResolverGroup.
func (e *ResolverGroup) Target(endpoint string) string {
return Target(e.id, endpoint)
}
// Target constructs a endpoint resolver target.
func Target(id, endpoint string) string {
return fmt.Sprintf("%s://%s/%s", scheme, id, endpoint)
}
// IsTarget checks if a given target string in an endpoint resolver target.
func IsTarget(target string) bool {
return strings.HasPrefix(target, "endpoint://")
}
func (e *ResolverGroup) Close() {
bldr.close(e.id)
}
// Build creates or reuses an etcd resolver for the etcd cluster name identified by the authority part of the target.
func (b *builder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOption) (resolver.Resolver, error) {
if len(target.Authority) < 1 {
return nil, fmt.Errorf("'etcd' target scheme requires non-empty authority identifying etcd cluster being routed to")
}
id := target.Authority
es, err := b.getResolverGroup(id)
if err != nil {
return nil, fmt.Errorf("failed to build resolver: %v", err)
}
r := &Resolver{
endpointID: id,
cc: cc,
}
es.addResolver(r)
return r, nil
}
func (b *builder) newResolverGroup(id string) (*ResolverGroup, error) {
b.mu.RLock()
_, ok := b.resolverGroups[id]
b.mu.RUnlock()
if ok {
return nil, fmt.Errorf("Endpoint already exists for id: %s", id)
}
es := &ResolverGroup{id: id}
b.mu.Lock()
b.resolverGroups[id] = es
b.mu.Unlock()
return es, nil
}
func (b *builder) getResolverGroup(id string) (*ResolverGroup, error) {
b.mu.RLock()
es, ok := b.resolverGroups[id]
b.mu.RUnlock()
if !ok {
return nil, fmt.Errorf("ResolverGroup not found for id: %s", id)
}
return es, nil
}
func (b *builder) close(id string) {
b.mu.Lock()
delete(b.resolverGroups, id)
b.mu.Unlock()
}
func (b *builder) Scheme() string {
return scheme
}
// Resolver provides a resolver for a single etcd cluster, identified by name.
type Resolver struct {
endpointID string
cc resolver.ClientConn
sync.RWMutex
}
// TODO: use balancer.epsToAddrs
func epsToAddrs(eps ...string) (addrs []resolver.Address) {
addrs = make([]resolver.Address, 0, len(eps))
for _, ep := range eps {
addrs = append(addrs, resolver.Address{Addr: ep})
}
return addrs
}
func (*Resolver) ResolveNow(o resolver.ResolveNowOption) {}
func (r *Resolver) Close() {
es, err := bldr.getResolverGroup(r.endpointID)
if err != nil {
return
}
es.removeResolver(r)
}
// ParseEndpoint endpoint parses an endpoint of the form
// (http|https)://<host>*|(unix|unixs)://<path>)
// and returns a protocol ('tcp' or 'unix'),
// host (or filepath if a unix socket),
// scheme (http, https, unix, unixs).
func ParseEndpoint(endpoint string) (proto string, host string, scheme string) {
proto = "tcp"
host = endpoint
url, uerr := url.Parse(endpoint)
if uerr != nil || !strings.Contains(endpoint, "://") {
return proto, host, scheme
}
scheme = url.Scheme
// strip scheme:// prefix since grpc dials by host
host = url.Host
switch url.Scheme {
case "http", "https":
case "unix", "unixs":
proto = "unix"
host = url.Host + url.Path
default:
proto, host = "", ""
}
return proto, host, scheme
}
// ParseTarget parses a endpoint://<id>/<endpoint> string and returns the parsed id and endpoint.
// If the target is malformed, an error is returned.
func ParseTarget(target string) (string, string, error) {
noPrefix := strings.TrimPrefix(target, targetPrefix)
if noPrefix == target {
return "", "", fmt.Errorf("malformed target, %s prefix is required: %s", targetPrefix, target)
}
parts := strings.SplitN(noPrefix, "/", 2)
if len(parts) != 2 {
return "", "", fmt.Errorf("malformed target, expected %s://<id>/<endpoint>, but got %s", scheme, target)
}
return parts[0], parts[1], nil
}
// ParseHostPort splits a "<host>:<port>" string into the host and port parts.
// The port part is optional.
func ParseHostPort(hostPort string) (host string, port string) {
parts := strings.SplitN(hostPort, ":", 2)
host = parts[0]
if len(parts) > 1 {
port = parts[1]
}
return host, port
}

View File

@ -0,0 +1,68 @@
// Copyright 2018 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package balancer
import (
"fmt"
"net/url"
"sort"
"sync/atomic"
"time"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/resolver"
)
func scToString(sc balancer.SubConn) string {
return fmt.Sprintf("%p", sc)
}
func scsToStrings(scs map[balancer.SubConn]resolver.Address) (ss []string) {
ss = make([]string, 0, len(scs))
for sc, a := range scs {
ss = append(ss, fmt.Sprintf("%s (%s)", a.Addr, scToString(sc)))
}
sort.Strings(ss)
return ss
}
func addrsToStrings(addrs []resolver.Address) (ss []string) {
ss = make([]string, len(addrs))
for i := range addrs {
ss[i] = addrs[i].Addr
}
sort.Strings(ss)
return ss
}
func epsToAddrs(eps ...string) (addrs []resolver.Address) {
addrs = make([]resolver.Address, 0, len(eps))
for _, ep := range eps {
u, err := url.Parse(ep)
if err != nil {
addrs = append(addrs, resolver.Address{Addr: ep, Type: resolver.Backend})
continue
}
addrs = append(addrs, resolver.Address{Addr: u.Host, Type: resolver.Backend})
}
return addrs
}
var genN = new(uint32)
func genName() string {
now := time.Now().UnixNano()
return fmt.Sprintf("%X%X", now, atomic.AddUint32(genN, 1))
}

View File

@ -0,0 +1,34 @@
// Copyright 2018 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package balancer
import (
"reflect"
"testing"
"google.golang.org/grpc/resolver"
)
func Test_epsToAddrs(t *testing.T) {
eps := []string{"https://example.com:2379", "127.0.0.1:2379"}
exp := []resolver.Address{
{Addr: "example.com:2379", Type: resolver.Backend},
{Addr: "127.0.0.1:2379", Type: resolver.Backend},
}
rs := epsToAddrs(eps...)
if !reflect.DeepEqual(rs, exp) {
t.Fatalf("expected %v, got %v", exp, rs)
}
}

View File

@ -16,21 +16,26 @@ package clientv3
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/url"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/coreos/etcd/clientv3/balancer"
"github.com/coreos/etcd/clientv3/balancer/picker"
"github.com/coreos/etcd/clientv3/balancer/resolver/endpoint"
"github.com/coreos/etcd/clientv3/credentials"
"github.com/coreos/etcd/etcdserver/api/v3rpc/rpctypes"
"github.com/coreos/etcd/pkg/logutil"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
grpccredentials "google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
@ -39,8 +44,31 @@ import (
var (
ErrNoAvailableEndpoints = errors.New("etcdclient: no available endpoints")
ErrOldCluster = errors.New("etcdclient: old cluster version")
roundRobinBalancerName = fmt.Sprintf("etcd-%s", picker.RoundrobinBalanced.String())
)
func init() {
lg := zap.NewNop()
if os.Getenv("ETCD_CLIENT_DEBUG") != "" {
lcfg := logutil.DefaultZapLoggerConfig
lcfg.Level = zap.NewAtomicLevelAt(zap.DebugLevel)
var err error
lg, err = lcfg.Build() // info level logging
if err != nil {
panic(err)
}
}
// TODO: support custom balancer
balancer.RegisterBuilder(balancer.Config{
Policy: picker.RoundrobinBalanced,
Name: roundRobinBalancerName,
Logger: lg,
})
}
// Client provides and manages an etcd v3 client session.
type Client struct {
Cluster
@ -50,13 +78,12 @@ type Client struct {
Auth
Maintenance
conn *grpc.ClientConn
dialerrc chan error
conn *grpc.ClientConn
cfg Config
creds *credentials.TransportCredentials
balancer *healthBalancer
mu *sync.RWMutex
cfg Config
creds grpccredentials.TransportCredentials
resolverGroup *endpoint.ResolverGroup
mu *sync.RWMutex
ctx context.Context
cancel context.CancelFunc
@ -64,11 +91,12 @@ type Client struct {
// Username is a user name for authentication.
Username string
// Password is a password for authentication.
Password string
// tokenCred is an instance of WithPerRPCCredentials()'s argument
tokenCred *authTokenCredential
Password string
authTokenBundle credentials.Bundle
callOpts []grpc.CallOption
lg *zap.Logger
}
// New creates a new etcdv3 client from a given configuration.
@ -93,11 +121,19 @@ func NewFromURL(url string) (*Client, error) {
return New(Config{Endpoints: []string{url}})
}
// NewFromURLs creates a new etcdv3 client from URLs.
func NewFromURLs(urls []string) (*Client, error) {
return New(Config{Endpoints: urls})
}
// Close shuts down the client's etcd connections.
func (c *Client) Close() error {
c.cancel()
c.Watcher.Close()
c.Lease.Close()
if c.resolverGroup != nil {
c.resolverGroup.Close()
}
if c.conn != nil {
return toErr(c.ctx, c.conn.Close())
}
@ -111,9 +147,9 @@ func (c *Client) Ctx() context.Context { return c.ctx }
// Endpoints lists the registered endpoints for the client.
func (c *Client) Endpoints() []string {
// copy the slice; protect original endpoints from being changed
c.mu.RLock()
defer c.mu.RUnlock()
// copy the slice; protect original endpoints from being changed
eps := make([]string, len(c.cfg.Endpoints))
copy(eps, c.cfg.Endpoints)
return eps
@ -122,22 +158,9 @@ func (c *Client) Endpoints() []string {
// SetEndpoints updates client's endpoints.
func (c *Client) SetEndpoints(eps ...string) {
c.mu.Lock()
defer c.mu.Unlock()
c.cfg.Endpoints = eps
c.mu.Unlock()
c.balancer.updateAddrs(eps...)
// updating notifyCh can trigger new connections,
// need update addrs if all connections are down
// or addrs does not include pinAddr.
c.balancer.mu.RLock()
update := !hasAddr(c.balancer.addrs, c.balancer.pinAddr)
c.balancer.mu.RUnlock()
if update {
select {
case c.balancer.updateAddrsC <- notifyNext:
case <-c.balancer.stopc:
}
}
c.resolverGroup.SetEndpoints(eps)
}
// Sync synchronizes client's endpoints with the known endpoints from the etcd membership.
@ -168,52 +191,13 @@ func (c *Client) autoSync() {
err := c.Sync(ctx)
cancel()
if err != nil && err != c.ctx.Err() {
logger.Println("Auto sync endpoints failed:", err)
lg.Lvl(4).Infof("Auto sync endpoints failed: %v", err)
}
}
}
}
type authTokenCredential struct {
token string
tokenMu *sync.RWMutex
}
func (cred authTokenCredential) RequireTransportSecurity() bool {
return false
}
func (cred authTokenCredential) GetRequestMetadata(ctx context.Context, s ...string) (map[string]string, error) {
cred.tokenMu.RLock()
defer cred.tokenMu.RUnlock()
return map[string]string{
"token": cred.token,
}, nil
}
func parseEndpoint(endpoint string) (proto string, host string, scheme string) {
proto = "tcp"
host = endpoint
url, uerr := url.Parse(endpoint)
if uerr != nil || !strings.Contains(endpoint, "://") {
return proto, host, scheme
}
scheme = url.Scheme
// strip scheme:// prefix since grpc dials by host
host = url.Host
switch url.Scheme {
case "http", "https":
case "unix", "unixs":
proto = "unix"
host = url.Host + url.Path
default:
proto, host = "", ""
}
return proto, host, scheme
}
func (c *Client) processCreds(scheme string) (creds *credentials.TransportCredentials) {
func (c *Client) processCreds(scheme string) (creds grpccredentials.TransportCredentials) {
creds = c.creds
switch scheme {
case "unix":
@ -223,83 +207,87 @@ func (c *Client) processCreds(scheme string) (creds *credentials.TransportCreden
if creds != nil {
break
}
tlsconfig := &tls.Config{}
emptyCreds := credentials.NewTLS(tlsconfig)
creds = &emptyCreds
creds = credentials.NewBundle(credentials.Config{}).TransportCredentials()
default:
creds = nil
}
return creds
}
// dialSetupOpts gives the dial opts prior to any authentication
func (c *Client) dialSetupOpts(endpoint string, dopts ...grpc.DialOption) (opts []grpc.DialOption) {
if c.cfg.DialTimeout > 0 {
opts = []grpc.DialOption{grpc.WithTimeout(c.cfg.DialTimeout)}
}
// dialSetupOpts gives the dial opts prior to any authentication.
func (c *Client) dialSetupOpts(creds grpccredentials.TransportCredentials, dopts ...grpc.DialOption) (opts []grpc.DialOption, err error) {
if c.cfg.DialKeepAliveTime > 0 {
params := keepalive.ClientParameters{
Time: c.cfg.DialKeepAliveTime,
Timeout: c.cfg.DialKeepAliveTimeout,
Time: c.cfg.DialKeepAliveTime,
Timeout: c.cfg.DialKeepAliveTimeout,
PermitWithoutStream: c.cfg.PermitWithoutStream,
}
opts = append(opts, grpc.WithKeepaliveParams(params))
}
opts = append(opts, dopts...)
f := func(host string, t time.Duration) (net.Conn, error) {
proto, host, _ := parseEndpoint(c.balancer.endpoint(host))
if host == "" && endpoint != "" {
// dialing an endpoint not in the balancer; use
// endpoint passed into dial
proto, host, _ = parseEndpoint(endpoint)
}
if proto == "" {
return nil, fmt.Errorf("unknown scheme for %q", host)
}
// Provide a net dialer that supports cancelation and timeout.
f := func(dialEp string, t time.Duration) (net.Conn, error) {
proto, host, _ := endpoint.ParseEndpoint(dialEp)
select {
case <-c.ctx.Done():
return nil, c.ctx.Err()
default:
}
dialer := &net.Dialer{Timeout: t}
conn, err := dialer.DialContext(c.ctx, proto, host)
if err != nil {
select {
case c.dialerrc <- err:
default:
}
}
return conn, err
return dialer.DialContext(c.ctx, proto, host)
}
opts = append(opts, grpc.WithDialer(f))
creds := c.creds
if _, _, scheme := parseEndpoint(endpoint); len(scheme) != 0 {
creds = c.processCreds(scheme)
}
if creds != nil {
opts = append(opts, grpc.WithTransportCredentials(*creds))
opts = append(opts, grpc.WithTransportCredentials(creds))
} else {
opts = append(opts, grpc.WithInsecure())
}
return opts
// Interceptor retry and backoff.
// TODO: Replace all of clientv3/retry.go with interceptor based retry, or with
// https://github.com/grpc/proposal/blob/master/A6-client-retries.md#retry-policy
// once it is available.
rrBackoff := withBackoff(c.roundRobinQuorumBackoff(defaultBackoffWaitBetween, defaultBackoffJitterFraction))
opts = append(opts,
// Disable stream retry by default since go-grpc-middleware/retry does not support client streams.
// Streams that are safe to retry are enabled individually.
grpc.WithStreamInterceptor(c.streamClientInterceptor(c.lg, withMax(0), rrBackoff)),
grpc.WithUnaryInterceptor(c.unaryClientInterceptor(c.lg, withMax(defaultUnaryMaxRetries), rrBackoff)),
)
return opts, nil
}
// Dial connects to a single endpoint using the client's config.
func (c *Client) Dial(endpoint string) (*grpc.ClientConn, error) {
return c.dial(endpoint)
func (c *Client) Dial(ep string) (*grpc.ClientConn, error) {
creds := c.directDialCreds(ep)
// Use the grpc passthrough resolver to directly dial a single endpoint.
// This resolver passes through the 'unix' and 'unixs' endpoints schemes used
// by etcd without modification, allowing us to directly dial endpoints and
// using the same dial functions that we use for load balancer dialing.
return c.dial(fmt.Sprintf("passthrough:///%s", ep), creds)
}
func (c *Client) getToken(ctx context.Context) error {
var err error // return last error in a case of fail
var auth *authenticator
for i := 0; i < len(c.cfg.Endpoints); i++ {
endpoint := c.cfg.Endpoints[i]
host := getHost(endpoint)
eps := c.Endpoints()
for _, ep := range eps {
// use dial options without dopts to avoid reusing the client balancer
auth, err = newAuthenticator(host, c.dialSetupOpts(endpoint), c)
var dOpts []grpc.DialOption
_, host, _ := endpoint.ParseEndpoint(ep)
target := c.resolverGroup.Target(host)
creds := c.dialWithBalancerCreds(ep)
dOpts, err = c.dialSetupOpts(creds, c.cfg.DialOptions...)
if err != nil {
err = fmt.Errorf("failed to configure auth dialer: %v", err)
continue
}
dOpts = append(dOpts, grpc.WithBalancerName(roundRobinBalancerName))
auth, err = newAuthenticator(ctx, target, dOpts, c)
if err != nil {
continue
}
@ -308,56 +296,102 @@ func (c *Client) getToken(ctx context.Context) error {
var resp *AuthenticateResponse
resp, err = auth.authenticate(ctx, c.Username, c.Password)
if err != nil {
// return err without retrying other endpoints
if err == rpctypes.ErrAuthNotEnabled {
return err
}
continue
}
c.tokenCred.tokenMu.Lock()
c.tokenCred.token = resp.Token
c.tokenCred.tokenMu.Unlock()
c.authTokenBundle.UpdateAuthToken(resp.Token)
return nil
}
return err
}
func (c *Client) dial(endpoint string, dopts ...grpc.DialOption) (*grpc.ClientConn, error) {
opts := c.dialSetupOpts(endpoint, dopts...)
host := getHost(endpoint)
// dialWithBalancer dials the client's current load balanced resolver group. The scheme of the host
// of the provided endpoint determines the scheme used for all endpoints of the client connection.
func (c *Client) dialWithBalancer(ep string, dopts ...grpc.DialOption) (*grpc.ClientConn, error) {
_, host, _ := endpoint.ParseEndpoint(ep)
target := c.resolverGroup.Target(host)
creds := c.dialWithBalancerCreds(ep)
return c.dial(target, creds, dopts...)
}
// dial configures and dials any grpc balancer target.
func (c *Client) dial(target string, creds grpccredentials.TransportCredentials, dopts ...grpc.DialOption) (*grpc.ClientConn, error) {
opts, err := c.dialSetupOpts(creds, dopts...)
if err != nil {
return nil, fmt.Errorf("failed to configure dialer: %v", err)
}
if c.Username != "" && c.Password != "" {
c.tokenCred = &authTokenCredential{
tokenMu: &sync.RWMutex{},
}
c.authTokenBundle = credentials.NewBundle(credentials.Config{})
ctx := c.ctx
ctx, cancel := c.ctx, func() {}
if c.cfg.DialTimeout > 0 {
cctx, cancel := context.WithTimeout(ctx, c.cfg.DialTimeout)
defer cancel()
ctx = cctx
ctx, cancel = context.WithTimeout(ctx, c.cfg.DialTimeout)
}
err := c.getToken(ctx)
err = c.getToken(ctx)
if err != nil {
if toErr(ctx, err) != rpctypes.ErrAuthNotEnabled {
if err == ctx.Err() && ctx.Err() != c.ctx.Err() {
err = context.DeadlineExceeded
}
cancel()
return nil, err
}
} else {
opts = append(opts, grpc.WithPerRPCCredentials(c.tokenCred))
opts = append(opts, grpc.WithPerRPCCredentials(c.authTokenBundle.PerRPCCredentials()))
}
cancel()
}
opts = append(opts, c.cfg.DialOptions...)
conn, err := grpc.DialContext(c.ctx, host, opts...)
dctx := c.ctx
if c.cfg.DialTimeout > 0 {
var cancel context.CancelFunc
dctx, cancel = context.WithTimeout(c.ctx, c.cfg.DialTimeout)
defer cancel() // TODO: Is this right for cases where grpc.WithBlock() is not set on the dial options?
}
conn, err := grpc.DialContext(dctx, target, opts...)
if err != nil {
return nil, err
}
return conn, nil
}
func (c *Client) directDialCreds(ep string) grpccredentials.TransportCredentials {
_, hostPort, scheme := endpoint.ParseEndpoint(ep)
creds := c.creds
if len(scheme) != 0 {
creds = c.processCreds(scheme)
if creds != nil {
clone := creds.Clone()
// Set the server name must to the endpoint hostname without port since grpc
// otherwise attempts to check if x509 cert is valid for the full endpoint
// including the scheme and port, which fails.
host, _ := endpoint.ParseHostPort(hostPort)
clone.OverrideServerName(host)
creds = clone
}
}
return creds
}
func (c *Client) dialWithBalancerCreds(ep string) grpccredentials.TransportCredentials {
_, _, scheme := endpoint.ParseEndpoint(ep)
creds := c.creds
if len(scheme) != 0 {
creds = c.processCreds(scheme)
}
return creds
}
// WithRequireLeader requires client requests to only succeed
// when the cluster has a leader.
func WithRequireLeader(ctx context.Context) context.Context {
@ -369,10 +403,9 @@ func newClient(cfg *Config) (*Client, error) {
if cfg == nil {
cfg = &Config{}
}
var creds *credentials.TransportCredentials
var creds grpccredentials.TransportCredentials
if cfg.TLS != nil {
c := credentials.NewTLS(cfg.TLS)
creds = &c
creds = credentials.NewBundle(credentials.Config{TLSConfig: cfg.TLS}).TransportCredentials()
}
// use a temporary skeleton client to bootstrap first connection
@ -384,7 +417,6 @@ func newClient(cfg *Config) (*Client, error) {
ctx, cancel := context.WithCancel(baseCtx)
client := &Client{
conn: nil,
dialerrc: make(chan error, 1),
cfg: *cfg,
creds: creds,
ctx: ctx,
@ -392,6 +424,17 @@ func newClient(cfg *Config) (*Client, error) {
mu: new(sync.RWMutex),
callOpts: defaultCallOpts,
}
lcfg := logutil.DefaultZapLoggerConfig
if cfg.LogConfig != nil {
lcfg = *cfg.LogConfig
}
var err error
client.lg, err = lcfg.Build()
if err != nil {
return nil, err
}
if cfg.Username != "" && cfg.Password != "" {
client.Username = cfg.Username
client.Password = cfg.Password
@ -414,42 +457,31 @@ func newClient(cfg *Config) (*Client, error) {
client.callOpts = callOpts
}
client.balancer = newHealthBalancer(cfg.Endpoints, cfg.DialTimeout, func(ep string) (bool, error) {
return grpcHealthCheck(client, ep)
})
// use Endpoints[0] so that for https:// without any tls config given, then
// grpc will assume the certificate server name is the endpoint host.
conn, err := client.dial(cfg.Endpoints[0], grpc.WithBalancer(client.balancer))
// Prepare a 'endpoint://<unique-client-id>/' resolver for the client and create a endpoint target to pass
// to dial so the client knows to use this resolver.
client.resolverGroup, err = endpoint.NewResolverGroup(fmt.Sprintf("client-%s", uuid.New().String()))
if err != nil {
client.cancel()
client.balancer.Close()
return nil, err
}
client.conn = conn
client.resolverGroup.SetEndpoints(cfg.Endpoints)
// wait for a connection
if cfg.DialTimeout > 0 {
hasConn := false
waitc := time.After(cfg.DialTimeout)
select {
case <-client.balancer.ready():
hasConn = true
case <-ctx.Done():
case <-waitc:
}
if !hasConn {
err := context.DeadlineExceeded
select {
case err = <-client.dialerrc:
default:
}
client.cancel()
client.balancer.Close()
conn.Close()
return nil, err
}
if len(cfg.Endpoints) < 1 {
return nil, fmt.Errorf("at least one Endpoint must is required in client config")
}
dialEndpoint := cfg.Endpoints[0]
// Use a provided endpoint target so that for https:// without any tls config given, then
// grpc will assume the certificate server name is the endpoint host.
conn, err := client.dialWithBalancer(dialEndpoint, grpc.WithBalancerName(roundRobinBalancerName))
if err != nil {
client.cancel()
client.resolverGroup.Close()
return nil, err
}
// TODO: With the old grpc balancer interface, we waited until the dial timeout
// for the balancer to be ready. Is there an equivalent wait we should do with the new grpc balancer interface?
client.conn = conn
client.Cluster = NewCluster(client)
client.KV = NewKV(client)
@ -469,15 +501,35 @@ func newClient(cfg *Config) (*Client, error) {
return client, nil
}
// roundRobinQuorumBackoff retries against quorum between each backoff.
// This is intended for use with a round robin load balancer.
func (c *Client) roundRobinQuorumBackoff(waitBetween time.Duration, jitterFraction float64) backoffFunc {
return func(attempt uint) time.Duration {
// after each round robin across quorum, backoff for our wait between duration
n := uint(len(c.Endpoints()))
quorum := (n/2 + 1)
if attempt%quorum == 0 {
c.lg.Debug("backoff", zap.Uint("attempt", attempt), zap.Uint("quorum", quorum), zap.Duration("waitBetween", waitBetween), zap.Float64("jitterFraction", jitterFraction))
return jitterUp(waitBetween, jitterFraction)
}
c.lg.Debug("backoff skipped", zap.Uint("attempt", attempt), zap.Uint("quorum", quorum))
return 0
}
}
func (c *Client) checkVersion() (err error) {
var wg sync.WaitGroup
errc := make(chan error, len(c.cfg.Endpoints))
eps := c.Endpoints()
errc := make(chan error, len(eps))
ctx, cancel := context.WithCancel(c.ctx)
if c.cfg.DialTimeout > 0 {
ctx, cancel = context.WithTimeout(ctx, c.cfg.DialTimeout)
cancel()
ctx, cancel = context.WithTimeout(c.ctx, c.cfg.DialTimeout)
}
wg.Add(len(c.cfg.Endpoints))
for _, ep := range c.cfg.Endpoints {
wg.Add(len(eps))
for _, ep := range eps {
// if cluster is current, any endpoint gives a recent version
go func(e string) {
defer wg.Done()
@ -489,8 +541,15 @@ func (c *Client) checkVersion() (err error) {
vs := strings.Split(resp.Version, ".")
maj, min := 0, 0
if len(vs) >= 2 {
maj, _ = strconv.Atoi(vs[0])
min, rerr = strconv.Atoi(vs[1])
var serr error
if maj, serr = strconv.Atoi(vs[0]); serr != nil {
errc <- serr
return
}
if min, serr = strconv.Atoi(vs[1]); serr != nil {
errc <- serr
return
}
}
if maj < 3 || (maj == 3 && min < 2) {
rerr = ErrOldCluster
@ -499,7 +558,7 @@ func (c *Client) checkVersion() (err error) {
}(ep)
}
// wait for success
for i := 0; i < len(c.cfg.Endpoints); i++ {
for range eps {
if err = <-errc; err == nil {
break
}
@ -539,10 +598,13 @@ func isUnavailableErr(ctx context.Context, err error) bool {
if err == nil {
return false
}
ev, _ := status.FromError(err)
// Unavailable codes mean the system will be right back.
// (e.g., can't connect, lost leader)
return ev.Code() == codes.Unavailable
ev, ok := status.FromError(err)
if ok {
// Unavailable codes mean the system will be right back.
// (e.g., can't connect, lost leader)
return ev.Code() == codes.Unavailable
}
return false
}
func toErr(ctx context.Context, err error) error {
@ -553,18 +615,16 @@ func toErr(ctx context.Context, err error) error {
if _, ok := err.(rpctypes.EtcdError); ok {
return err
}
ev, _ := status.FromError(err)
code := ev.Code()
switch code {
case codes.DeadlineExceeded:
fallthrough
case codes.Canceled:
if ctx.Err() != nil {
err = ctx.Err()
if ev, ok := status.FromError(err); ok {
code := ev.Code()
switch code {
case codes.DeadlineExceeded:
fallthrough
case codes.Canceled:
if ctx.Err() != nil {
err = ctx.Err()
}
}
case codes.Unavailable:
case codes.FailedPrecondition:
err = grpc.ErrClientConnClosing
}
return err
}
@ -576,3 +636,26 @@ func canceledByCaller(stopCtx context.Context, err error) bool {
return err == context.Canceled || err == context.DeadlineExceeded
}
// IsConnCanceled returns true, if error is from a closed gRPC connection.
// ref. https://github.com/grpc/grpc-go/pull/1854
func IsConnCanceled(err error) bool {
if err == nil {
return false
}
// >= gRPC v1.23.x
s, ok := status.FromError(err)
if ok {
// connection is canceled or server has already closed the connection
return s.Code() == codes.Canceled || s.Message() == "transport is closing"
}
// >= gRPC v1.10.x
if err == context.Canceled {
return true
}
// <= gRPC v1.7.x returns 'errors.New("grpc: the client connection is closing")'
return strings.Contains(err.Error(), "grpc: the client connection is closing")
}

View File

@ -23,6 +23,7 @@ import (
"github.com/coreos/etcd/etcdserver/api/v3rpc/rpctypes"
"github.com/coreos/etcd/pkg/testutil"
"google.golang.org/grpc"
)
func TestDialCancel(t *testing.T) {
@ -83,11 +84,13 @@ func TestDialTimeout(t *testing.T) {
testCfgs := []Config{
{
Endpoints: []string{"http://254.0.0.1:12345"},
DialOptions: []grpc.DialOption{grpc.WithBlock()},
DialTimeout: 2 * time.Second,
},
{
Endpoints: []string{"http://254.0.0.1:12345"},
DialTimeout: time.Second,
DialOptions: []grpc.DialOption{grpc.WithBlock()},
Username: "abc",
Password: "def",
},

View File

@ -49,6 +49,7 @@ func NewElection(s *Session, pfx string) *Election {
func ResumeElection(s *Session, pfx string, leaderKey string, leaderRev int64) *Election {
return &Election{
session: s,
keyPrefix: pfx,
leaderKey: leaderKey,
leaderRev: leaderRev,
leaderSession: s,

View File

@ -0,0 +1,114 @@
// Copyright 2018 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package concurrency_test
import (
"context"
"log"
"testing"
"time"
"strings"
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/clientv3/concurrency"
)
func TestResumeElection(t *testing.T) {
const prefix = "/resume-election/"
cli, err := clientv3.New(clientv3.Config{Endpoints: endpoints})
if err != nil {
log.Fatal(err)
}
defer cli.Close()
s, err := concurrency.NewSession(cli)
if err != nil {
log.Fatal(err)
}
defer s.Close()
e := concurrency.NewElection(s, prefix)
// Entire test should never take more than 10 seconds
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
// Become leader
if err := e.Campaign(ctx, "candidate1"); err != nil {
t.Fatalf("Campaign() returned non nil err: %s", err)
}
// Get the leadership details of the current election
leader, err := e.Leader(ctx)
if err != nil {
t.Fatalf("Leader() returned non nil err: %s", err)
}
// Recreate the election
e = concurrency.ResumeElection(s, prefix,
string(leader.Kvs[0].Key), leader.Kvs[0].CreateRevision)
respChan := make(chan *clientv3.GetResponse)
go func() {
o := e.Observe(ctx)
respChan <- nil
for {
select {
case resp, ok := <-o:
if !ok {
t.Fatal("Observe() channel closed prematurely")
}
// Ignore any observations that candidate1 was elected
if string(resp.Kvs[0].Value) == "candidate1" {
continue
}
respChan <- &resp
return
}
}
}()
// Wait until observe goroutine is running
<-respChan
// Put some random data to generate a change event, this put should be
// ignored by Observe() because it is not under the election prefix.
_, err = cli.Put(ctx, "foo", "bar")
if err != nil {
t.Fatalf("Put('foo') returned non nil err: %s", err)
}
// Resign as leader
if err := e.Resign(ctx); err != nil {
t.Fatalf("Resign() returned non nil err: %s", err)
}
// Elect a different candidate
if err := e.Campaign(ctx, "candidate2"); err != nil {
t.Fatalf("Campaign() returned non nil err: %s", err)
}
// Wait for observed leader change
resp := <-respChan
kv := resp.Kvs[0]
if !strings.HasPrefix(string(kv.Key), prefix) {
t.Errorf("expected observed election to have prefix '%s' got '%s'", prefix, string(kv.Key))
}
if string(kv.Value) != "candidate2" {
t.Errorf("expected new leader to be 'candidate1' got '%s'", string(kv.Value))
}
}

View File

@ -19,6 +19,7 @@ import (
"crypto/tls"
"time"
"go.uber.org/zap"
"google.golang.org/grpc"
)
@ -67,9 +68,19 @@ type Config struct {
RejectOldCluster bool `json:"reject-old-cluster"`
// DialOptions is a list of dial options for the grpc client (e.g., for interceptors).
// For example, pass "grpc.WithBlock()" to block until the underlying connection is up.
// Without this, Dial returns immediately and connecting the server happens in background.
DialOptions []grpc.DialOption
// LogConfig configures client-side logger.
// If nil, use the default logger.
// TODO: configure gRPC logger
LogConfig *zap.Config
// Context is the default client context; it can be used to cancel grpc dial out and
// other operations that do not have an explicit context.
Context context.Context
// PermitWithoutStream when set will allow client to send keepalive pings to server without any active streams(RPCs).
PermitWithoutStream bool `json:"permit-without-stream"`
}

View File

@ -0,0 +1,155 @@
// Copyright 2019 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package credentials implements gRPC credential interface with etcd specific logic.
// e.g., client handshake with custom authority parameter
package credentials
import (
"context"
"crypto/tls"
"net"
"sync"
"github.com/coreos/etcd/etcdserver/api/v3rpc/rpctypes"
grpccredentials "google.golang.org/grpc/credentials"
)
// Config defines gRPC credential configuration.
type Config struct {
TLSConfig *tls.Config
}
// Bundle defines gRPC credential interface.
type Bundle interface {
grpccredentials.Bundle
UpdateAuthToken(token string)
}
// NewBundle constructs a new gRPC credential bundle.
func NewBundle(cfg Config) Bundle {
return &bundle{
tc: newTransportCredential(cfg.TLSConfig),
rc: newPerRPCCredential(),
}
}
// bundle implements "grpccredentials.Bundle" interface.
type bundle struct {
tc *transportCredential
rc *perRPCCredential
}
func (b *bundle) TransportCredentials() grpccredentials.TransportCredentials {
return b.tc
}
func (b *bundle) PerRPCCredentials() grpccredentials.PerRPCCredentials {
return b.rc
}
func (b *bundle) NewWithMode(mode string) (grpccredentials.Bundle, error) {
// no-op
return nil, nil
}
// transportCredential implements "grpccredentials.TransportCredentials" interface.
type transportCredential struct {
gtc grpccredentials.TransportCredentials
}
func newTransportCredential(cfg *tls.Config) *transportCredential {
return &transportCredential{
gtc: grpccredentials.NewTLS(cfg),
}
}
func (tc *transportCredential) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, grpccredentials.AuthInfo, error) {
// Only overwrite when authority is an IP address!
// Let's say, a server runs SRV records on "etcd.local" that resolves
// to "m1.etcd.local", and its SAN field also includes "m1.etcd.local".
// But what if SAN does not include its resolved IP address (e.g. 127.0.0.1)?
// Then, the server should only authenticate using its DNS hostname "m1.etcd.local",
// instead of overwriting it with its IP address.
// And we do not overwrite "localhost" either. Only overwrite IP addresses!
if isIP(authority) {
target := rawConn.RemoteAddr().String()
if authority != target {
// When user dials with "grpc.WithDialer", "grpc.DialContext" "cc.parsedTarget"
// update only happens once. This is problematic, because when TLS is enabled,
// retries happen through "grpc.WithDialer" with static "cc.parsedTarget" from
// the initial dial call.
// If the server authenticates by IP addresses, we want to set a new endpoint as
// a new authority. Otherwise
// "transport: authentication handshake failed: x509: certificate is valid for 127.0.0.1, 192.168.121.180, not 192.168.223.156"
// when the new dial target is "192.168.121.180" whose certificate host name is also "192.168.121.180"
// but client tries to authenticate with previously set "cc.parsedTarget" field "192.168.223.156"
authority = target
}
}
return tc.gtc.ClientHandshake(ctx, authority, rawConn)
}
// return true if given string is an IP.
func isIP(ep string) bool {
return net.ParseIP(ep) != nil
}
func (tc *transportCredential) ServerHandshake(rawConn net.Conn) (net.Conn, grpccredentials.AuthInfo, error) {
return tc.gtc.ServerHandshake(rawConn)
}
func (tc *transportCredential) Info() grpccredentials.ProtocolInfo {
return tc.gtc.Info()
}
func (tc *transportCredential) Clone() grpccredentials.TransportCredentials {
return &transportCredential{
gtc: tc.gtc.Clone(),
}
}
func (tc *transportCredential) OverrideServerName(serverNameOverride string) error {
return tc.gtc.OverrideServerName(serverNameOverride)
}
// perRPCCredential implements "grpccredentials.PerRPCCredentials" interface.
type perRPCCredential struct {
authToken string
authTokenMu sync.RWMutex
}
func newPerRPCCredential() *perRPCCredential { return &perRPCCredential{} }
func (rc *perRPCCredential) RequireTransportSecurity() bool { return false }
func (rc *perRPCCredential) GetRequestMetadata(ctx context.Context, s ...string) (map[string]string, error) {
rc.authTokenMu.RLock()
authToken := rc.authToken
rc.authTokenMu.RUnlock()
return map[string]string{rpctypes.TokenFieldNameGRPC: authToken}, nil
}
func (b *bundle) UpdateAuthToken(token string) {
if b.rc == nil {
return
}
b.rc.UpdateAuthToken(token)
}
func (rc *perRPCCredential) UpdateAuthToken(token string) {
rc.authTokenMu.Lock()
rc.authToken = token
rc.authTokenMu.Unlock()
}

View File

@ -1,609 +0,0 @@
// Copyright 2017 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package clientv3
import (
"context"
"errors"
"net/url"
"strings"
"sync"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
healthpb "google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/grpc/status"
)
const (
minHealthRetryDuration = 3 * time.Second
unknownService = "unknown service grpc.health.v1.Health"
)
// ErrNoAddrAvilable is returned by Get() when the balancer does not have
// any active connection to endpoints at the time.
// This error is returned only when opts.BlockingWait is true.
var ErrNoAddrAvilable = status.Error(codes.Unavailable, "there is no address available")
type healthCheckFunc func(ep string) (bool, error)
type notifyMsg int
const (
notifyReset notifyMsg = iota
notifyNext
)
// healthBalancer does the bare minimum to expose multiple eps
// to the grpc reconnection code path
type healthBalancer struct {
// addrs are the client's endpoint addresses for grpc
addrs []grpc.Address
// eps holds the raw endpoints from the client
eps []string
// notifyCh notifies grpc of the set of addresses for connecting
notifyCh chan []grpc.Address
// readyc closes once the first connection is up
readyc chan struct{}
readyOnce sync.Once
// healthCheck checks an endpoint's health.
healthCheck healthCheckFunc
healthCheckTimeout time.Duration
unhealthyMu sync.RWMutex
unhealthyHostPorts map[string]time.Time
// mu protects all fields below.
mu sync.RWMutex
// upc closes when pinAddr transitions from empty to non-empty or the balancer closes.
upc chan struct{}
// downc closes when grpc calls down() on pinAddr
downc chan struct{}
// stopc is closed to signal updateNotifyLoop should stop.
stopc chan struct{}
stopOnce sync.Once
wg sync.WaitGroup
// donec closes when all goroutines are exited
donec chan struct{}
// updateAddrsC notifies updateNotifyLoop to update addrs.
updateAddrsC chan notifyMsg
// grpc issues TLS cert checks using the string passed into dial so
// that string must be the host. To recover the full scheme://host URL,
// have a map from hosts to the original endpoint.
hostPort2ep map[string]string
// pinAddr is the currently pinned address; set to the empty string on
// initialization and shutdown.
pinAddr string
closed bool
}
func newHealthBalancer(eps []string, timeout time.Duration, hc healthCheckFunc) *healthBalancer {
notifyCh := make(chan []grpc.Address)
addrs := eps2addrs(eps)
hb := &healthBalancer{
addrs: addrs,
eps: eps,
notifyCh: notifyCh,
readyc: make(chan struct{}),
healthCheck: hc,
unhealthyHostPorts: make(map[string]time.Time),
upc: make(chan struct{}),
stopc: make(chan struct{}),
downc: make(chan struct{}),
donec: make(chan struct{}),
updateAddrsC: make(chan notifyMsg),
hostPort2ep: getHostPort2ep(eps),
}
if timeout < minHealthRetryDuration {
timeout = minHealthRetryDuration
}
hb.healthCheckTimeout = timeout
close(hb.downc)
go hb.updateNotifyLoop()
hb.wg.Add(1)
go func() {
defer hb.wg.Done()
hb.updateUnhealthy()
}()
return hb
}
func (b *healthBalancer) Start(target string, config grpc.BalancerConfig) error { return nil }
func (b *healthBalancer) ConnectNotify() <-chan struct{} {
b.mu.Lock()
defer b.mu.Unlock()
return b.upc
}
func (b *healthBalancer) ready() <-chan struct{} { return b.readyc }
func (b *healthBalancer) endpoint(hostPort string) string {
b.mu.RLock()
defer b.mu.RUnlock()
return b.hostPort2ep[hostPort]
}
func (b *healthBalancer) pinned() string {
b.mu.RLock()
defer b.mu.RUnlock()
return b.pinAddr
}
func (b *healthBalancer) hostPortError(hostPort string, err error) {
if b.endpoint(hostPort) == "" {
logger.Lvl(4).Infof("clientv3/balancer: %q is stale (skip marking as unhealthy on %q)", hostPort, err.Error())
return
}
b.unhealthyMu.Lock()
b.unhealthyHostPorts[hostPort] = time.Now()
b.unhealthyMu.Unlock()
logger.Lvl(4).Infof("clientv3/balancer: %q is marked unhealthy (%q)", hostPort, err.Error())
}
func (b *healthBalancer) removeUnhealthy(hostPort, msg string) {
if b.endpoint(hostPort) == "" {
logger.Lvl(4).Infof("clientv3/balancer: %q was not in unhealthy (%q)", hostPort, msg)
return
}
b.unhealthyMu.Lock()
delete(b.unhealthyHostPorts, hostPort)
b.unhealthyMu.Unlock()
logger.Lvl(4).Infof("clientv3/balancer: %q is removed from unhealthy (%q)", hostPort, msg)
}
func (b *healthBalancer) countUnhealthy() (count int) {
b.unhealthyMu.RLock()
count = len(b.unhealthyHostPorts)
b.unhealthyMu.RUnlock()
return count
}
func (b *healthBalancer) isUnhealthy(hostPort string) (unhealthy bool) {
b.unhealthyMu.RLock()
_, unhealthy = b.unhealthyHostPorts[hostPort]
b.unhealthyMu.RUnlock()
return unhealthy
}
func (b *healthBalancer) cleanupUnhealthy() {
b.unhealthyMu.Lock()
for k, v := range b.unhealthyHostPorts {
if time.Since(v) > b.healthCheckTimeout {
delete(b.unhealthyHostPorts, k)
logger.Lvl(4).Infof("clientv3/balancer: removed %q from unhealthy after %v", k, b.healthCheckTimeout)
}
}
b.unhealthyMu.Unlock()
}
func (b *healthBalancer) liveAddrs() ([]grpc.Address, map[string]struct{}) {
unhealthyCnt := b.countUnhealthy()
b.mu.RLock()
defer b.mu.RUnlock()
hbAddrs := b.addrs
if len(b.addrs) == 1 || unhealthyCnt == 0 || unhealthyCnt == len(b.addrs) {
liveHostPorts := make(map[string]struct{}, len(b.hostPort2ep))
for k := range b.hostPort2ep {
liveHostPorts[k] = struct{}{}
}
return hbAddrs, liveHostPorts
}
addrs := make([]grpc.Address, 0, len(b.addrs)-unhealthyCnt)
liveHostPorts := make(map[string]struct{}, len(addrs))
for _, addr := range b.addrs {
if !b.isUnhealthy(addr.Addr) {
addrs = append(addrs, addr)
liveHostPorts[addr.Addr] = struct{}{}
}
}
return addrs, liveHostPorts
}
func (b *healthBalancer) updateUnhealthy() {
for {
select {
case <-time.After(b.healthCheckTimeout):
b.cleanupUnhealthy()
pinned := b.pinned()
if pinned == "" || b.isUnhealthy(pinned) {
select {
case b.updateAddrsC <- notifyNext:
case <-b.stopc:
return
}
}
case <-b.stopc:
return
}
}
}
func (b *healthBalancer) updateAddrs(eps ...string) {
np := getHostPort2ep(eps)
b.mu.Lock()
defer b.mu.Unlock()
match := len(np) == len(b.hostPort2ep)
if match {
for k, v := range np {
if b.hostPort2ep[k] != v {
match = false
break
}
}
}
if match {
// same endpoints, so no need to update address
return
}
b.hostPort2ep = np
b.addrs, b.eps = eps2addrs(eps), eps
b.unhealthyMu.Lock()
b.unhealthyHostPorts = make(map[string]time.Time)
b.unhealthyMu.Unlock()
}
func (b *healthBalancer) next() {
b.mu.RLock()
downc := b.downc
b.mu.RUnlock()
select {
case b.updateAddrsC <- notifyNext:
case <-b.stopc:
}
// wait until disconnect so new RPCs are not issued on old connection
select {
case <-downc:
case <-b.stopc:
}
}
func (b *healthBalancer) updateNotifyLoop() {
defer close(b.donec)
for {
b.mu.RLock()
upc, downc, addr := b.upc, b.downc, b.pinAddr
b.mu.RUnlock()
// downc or upc should be closed
select {
case <-downc:
downc = nil
default:
}
select {
case <-upc:
upc = nil
default:
}
switch {
case downc == nil && upc == nil:
// stale
select {
case <-b.stopc:
return
default:
}
case downc == nil:
b.notifyAddrs(notifyReset)
select {
case <-upc:
case msg := <-b.updateAddrsC:
b.notifyAddrs(msg)
case <-b.stopc:
return
}
case upc == nil:
select {
// close connections that are not the pinned address
case b.notifyCh <- []grpc.Address{{Addr: addr}}:
case <-downc:
case <-b.stopc:
return
}
select {
case <-downc:
b.notifyAddrs(notifyReset)
case msg := <-b.updateAddrsC:
b.notifyAddrs(msg)
case <-b.stopc:
return
}
}
}
}
func (b *healthBalancer) notifyAddrs(msg notifyMsg) {
if msg == notifyNext {
select {
case b.notifyCh <- []grpc.Address{}:
case <-b.stopc:
return
}
}
b.mu.RLock()
pinAddr := b.pinAddr
downc := b.downc
b.mu.RUnlock()
addrs, hostPorts := b.liveAddrs()
var waitDown bool
if pinAddr != "" {
_, ok := hostPorts[pinAddr]
waitDown = !ok
}
select {
case b.notifyCh <- addrs:
if waitDown {
select {
case <-downc:
case <-b.stopc:
}
}
case <-b.stopc:
}
}
func (b *healthBalancer) Up(addr grpc.Address) func(error) {
if !b.mayPin(addr) {
return func(err error) {}
}
b.mu.Lock()
defer b.mu.Unlock()
// gRPC might call Up after it called Close. We add this check
// to "fix" it up at application layer. Otherwise, will panic
// if b.upc is already closed.
if b.closed {
return func(err error) {}
}
// gRPC might call Up on a stale address.
// Prevent updating pinAddr with a stale address.
if !hasAddr(b.addrs, addr.Addr) {
return func(err error) {}
}
if b.pinAddr != "" {
logger.Lvl(4).Infof("clientv3/balancer: %q is up but not pinned (already pinned %q)", addr.Addr, b.pinAddr)
return func(err error) {}
}
// notify waiting Get()s and pin first connected address
close(b.upc)
b.downc = make(chan struct{})
b.pinAddr = addr.Addr
logger.Lvl(4).Infof("clientv3/balancer: pin %q", addr.Addr)
// notify client that a connection is up
b.readyOnce.Do(func() { close(b.readyc) })
return func(err error) {
// If connected to a black hole endpoint or a killed server, the gRPC ping
// timeout will induce a network I/O error, and retrying until success;
// finding healthy endpoint on retry could take several timeouts and redials.
// To avoid wasting retries, gray-list unhealthy endpoints.
b.hostPortError(addr.Addr, err)
b.mu.Lock()
b.upc = make(chan struct{})
close(b.downc)
b.pinAddr = ""
b.mu.Unlock()
logger.Lvl(4).Infof("clientv3/balancer: unpin %q (%q)", addr.Addr, err.Error())
}
}
func (b *healthBalancer) mayPin(addr grpc.Address) bool {
if b.endpoint(addr.Addr) == "" { // stale host:port
return false
}
b.unhealthyMu.RLock()
unhealthyCnt := len(b.unhealthyHostPorts)
failedTime, bad := b.unhealthyHostPorts[addr.Addr]
b.unhealthyMu.RUnlock()
b.mu.RLock()
skip := len(b.addrs) == 1 || unhealthyCnt == 0 || len(b.addrs) == unhealthyCnt
b.mu.RUnlock()
if skip || !bad {
return true
}
// prevent isolated member's endpoint from being infinitely retried, as follows:
// 1. keepalive pings detects GoAway with http2.ErrCodeEnhanceYourCalm
// 2. balancer 'Up' unpins with grpc: failed with network I/O error
// 3. grpc-healthcheck still SERVING, thus retry to pin
// instead, return before grpc-healthcheck if failed within healthcheck timeout
if elapsed := time.Since(failedTime); elapsed < b.healthCheckTimeout {
logger.Lvl(4).Infof("clientv3/balancer: %q is up but not pinned (failed %v ago, require minimum %v after failure)", addr.Addr, elapsed, b.healthCheckTimeout)
return false
}
if ok, _ := b.healthCheck(addr.Addr); ok {
b.removeUnhealthy(addr.Addr, "health check success")
return true
}
b.hostPortError(addr.Addr, errors.New("health check failed"))
return false
}
func (b *healthBalancer) Get(ctx context.Context, opts grpc.BalancerGetOptions) (grpc.Address, func(), error) {
var (
addr string
closed bool
)
// If opts.BlockingWait is false (for fail-fast RPCs), it should return
// an address it has notified via Notify immediately instead of blocking.
if !opts.BlockingWait {
b.mu.RLock()
closed = b.closed
addr = b.pinAddr
b.mu.RUnlock()
if closed {
return grpc.Address{Addr: ""}, nil, grpc.ErrClientConnClosing
}
if addr == "" {
return grpc.Address{Addr: ""}, nil, ErrNoAddrAvilable
}
return grpc.Address{Addr: addr}, func() {}, nil
}
for {
b.mu.RLock()
ch := b.upc
b.mu.RUnlock()
select {
case <-ch:
case <-b.donec:
return grpc.Address{Addr: ""}, nil, grpc.ErrClientConnClosing
case <-ctx.Done():
return grpc.Address{Addr: ""}, nil, ctx.Err()
}
b.mu.RLock()
closed = b.closed
addr = b.pinAddr
b.mu.RUnlock()
// Close() which sets b.closed = true can be called before Get(), Get() must exit if balancer is closed.
if closed {
return grpc.Address{Addr: ""}, nil, grpc.ErrClientConnClosing
}
if addr != "" {
break
}
}
return grpc.Address{Addr: addr}, func() {}, nil
}
func (b *healthBalancer) Notify() <-chan []grpc.Address { return b.notifyCh }
func (b *healthBalancer) Close() error {
b.mu.Lock()
// In case gRPC calls close twice. TODO: remove the checking
// when we are sure that gRPC wont call close twice.
if b.closed {
b.mu.Unlock()
<-b.donec
return nil
}
b.closed = true
b.stopOnce.Do(func() { close(b.stopc) })
b.pinAddr = ""
// In the case of following scenario:
// 1. upc is not closed; no pinned address
// 2. client issues an RPC, calling invoke(), which calls Get(), enters for loop, blocks
// 3. client.conn.Close() calls balancer.Close(); closed = true
// 4. for loop in Get() never exits since ctx is the context passed in by the client and may not be canceled
// we must close upc so Get() exits from blocking on upc
select {
case <-b.upc:
default:
// terminate all waiting Get()s
close(b.upc)
}
b.mu.Unlock()
b.wg.Wait()
// wait for updateNotifyLoop to finish
<-b.donec
close(b.notifyCh)
return nil
}
func grpcHealthCheck(client *Client, ep string) (bool, error) {
conn, err := client.dial(ep)
if err != nil {
return false, err
}
defer conn.Close()
cli := healthpb.NewHealthClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
resp, err := cli.Check(ctx, &healthpb.HealthCheckRequest{})
cancel()
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.Unavailable {
if s.Message() == unknownService { // etcd < v3.3.0
return true, nil
}
}
return false, err
}
return resp.Status == healthpb.HealthCheckResponse_SERVING, nil
}
func hasAddr(addrs []grpc.Address, targetAddr string) bool {
for _, addr := range addrs {
if targetAddr == addr.Addr {
return true
}
}
return false
}
func getHost(ep string) string {
url, uerr := url.Parse(ep)
if uerr != nil || !strings.Contains(ep, "://") {
return ep
}
return url.Host
}
func eps2addrs(eps []string) []grpc.Address {
addrs := make([]grpc.Address, len(eps))
for i := range eps {
addrs[i].Addr = getHost(eps[i])
}
return addrs
}
func getHostPort2ep(eps []string) map[string]string {
hm := make(map[string]string, len(eps))
for i := range eps {
_, host, _ := parseEndpoint(eps[i])
hm[host] = eps[i]
}
return hm
}

View File

@ -1,298 +0,0 @@
// Copyright 2017 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package clientv3
import (
"context"
"errors"
"net"
"sync"
"testing"
"time"
pb "github.com/coreos/etcd/etcdserver/etcdserverpb"
"github.com/coreos/etcd/pkg/testutil"
"google.golang.org/grpc"
)
var endpoints = []string{"localhost:2379", "localhost:22379", "localhost:32379"}
func TestBalancerGetUnblocking(t *testing.T) {
hb := newHealthBalancer(endpoints, minHealthRetryDuration, func(string) (bool, error) { return true, nil })
defer hb.Close()
if addrs := <-hb.Notify(); len(addrs) != len(endpoints) {
t.Errorf("Initialize newHealthBalancer should have triggered Notify() chan, but it didn't")
}
unblockingOpts := grpc.BalancerGetOptions{BlockingWait: false}
_, _, err := hb.Get(context.Background(), unblockingOpts)
if err != ErrNoAddrAvilable {
t.Errorf("Get() with no up endpoints should return ErrNoAddrAvailable, got: %v", err)
}
down1 := hb.Up(grpc.Address{Addr: endpoints[1]})
if addrs := <-hb.Notify(); len(addrs) != 1 {
t.Errorf("first Up() should have triggered balancer to send the first connected address via Notify chan so that other connections can be closed")
}
down2 := hb.Up(grpc.Address{Addr: endpoints[2]})
addrFirst, putFun, err := hb.Get(context.Background(), unblockingOpts)
if err != nil {
t.Errorf("Get() with up endpoints should success, got %v", err)
}
if addrFirst.Addr != endpoints[1] {
t.Errorf("Get() didn't return expected address, got %v", addrFirst)
}
if putFun == nil {
t.Errorf("Get() returned unexpected nil put function")
}
addrSecond, _, _ := hb.Get(context.Background(), unblockingOpts)
if addrFirst.Addr != addrSecond.Addr {
t.Errorf("Get() didn't return the same address as previous call, got %v and %v", addrFirst, addrSecond)
}
down1(errors.New("error"))
if addrs := <-hb.Notify(); len(addrs) != len(endpoints)-1 { // we call down on one endpoint
t.Errorf("closing the only connection should triggered balancer to send the %d endpoints via Notify chan so that we can establish a connection", len(endpoints)-1)
}
down2(errors.New("error"))
_, _, err = hb.Get(context.Background(), unblockingOpts)
if err != ErrNoAddrAvilable {
t.Errorf("Get() with no up endpoints should return ErrNoAddrAvailable, got: %v", err)
}
}
func TestBalancerGetBlocking(t *testing.T) {
hb := newHealthBalancer(endpoints, minHealthRetryDuration, func(string) (bool, error) { return true, nil })
defer hb.Close()
if addrs := <-hb.Notify(); len(addrs) != len(endpoints) {
t.Errorf("Initialize newHealthBalancer should have triggered Notify() chan, but it didn't")
}
blockingOpts := grpc.BalancerGetOptions{BlockingWait: true}
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
_, _, err := hb.Get(ctx, blockingOpts)
cancel()
if err != context.DeadlineExceeded {
t.Errorf("Get() with no up endpoints should timeout, got %v", err)
}
downC := make(chan func(error), 1)
go func() {
// ensure hb.Up() will be called after hb.Get() to see if Up() releases blocking Get()
time.Sleep(time.Millisecond * 100)
f := hb.Up(grpc.Address{Addr: endpoints[1]})
if addrs := <-hb.Notify(); len(addrs) != 1 {
t.Errorf("first Up() should have triggered balancer to send the first connected address via Notify chan so that other connections can be closed")
}
downC <- f
}()
addrFirst, putFun, err := hb.Get(context.Background(), blockingOpts)
if err != nil {
t.Errorf("Get() with up endpoints should success, got %v", err)
}
if addrFirst.Addr != endpoints[1] {
t.Errorf("Get() didn't return expected address, got %v", addrFirst)
}
if putFun == nil {
t.Errorf("Get() returned unexpected nil put function")
}
down1 := <-downC
down2 := hb.Up(grpc.Address{Addr: endpoints[2]})
addrSecond, _, _ := hb.Get(context.Background(), blockingOpts)
if addrFirst.Addr != addrSecond.Addr {
t.Errorf("Get() didn't return the same address as previous call, got %v and %v", addrFirst, addrSecond)
}
down1(errors.New("error"))
if addrs := <-hb.Notify(); len(addrs) != len(endpoints)-1 { // we call down on one endpoint
t.Errorf("closing the only connection should triggered balancer to send the %d endpoints via Notify chan so that we can establish a connection", len(endpoints)-1)
}
down2(errors.New("error"))
ctx, cancel = context.WithTimeout(context.Background(), time.Millisecond*100)
_, _, err = hb.Get(ctx, blockingOpts)
cancel()
if err != context.DeadlineExceeded {
t.Errorf("Get() with no up endpoints should timeout, got %v", err)
}
}
// TestHealthBalancerGraylist checks one endpoint is tried after the other
// due to gray listing.
func TestHealthBalancerGraylist(t *testing.T) {
var wg sync.WaitGroup
// Use 3 endpoints so gray list doesn't fallback to all connections
// after failing on 2 endpoints.
lns, eps := make([]net.Listener, 3), make([]string, 3)
wg.Add(3)
connc := make(chan string, 2)
for i := range eps {
ln, err := net.Listen("tcp", ":0")
testutil.AssertNil(t, err)
lns[i], eps[i] = ln, ln.Addr().String()
go func() {
defer wg.Done()
for {
conn, err := ln.Accept()
if err != nil {
return
}
_, err = conn.Read(make([]byte, 512))
conn.Close()
if err == nil {
select {
case connc <- ln.Addr().String():
// sleep some so balancer catches up
// before attempted next reconnect.
time.Sleep(50 * time.Millisecond)
default:
}
}
}
}()
}
tf := func(s string) (bool, error) { return false, nil }
hb := newHealthBalancer(eps, 5*time.Second, tf)
conn, err := grpc.Dial("", grpc.WithInsecure(), grpc.WithBalancer(hb))
testutil.AssertNil(t, err)
defer conn.Close()
kvc := pb.NewKVClient(conn)
<-hb.ready()
kvc.Range(context.TODO(), &pb.RangeRequest{})
ep1 := <-connc
kvc.Range(context.TODO(), &pb.RangeRequest{})
ep2 := <-connc
for _, ln := range lns {
ln.Close()
}
wg.Wait()
if ep1 == ep2 {
t.Fatalf("expected %q != %q", ep1, ep2)
}
}
// TestBalancerDoNotBlockOnClose ensures that balancer and grpc don't deadlock each other
// due to rapid open/close conn. The deadlock causes balancer.Close() to block forever.
// See issue: https://github.com/coreos/etcd/issues/7283 for more detail.
func TestBalancerDoNotBlockOnClose(t *testing.T) {
defer testutil.AfterTest(t)
kcl := newKillConnListener(t, 3)
defer kcl.close()
for i := 0; i < 5; i++ {
hb := newHealthBalancer(kcl.endpoints(), minHealthRetryDuration, func(string) (bool, error) { return true, nil })
conn, err := grpc.Dial("", grpc.WithInsecure(), grpc.WithBalancer(hb))
if err != nil {
t.Fatal(err)
}
kvc := pb.NewKVClient(conn)
<-hb.readyc
var wg sync.WaitGroup
wg.Add(100)
cctx, cancel := context.WithCancel(context.TODO())
for j := 0; j < 100; j++ {
go func() {
defer wg.Done()
kvc.Range(cctx, &pb.RangeRequest{}, grpc.FailFast(false))
}()
}
// balancer.Close() might block
// if balancer and grpc deadlock each other.
bclosec, cclosec := make(chan struct{}), make(chan struct{})
go func() {
defer close(bclosec)
hb.Close()
}()
go func() {
defer close(cclosec)
conn.Close()
}()
select {
case <-bclosec:
case <-time.After(3 * time.Second):
testutil.FatalStack(t, "balancer close timeout")
}
select {
case <-cclosec:
case <-time.After(3 * time.Second):
t.Fatal("grpc conn close timeout")
}
cancel()
wg.Wait()
}
}
// killConnListener listens incoming conn and kills it immediately.
type killConnListener struct {
wg sync.WaitGroup
eps []string
stopc chan struct{}
t *testing.T
}
func newKillConnListener(t *testing.T, size int) *killConnListener {
kcl := &killConnListener{stopc: make(chan struct{}), t: t}
for i := 0; i < size; i++ {
ln, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
kcl.eps = append(kcl.eps, ln.Addr().String())
kcl.wg.Add(1)
go kcl.listen(ln)
}
return kcl
}
func (kcl *killConnListener) endpoints() []string {
return kcl.eps
}
func (kcl *killConnListener) listen(l net.Listener) {
go func() {
defer kcl.wg.Done()
for {
conn, err := l.Accept()
select {
case <-kcl.stopc:
return
default:
}
if err != nil {
kcl.t.Fatal(err)
}
time.Sleep(1 * time.Millisecond)
conn.Close()
}
}()
<-kcl.stopc
l.Close()
}
func (kcl *killConnListener) close() {
close(kcl.stopc)
kcl.wg.Wait()
}

View File

@ -25,6 +25,7 @@ import (
"github.com/coreos/etcd/etcdserver/api/v3rpc/rpctypes"
"github.com/coreos/etcd/integration"
"github.com/coreos/etcd/pkg/testutil"
"google.golang.org/grpc"
)
// TestBalancerUnderBlackholeKeepAliveWatch tests when watch discovers it cannot talk to
@ -35,7 +36,7 @@ func TestBalancerUnderBlackholeKeepAliveWatch(t *testing.T) {
clus := integration.NewClusterV3(t, &integration.ClusterConfig{
Size: 2,
GRPCKeepAliveMinTime: 1 * time.Millisecond, // avoid too_many_pings
GRPCKeepAliveMinTime: time.Millisecond, // avoid too_many_pings
})
defer clus.Terminate(t)
@ -43,8 +44,9 @@ func TestBalancerUnderBlackholeKeepAliveWatch(t *testing.T) {
ccfg := clientv3.Config{
Endpoints: []string{eps[0]},
DialTimeout: 1 * time.Second,
DialKeepAliveTime: 1 * time.Second,
DialTimeout: time.Second,
DialOptions: []grpc.DialOption{grpc.WithBlock()},
DialKeepAliveTime: time.Second,
DialKeepAliveTimeout: 500 * time.Millisecond,
}
@ -70,6 +72,9 @@ func TestBalancerUnderBlackholeKeepAliveWatch(t *testing.T) {
// endpoint can switch to eps[1] when it detects the failure of eps[0]
cli.SetEndpoints(eps...)
// give enough time for balancer resolution
time.Sleep(5 * time.Second)
clus.Members[0].Blackhole()
if _, err = clus.Client(1).Put(context.TODO(), "foo", "bar"); err != nil {
@ -106,7 +111,7 @@ func TestBalancerUnderBlackholeKeepAliveWatch(t *testing.T) {
func TestBalancerUnderBlackholeNoKeepAlivePut(t *testing.T) {
testBalancerUnderBlackholeNoKeepAlive(t, func(cli *clientv3.Client, ctx context.Context) error {
_, err := cli.Put(ctx, "foo", "bar")
if err == context.DeadlineExceeded || isServerCtxTimeout(err) || err == rpctypes.ErrTimeout {
if isClientTimeout(err) || isServerCtxTimeout(err) || err == rpctypes.ErrTimeout {
return errExpected
}
return err
@ -116,7 +121,7 @@ func TestBalancerUnderBlackholeNoKeepAlivePut(t *testing.T) {
func TestBalancerUnderBlackholeNoKeepAliveDelete(t *testing.T) {
testBalancerUnderBlackholeNoKeepAlive(t, func(cli *clientv3.Client, ctx context.Context) error {
_, err := cli.Delete(ctx, "foo")
if err == context.DeadlineExceeded || isServerCtxTimeout(err) || err == rpctypes.ErrTimeout {
if isClientTimeout(err) || isServerCtxTimeout(err) || err == rpctypes.ErrTimeout {
return errExpected
}
return err
@ -129,7 +134,7 @@ func TestBalancerUnderBlackholeNoKeepAliveTxn(t *testing.T) {
If(clientv3.Compare(clientv3.Version("foo"), "=", 0)).
Then(clientv3.OpPut("foo", "bar")).
Else(clientv3.OpPut("foo", "baz")).Commit()
if err == context.DeadlineExceeded || isServerCtxTimeout(err) || err == rpctypes.ErrTimeout {
if isClientTimeout(err) || isServerCtxTimeout(err) || err == rpctypes.ErrTimeout {
return errExpected
}
return err
@ -139,7 +144,7 @@ func TestBalancerUnderBlackholeNoKeepAliveTxn(t *testing.T) {
func TestBalancerUnderBlackholeNoKeepAliveLinearizableGet(t *testing.T) {
testBalancerUnderBlackholeNoKeepAlive(t, func(cli *clientv3.Client, ctx context.Context) error {
_, err := cli.Get(ctx, "a")
if err == context.DeadlineExceeded || isServerCtxTimeout(err) || err == rpctypes.ErrTimeout {
if isClientTimeout(err) || isServerCtxTimeout(err) || err == rpctypes.ErrTimeout {
return errExpected
}
return err
@ -149,7 +154,7 @@ func TestBalancerUnderBlackholeNoKeepAliveLinearizableGet(t *testing.T) {
func TestBalancerUnderBlackholeNoKeepAliveSerializableGet(t *testing.T) {
testBalancerUnderBlackholeNoKeepAlive(t, func(cli *clientv3.Client, ctx context.Context) error {
_, err := cli.Get(ctx, "a", clientv3.WithSerializable())
if err == context.DeadlineExceeded || isServerCtxTimeout(err) {
if isClientTimeout(err) || isServerCtxTimeout(err) {
return errExpected
}
return err
@ -172,6 +177,7 @@ func testBalancerUnderBlackholeNoKeepAlive(t *testing.T, op func(*clientv3.Clien
ccfg := clientv3.Config{
Endpoints: []string{eps[0]},
DialTimeout: 1 * time.Second,
DialOptions: []grpc.DialOption{grpc.WithBlock()},
}
cli, err := clientv3.New(ccfg)
if err != nil {
@ -189,22 +195,23 @@ func testBalancerUnderBlackholeNoKeepAlive(t *testing.T, op func(*clientv3.Clien
// blackhole eps[0]
clus.Members[0].Blackhole()
// fail first due to blackhole, retry should succeed
// With round robin balancer, client will make a request to a healthy endpoint
// within a few requests.
// TODO: first operation can succeed
// when gRPC supports better retry on non-delivered request
for i := 0; i < 2; i++ {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
for i := 0; i < 5; i++ {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
err = op(cli, ctx)
cancel()
if err == nil {
break
}
if i == 0 {
if err != errExpected {
t.Errorf("#%d: expected %v, got %v", i, errExpected, err)
}
} else if err != nil {
} else if err == errExpected {
t.Logf("#%d: current error %v", i, err)
} else {
t.Errorf("#%d: failed with error %v", i, err)
}
}
if err != nil {
t.Fatal(err)
}
}

View File

@ -26,6 +26,7 @@ import (
"github.com/coreos/etcd/integration"
"github.com/coreos/etcd/pkg/testutil"
"github.com/coreos/etcd/pkg/transport"
"google.golang.org/grpc"
)
var (
@ -58,10 +59,11 @@ func TestDialTLSExpired(t *testing.T) {
_, err = clientv3.New(clientv3.Config{
Endpoints: []string{clus.Members[0].GRPCAddr()},
DialTimeout: 3 * time.Second,
DialOptions: []grpc.DialOption{grpc.WithBlock()},
TLS: tls,
})
if err != context.DeadlineExceeded {
t.Fatalf("expected %v, got %v", context.DeadlineExceeded, err)
if !isClientTimeout(err) {
t.Fatalf("expected dial timeout error, got %v", err)
}
}
@ -72,12 +74,18 @@ func TestDialTLSNoConfig(t *testing.T) {
clus := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 1, ClientTLS: &testTLSInfo, SkipCreatingClient: true})
defer clus.Terminate(t)
// expect "signed by unknown authority"
_, err := clientv3.New(clientv3.Config{
c, err := clientv3.New(clientv3.Config{
Endpoints: []string{clus.Members[0].GRPCAddr()},
DialTimeout: time.Second,
DialOptions: []grpc.DialOption{grpc.WithBlock()},
})
if err != context.DeadlineExceeded {
t.Fatalf("expected %v, got %v", context.DeadlineExceeded, err)
defer func() {
if c != nil {
c.Close()
}
}()
if !isClientTimeout(err) {
t.Fatalf("expected dial timeout error, got %v", err)
}
}
@ -104,7 +112,11 @@ func testDialSetEndpoints(t *testing.T, setBefore bool) {
}
toKill := rand.Intn(len(eps))
cfg := clientv3.Config{Endpoints: []string{eps[toKill]}, DialTimeout: 1 * time.Second}
cfg := clientv3.Config{
Endpoints: []string{eps[toKill]},
DialTimeout: 1 * time.Second,
DialOptions: []grpc.DialOption{grpc.WithBlock()},
}
cli, err := clientv3.New(cfg)
if err != nil {
t.Fatal(err)
@ -121,6 +133,7 @@ func testDialSetEndpoints(t *testing.T, setBefore bool) {
if !setBefore {
cli.SetEndpoints(eps[toKill%3], eps[(toKill+1)%3])
}
time.Sleep(time.Second * 2)
ctx, cancel := context.WithTimeout(context.Background(), integration.RequestWaitTimeout)
if _, err = cli.Get(ctx, "foo", clientv3.WithSerializable()); err != nil {
t.Fatal(err)
@ -158,6 +171,7 @@ func TestRejectOldCluster(t *testing.T) {
cfg := clientv3.Config{
Endpoints: []string{clus.Members[0].GRPCAddr(), clus.Members[1].GRPCAddr()},
DialTimeout: 5 * time.Second,
DialOptions: []grpc.DialOption{grpc.WithBlock()},
RejectOldCluster: true,
}
cli, err := clientv3.New(cfg)

View File

@ -884,7 +884,7 @@ func TestKVLargeRequests(t *testing.T) {
},
// without proper client-side receive size limit
// "code = ResourceExhausted desc = grpc: received message larger than max (5242929 vs. 4194304)"
// "code = ResourceExhausted desc = received message larger than max (5242929 vs. 4194304)"
{
maxRequestBytesServer: 7*1024*1024 + 512*1024,
@ -906,7 +906,7 @@ func TestKVLargeRequests(t *testing.T) {
maxCallSendBytesClient: 10 * 1024 * 1024,
maxCallRecvBytesClient: 0,
valueSize: 10 * 1024 * 1024,
expectError: grpc.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max "),
expectError: grpc.Errorf(codes.ResourceExhausted, "trying to send message larger than max "),
},
{
maxRequestBytesServer: 10 * 1024 * 1024,
@ -920,7 +920,7 @@ func TestKVLargeRequests(t *testing.T) {
maxCallSendBytesClient: 10 * 1024 * 1024,
maxCallRecvBytesClient: 0,
valueSize: 10*1024*1024 + 5,
expectError: grpc.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max "),
expectError: grpc.Errorf(codes.ResourceExhausted, "trying to send message larger than max "),
},
}
for i, test := range tests {
@ -940,7 +940,7 @@ func TestKVLargeRequests(t *testing.T) {
t.Errorf("#%d: expected %v, got %v", i, test.expectError, err)
}
} else if err != nil && !strings.HasPrefix(err.Error(), test.expectError.Error()) {
t.Errorf("#%d: expected %v, got %v", i, test.expectError, err)
t.Errorf("#%d: expected error starting with '%s', got '%s'", i, test.expectError.Error(), err.Error())
}
// put request went through, now expects large response back

View File

@ -19,11 +19,9 @@ import (
"github.com/coreos/etcd/clientv3"
"github.com/coreos/pkg/capnslog"
"google.golang.org/grpc/grpclog"
)
func init() {
capnslog.SetGlobalLogLevel(capnslog.CRITICAL)
clientv3.SetLogger(grpclog.NewLoggerV2(ioutil.Discard, ioutil.Discard, ioutil.Discard))
}

View File

@ -24,8 +24,10 @@ import (
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/etcdserver/api/v3rpc/rpctypes"
pb "github.com/coreos/etcd/etcdserver/etcdserverpb"
"github.com/coreos/etcd/integration"
"github.com/coreos/etcd/pkg/testutil"
"google.golang.org/grpc"
)
var errExpected = errors.New("expected error")
@ -36,7 +38,7 @@ var errExpected = errors.New("expected error")
func TestBalancerUnderNetworkPartitionPut(t *testing.T) {
testBalancerUnderNetworkPartition(t, func(cli *clientv3.Client, ctx context.Context) error {
_, err := cli.Put(ctx, "a", "b")
if err == context.DeadlineExceeded || isServerCtxTimeout(err) || err == rpctypes.ErrTimeout {
if isClientTimeout(err) || isServerCtxTimeout(err) || err == rpctypes.ErrTimeout {
return errExpected
}
return err
@ -46,7 +48,7 @@ func TestBalancerUnderNetworkPartitionPut(t *testing.T) {
func TestBalancerUnderNetworkPartitionDelete(t *testing.T) {
testBalancerUnderNetworkPartition(t, func(cli *clientv3.Client, ctx context.Context) error {
_, err := cli.Delete(ctx, "a")
if err == context.DeadlineExceeded || isServerCtxTimeout(err) || err == rpctypes.ErrTimeout {
if isClientTimeout(err) || isServerCtxTimeout(err) || err == rpctypes.ErrTimeout {
return errExpected
}
return err
@ -59,7 +61,7 @@ func TestBalancerUnderNetworkPartitionTxn(t *testing.T) {
If(clientv3.Compare(clientv3.Version("foo"), "=", 0)).
Then(clientv3.OpPut("foo", "bar")).
Else(clientv3.OpPut("foo", "baz")).Commit()
if err == context.DeadlineExceeded || isServerCtxTimeout(err) || err == rpctypes.ErrTimeout {
if isClientTimeout(err) || isServerCtxTimeout(err) || err == rpctypes.ErrTimeout {
return errExpected
}
return err
@ -72,6 +74,9 @@ func TestBalancerUnderNetworkPartitionTxn(t *testing.T) {
func TestBalancerUnderNetworkPartitionLinearizableGetWithLongTimeout(t *testing.T) {
testBalancerUnderNetworkPartition(t, func(cli *clientv3.Client, ctx context.Context) error {
_, err := cli.Get(ctx, "a")
if err == rpctypes.ErrTimeout {
return errExpected
}
return err
}, 7*time.Second)
}
@ -82,7 +87,7 @@ func TestBalancerUnderNetworkPartitionLinearizableGetWithLongTimeout(t *testing.
func TestBalancerUnderNetworkPartitionLinearizableGetWithShortTimeout(t *testing.T) {
testBalancerUnderNetworkPartition(t, func(cli *clientv3.Client, ctx context.Context) error {
_, err := cli.Get(ctx, "a")
if err == context.DeadlineExceeded || isServerCtxTimeout(err) {
if isClientTimeout(err) || isServerCtxTimeout(err) {
return errExpected
}
return err
@ -111,6 +116,7 @@ func testBalancerUnderNetworkPartition(t *testing.T, op func(*clientv3.Client, c
ccfg := clientv3.Config{
Endpoints: []string{eps[0]},
DialTimeout: 3 * time.Second,
DialOptions: []grpc.DialOption{grpc.WithBlock()},
}
cli, err := clientv3.New(ccfg)
if err != nil {
@ -123,9 +129,10 @@ func testBalancerUnderNetworkPartition(t *testing.T, op func(*clientv3.Client, c
// add other endpoints for later endpoint switch
cli.SetEndpoints(eps...)
time.Sleep(time.Second * 2)
clus.Members[0].InjectPartition(t, clus.Members[1:]...)
for i := 0; i < 2; i++ {
for i := 0; i < 5; i++ {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
err = op(cli, ctx)
cancel()
@ -133,7 +140,7 @@ func testBalancerUnderNetworkPartition(t *testing.T, op func(*clientv3.Client, c
break
}
if err != errExpected {
t.Errorf("#%d: expected %v, got %v", i, errExpected, err)
t.Errorf("#%d: expected '%v', got '%v'", i, errExpected, err)
}
// give enough time for endpoint switch
// TODO: remove random sleep by syncing directly with balancer
@ -165,16 +172,14 @@ func TestBalancerUnderNetworkPartitionLinearizableGetLeaderElection(t *testing.T
cli, err := clientv3.New(clientv3.Config{
Endpoints: []string{eps[(lead+1)%2]},
DialTimeout: 1 * time.Second,
DialTimeout: 2 * time.Second,
DialOptions: []grpc.DialOption{grpc.WithBlock()},
})
if err != nil {
t.Fatal(err)
}
defer cli.Close()
// wait for non-leader to be pinned
mustWaitPinReady(t, cli)
// add all eps to list, so that when the original pined one fails
// the client can switch to other available eps
cli.SetEndpoints(eps[lead], eps[(lead+1)%2])
@ -182,10 +187,15 @@ func TestBalancerUnderNetworkPartitionLinearizableGetLeaderElection(t *testing.T
// isolate leader
clus.Members[lead].InjectPartition(t, clus.Members[(lead+1)%3], clus.Members[(lead+2)%3])
// expects balancer endpoint switch while ongoing leader election
ctx, cancel := context.WithTimeout(context.TODO(), timeout)
_, err = cli.Get(ctx, "a")
cancel()
// expects balancer to round robin to leader within two attempts
for i := 0; i < 2; i++ {
ctx, cancel := context.WithTimeout(context.TODO(), timeout)
_, err = cli.Get(ctx, "a")
cancel()
if err == nil {
break
}
}
if err != nil {
t.Fatal(err)
}
@ -256,3 +266,63 @@ func testBalancerUnderNetworkPartitionWatch(t *testing.T, isolateLeader bool) {
t.Fatal("took too long to detect leader lost")
}
}
func TestDropReadUnderNetworkPartition(t *testing.T) {
defer testutil.AfterTest(t)
clus := integration.NewClusterV3(t, &integration.ClusterConfig{
Size: 3,
SkipCreatingClient: true,
})
defer clus.Terminate(t)
leaderIndex := clus.WaitLeader(t)
// get a follower endpoint
eps := []string{clus.Members[(leaderIndex+1)%3].GRPCAddr()}
ccfg := clientv3.Config{
Endpoints: eps,
DialTimeout: 10 * time.Second,
DialOptions: []grpc.DialOption{grpc.WithBlock()},
}
cli, err := clientv3.New(ccfg)
if err != nil {
t.Fatal(err)
}
defer cli.Close()
// wait for eps[0] to be pinned
mustWaitPinReady(t, cli)
// add other endpoints for later endpoint switch
cli.SetEndpoints(eps...)
time.Sleep(time.Second * 2)
conn, err := cli.Dial(clus.Members[(leaderIndex+1)%3].GRPCAddr())
if err != nil {
t.Fatal(err)
}
defer conn.Close()
clus.Members[leaderIndex].InjectPartition(t, clus.Members[(leaderIndex+1)%3], clus.Members[(leaderIndex+2)%3])
kvc := clientv3.NewKVFromKVClient(pb.NewKVClient(conn), nil)
ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second)
_, err = kvc.Get(ctx, "a")
cancel()
if err.Error() != rpctypes.ErrLeaderChanged.Error() {
t.Fatalf("expected %v, got %v", rpctypes.ErrLeaderChanged, err)
}
for i := 0; i < 5; i++ {
ctx, cancel = context.WithTimeout(context.TODO(), 10*time.Second)
_, err = kvc.Get(ctx, "a")
cancel()
if err != nil {
if err == rpctypes.ErrTimeout {
<-time.After(time.Second)
i++
continue
}
t.Fatalf("expected nil or timeout, got %v", err)
}
// No error returned and no retry required
break
}
}

View File

@ -75,16 +75,16 @@ func TestBalancerUnderServerShutdownWatch(t *testing.T) {
select {
case ev := <-wch:
if werr := ev.Err(); werr != nil {
t.Fatal(werr)
t.Error(werr)
}
if len(ev.Events) != 1 {
t.Fatalf("expected one event, got %+v", ev)
t.Errorf("expected one event, got %+v", ev)
}
if !bytes.Equal(ev.Events[0].Kv.Value, []byte(val)) {
t.Fatalf("expected %q, got %+v", val, ev.Events[0].Kv)
t.Errorf("expected %q, got %+v", val, ev.Events[0].Kv)
}
case <-time.After(7 * time.Second):
t.Fatal("took too long to receive events")
t.Error("took too long to receive events")
}
}()
@ -104,7 +104,7 @@ func TestBalancerUnderServerShutdownWatch(t *testing.T) {
if err == nil {
break
}
if err == context.DeadlineExceeded || isServerCtxTimeout(err) || err == rpctypes.ErrTimeout || err == rpctypes.ErrTimeoutDueToLeaderFail {
if isClientTimeout(err) || isServerCtxTimeout(err) || err == rpctypes.ErrTimeout || err == rpctypes.ErrTimeoutDueToLeaderFail {
continue
}
t.Fatal(err)
@ -337,10 +337,20 @@ func testBalancerUnderServerStopInflightRangeOnRestart(t *testing.T, linearizabl
defer close(donec)
ctx, cancel := context.WithTimeout(context.TODO(), clientTimeout)
readyc <- struct{}{}
_, err := cli.Get(ctx, "abc", gops...)
// TODO: The new grpc load balancer will not pin to an endpoint
// as intended by this test. But it will round robin member within
// two attempts.
// Remove retry loop once the new grpc load balancer provides retry.
for i := 0; i < 2; i++ {
_, err = cli.Get(ctx, "abc", gops...)
if err == nil {
break
}
}
cancel()
if err != nil {
t.Fatal(err)
t.Errorf("unexpected error: %v", err)
}
}()
@ -361,7 +371,57 @@ func isServerCtxTimeout(err error) bool {
if err == nil {
return false
}
ev, _ := status.FromError(err)
ev, ok := status.FromError(err)
if !ok {
return false
}
code := ev.Code()
return code == codes.DeadlineExceeded && strings.Contains(err.Error(), "context deadline exceeded")
}
// In grpc v1.11.3+ dial timeouts can error out with transport.ErrConnClosing. Previously dial timeouts
// would always error out with context.DeadlineExceeded.
func isClientTimeout(err error) bool {
if err == nil {
return false
}
if err == context.DeadlineExceeded {
return true
}
ev, ok := status.FromError(err)
if !ok {
return false
}
code := ev.Code()
return code == codes.DeadlineExceeded
}
func isCanceled(err error) bool {
if err == nil {
return false
}
if err == context.Canceled {
return true
}
ev, ok := status.FromError(err)
if !ok {
return false
}
code := ev.Code()
return code == codes.Canceled
}
func isUnavailable(err error) bool {
if err == nil {
return false
}
if err == context.Canceled {
return true
}
ev, ok := status.FromError(err)
if !ok {
return false
}
code := ev.Code()
return code == codes.Unavailable
}

View File

@ -25,7 +25,7 @@ import (
// mustWaitPinReady waits up to 3-second until connection is up (pin endpoint).
// Fatal on time-out.
func mustWaitPinReady(t *testing.T, cli *clientv3.Client) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
_, err := cli.Get(ctx, "foo")
cancel()
if err != nil {

View File

@ -0,0 +1,125 @@
// Copyright 2018 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build !cluster_proxy
package integration
import (
"context"
"fmt"
"strings"
"testing"
"time"
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/integration"
"github.com/coreos/etcd/pkg/testutil"
)
// TestWatchFragmentDisable ensures that large watch
// response exceeding server-side request limit can
// arrive even without watch response fragmentation.
func TestWatchFragmentDisable(t *testing.T) {
testWatchFragment(t, false, false)
}
// TestWatchFragmentDisableWithGRPCLimit verifies
// large watch response exceeding server-side request
// limit and client-side gRPC response receive limit
// cannot arrive without watch events fragmentation,
// because multiple events exceed client-side gRPC
// response receive limit.
func TestWatchFragmentDisableWithGRPCLimit(t *testing.T) {
testWatchFragment(t, false, true)
}
// TestWatchFragmentEnable ensures that large watch
// response exceeding server-side request limit arrive
// with watch response fragmentation.
func TestWatchFragmentEnable(t *testing.T) {
testWatchFragment(t, true, false)
}
// TestWatchFragmentEnableWithGRPCLimit verifies
// large watch response exceeding server-side request
// limit and client-side gRPC response receive limit
// can arrive only when watch events are fragmented.
func TestWatchFragmentEnableWithGRPCLimit(t *testing.T) {
testWatchFragment(t, true, true)
}
// testWatchFragment triggers watch response that spans over multiple
// revisions exceeding server request limits when combined.
func testWatchFragment(t *testing.T, fragment, exceedRecvLimit bool) {
cfg := &integration.ClusterConfig{
Size: 1,
MaxRequestBytes: 1.5 * 1024 * 1024,
}
if exceedRecvLimit {
cfg.ClientMaxCallRecvMsgSize = 1.5 * 1024 * 1024
}
clus := integration.NewClusterV3(t, cfg)
defer clus.Terminate(t)
cli := clus.Client(0)
errc := make(chan error)
for i := 0; i < 10; i++ {
go func(i int) {
_, err := cli.Put(context.TODO(),
fmt.Sprint("foo", i),
strings.Repeat("a", 1024*1024),
)
errc <- err
}(i)
}
for i := 0; i < 10; i++ {
if err := <-errc; err != nil {
t.Fatalf("failed to put: %v", err)
}
}
opts := []clientv3.OpOption{clientv3.WithPrefix(), clientv3.WithRev(1)}
if fragment {
opts = append(opts, clientv3.WithFragment())
}
wch := cli.Watch(context.TODO(), "foo", opts...)
// expect 10 MiB watch response
select {
case ws := <-wch:
// without fragment, should exceed gRPC client receive limit
if !fragment && exceedRecvLimit {
if len(ws.Events) != 0 {
t.Fatalf("expected 0 events with watch fragmentation, got %d", len(ws.Events))
}
exp := "code = ResourceExhausted desc = grpc: received message larger than max ("
if !strings.Contains(ws.Err().Error(), exp) {
t.Fatalf("expected 'ResourceExhausted' error, got %v", ws.Err())
}
return
}
// still expect merged watch events
if len(ws.Events) != 10 {
t.Fatalf("expected 10 events with watch fragmentation, got %d", len(ws.Events))
}
if ws.Err() != nil {
t.Fatalf("unexpected error %v", ws.Err())
}
case <-time.After(testutil.RequestTimeout):
t.Fatalf("took too long to receive events")
}
}

View File

@ -30,7 +30,6 @@ import (
mvccpb "github.com/coreos/etcd/mvcc/mvccpb"
"github.com/coreos/etcd/pkg/testutil"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)
@ -583,6 +582,78 @@ func testWatchWithProgressNotify(t *testing.T, watchOnPut bool) {
}
}
func TestWatchRequestProgress(t *testing.T) {
testCases := []struct {
name string
watchers []string
}{
{"0-watcher", []string{}},
{"1-watcher", []string{"/"}},
{"2-watcher", []string{"/", "/"}},
}
for _, c := range testCases {
t.Run(c.name, func(t *testing.T) {
defer testutil.AfterTest(t)
watchTimeout := 3 * time.Second
clus := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 3})
defer clus.Terminate(t)
wc := clus.RandClient()
var watchChans []clientv3.WatchChan
for _, prefix := range c.watchers {
watchChans = append(watchChans, wc.Watch(context.Background(), prefix, clientv3.WithPrefix()))
}
_, err := wc.Put(context.Background(), "/a", "1")
if err != nil {
t.Fatal(err)
}
for _, rch := range watchChans {
select {
case resp := <-rch: // wait for notification
if len(resp.Events) != 1 {
t.Fatalf("resp.Events expected 1, got %d", len(resp.Events))
}
case <-time.After(watchTimeout):
t.Fatalf("watch response expected in %v, but timed out", watchTimeout)
}
}
// put a value not being watched to increment revision
_, err = wc.Put(context.Background(), "x", "1")
if err != nil {
t.Fatal(err)
}
err = wc.RequestProgress(context.Background())
if err != nil {
t.Fatal(err)
}
// verify all watch channels receive a progress notify
for _, rch := range watchChans {
select {
case resp := <-rch:
if !resp.IsProgressNotify() {
t.Fatalf("expected resp.IsProgressNotify() == true")
}
if resp.Header.Revision != 3 {
t.Fatalf("resp.Header.Revision expected 3, got %d", resp.Header.Revision)
}
case <-time.After(watchTimeout):
t.Fatalf("progress response expected in %v, but timed out", watchTimeout)
}
}
})
}
}
func TestWatchEventType(t *testing.T) {
cluster := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 1})
defer cluster.Terminate(t)
@ -667,8 +738,9 @@ func TestWatchErrConnClosed(t *testing.T) {
go func() {
defer close(donec)
ch := cli.Watch(context.TODO(), "foo")
if wr := <-ch; grpc.ErrorDesc(wr.Err()) != grpc.ErrClientConnClosing.Error() {
t.Fatalf("expected %v, got %v", grpc.ErrClientConnClosing, grpc.ErrorDesc(wr.Err()))
if wr := <-ch; !isCanceled(wr.Err()) {
t.Errorf("expected context canceled, got %v", wr.Err())
}
}()
@ -699,8 +771,8 @@ func TestWatchAfterClose(t *testing.T) {
donec := make(chan struct{})
go func() {
cli.Watch(context.TODO(), "foo")
if err := cli.Close(); err != nil && err != grpc.ErrClientConnClosing {
t.Fatalf("expected %v, got %v", grpc.ErrClientConnClosing, err)
if err := cli.Close(); err != nil && err != context.Canceled {
t.Errorf("expected %v, got %v", context.Canceled, err)
}
close(donec)
}()
@ -1061,3 +1133,24 @@ func TestWatchCancelDisconnected(t *testing.T) {
t.Fatal("took too long to cancel disconnected watcher")
}
}
// TestWatchClose ensures that close does not return error
func TestWatchClose(t *testing.T) {
runWatchTest(t, testWatchClose)
}
func testWatchClose(t *testing.T, wctx *watchctx) {
ctx, cancel := context.WithCancel(context.Background())
wch := wctx.w.Watch(ctx, "a")
cancel()
if wch == nil {
t.Fatalf("expected watcher channel, got nil")
}
if wctx.w.Close() != nil {
t.Fatalf("watch did not close successfully")
}
wresp, ok := <-wch
if ok {
t.Fatalf("read wch got %v; expected closed channel", wresp)
}
}

View File

@ -85,7 +85,7 @@ func (txn *txnLeasing) eval() (*v3.TxnResponse, error) {
if !ok {
return nil, nil
}
return &v3.TxnResponse{copyHeader(txn.lkv.leases.header), succeeded, resps}, nil
return &v3.TxnResponse{Header: copyHeader(txn.lkv.leases.header), Succeeded: succeeded, Responses: resps}, nil
}
// fallback computes the ops to fetch all possible conflicting

View File

@ -18,28 +18,14 @@ import (
"io/ioutil"
"sync"
"github.com/coreos/etcd/pkg/logutil"
"google.golang.org/grpc/grpclog"
)
// Logger is the logger used by client library.
// It implements grpclog.LoggerV2 interface.
type Logger interface {
grpclog.LoggerV2
// Lvl returns logger if logger's verbosity level >= "lvl".
// Otherwise, logger that discards all logs.
Lvl(lvl int) Logger
// to satisfy capnslog
Print(args ...interface{})
Printf(format string, args ...interface{})
Println(args ...interface{})
}
var (
loggerMu sync.RWMutex
logger Logger
lgMu sync.RWMutex
lg logutil.Logger
)
type settableLogger struct {
@ -49,29 +35,29 @@ type settableLogger struct {
func init() {
// disable client side logs by default
logger = &settableLogger{}
lg = &settableLogger{}
SetLogger(grpclog.NewLoggerV2(ioutil.Discard, ioutil.Discard, ioutil.Discard))
}
// SetLogger sets client-side Logger.
func SetLogger(l grpclog.LoggerV2) {
loggerMu.Lock()
logger = NewLogger(l)
lgMu.Lock()
lg = logutil.NewLogger(l)
// override grpclog so that any changes happen with locking
grpclog.SetLoggerV2(logger)
loggerMu.Unlock()
grpclog.SetLoggerV2(lg)
lgMu.Unlock()
}
// GetLogger returns the current logger.
func GetLogger() Logger {
loggerMu.RLock()
l := logger
loggerMu.RUnlock()
// GetLogger returns the current logutil.Logger.
func GetLogger() logutil.Logger {
lgMu.RLock()
l := lg
lgMu.RUnlock()
return l
}
// NewLogger returns a new Logger with grpclog.LoggerV2.
func NewLogger(gl grpclog.LoggerV2) Logger {
// NewLogger returns a new Logger with logutil.Logger.
func NewLogger(gl grpclog.LoggerV2) logutil.Logger {
return &settableLogger{l: gl}
}
@ -104,32 +90,12 @@ func (s *settableLogger) Print(args ...interface{}) { s.get().In
func (s *settableLogger) Printf(format string, args ...interface{}) { s.get().Infof(format, args...) }
func (s *settableLogger) Println(args ...interface{}) { s.get().Infoln(args...) }
func (s *settableLogger) V(l int) bool { return s.get().V(l) }
func (s *settableLogger) Lvl(lvl int) Logger {
func (s *settableLogger) Lvl(lvl int) grpclog.LoggerV2 {
s.mu.RLock()
l := s.l
s.mu.RUnlock()
if l.V(lvl) {
return s
}
return &noLogger{}
return logutil.NewDiscardLogger()
}
type noLogger struct{}
func (*noLogger) Info(args ...interface{}) {}
func (*noLogger) Infof(format string, args ...interface{}) {}
func (*noLogger) Infoln(args ...interface{}) {}
func (*noLogger) Warning(args ...interface{}) {}
func (*noLogger) Warningf(format string, args ...interface{}) {}
func (*noLogger) Warningln(args ...interface{}) {}
func (*noLogger) Error(args ...interface{}) {}
func (*noLogger) Errorf(format string, args ...interface{}) {}
func (*noLogger) Errorln(args ...interface{}) {}
func (*noLogger) Fatal(args ...interface{}) {}
func (*noLogger) Fatalf(format string, args ...interface{}) {}
func (*noLogger) Fatalln(args ...interface{}) {}
func (*noLogger) Print(args ...interface{}) {}
func (*noLogger) Printf(format string, args ...interface{}) {}
func (*noLogger) Println(args ...interface{}) {}
func (*noLogger) V(l int) bool { return false }
func (ng *noLogger) Lvl(lvl int) Logger { return ng }

View File

@ -1,51 +0,0 @@
// Copyright 2017 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package clientv3
import (
"bytes"
"io/ioutil"
"strings"
"testing"
"google.golang.org/grpc/grpclog"
)
func TestLogger(t *testing.T) {
buf := new(bytes.Buffer)
l := NewLogger(grpclog.NewLoggerV2WithVerbosity(buf, buf, buf, 10))
l.Infof("hello world!")
if !strings.Contains(buf.String(), "hello world!") {
t.Fatalf("expected 'hello world!', got %q", buf.String())
}
buf.Reset()
l.Lvl(10).Infof("Level 10")
l.Lvl(30).Infof("Level 30")
if !strings.Contains(buf.String(), "Level 10") {
t.Fatalf("expected 'Level 10', got %q", buf.String())
}
if strings.Contains(buf.String(), "Level 30") {
t.Fatalf("unexpected 'Level 30', got %q", buf.String())
}
buf.Reset()
l = NewLogger(grpclog.NewLoggerV2(ioutil.Discard, ioutil.Discard, ioutil.Discard))
l.Infof("ignore this")
if len(buf.Bytes()) > 0 {
t.Fatalf("unexpected logs %q", buf.String())
}
}

View File

@ -16,6 +16,7 @@ package clientv3
import (
"context"
"fmt"
"io"
pb "github.com/coreos/etcd/etcdserver/etcdserverpb"
@ -57,6 +58,8 @@ type Maintenance interface {
HashKV(ctx context.Context, endpoint string, rev int64) (*HashKVResponse, error)
// Snapshot provides a reader for a point-in-time snapshot of etcd.
// If the context "ctx" is canceled or timed out, reading from returned
// "io.ReadCloser" would error out (e.g. context.Canceled, context.DeadlineExceeded).
Snapshot(ctx context.Context) (io.ReadCloser, error)
// MoveLeader requests current leader to transfer its leadership to the transferee.
@ -73,9 +76,9 @@ type maintenance struct {
func NewMaintenance(c *Client) Maintenance {
api := &maintenance{
dial: func(endpoint string) (pb.MaintenanceClient, func(), error) {
conn, err := c.dial(endpoint)
conn, err := c.Dial(endpoint)
if err != nil {
return nil, nil, err
return nil, nil, fmt.Errorf("failed to dial endpoint %s with maintenance client: %v", endpoint, err)
}
cancel := func() { conn.Close() }
return RetryMaintenanceClient(c, conn), cancel, nil
@ -173,6 +176,7 @@ func (m *maintenance) Status(ctx context.Context, endpoint string) (*StatusRespo
func (m *maintenance) HashKV(ctx context.Context, endpoint string, rev int64) (*HashKVResponse, error) {
remote, cancel, err := m.dial(endpoint)
if err != nil {
return nil, toErr(ctx, err)
}
defer cancel()
@ -184,7 +188,7 @@ func (m *maintenance) HashKV(ctx context.Context, endpoint string, rev int64) (*
}
func (m *maintenance) Snapshot(ctx context.Context) (io.ReadCloser, error) {
ss, err := m.remote.Snapshot(ctx, &pb.SnapshotRequest{}, m.callOpts...)
ss, err := m.remote.Snapshot(ctx, &pb.SnapshotRequest{}, append(m.callOpts, withMax(defaultStreamMaxRetries))...)
if err != nil {
return nil, toErr(ctx, err)
}

View File

@ -53,6 +53,12 @@ type Op struct {
// for watch, put, delete
prevKV bool
// for watch
// fragmentation should be disabled by default
// if true, split watch events when total exceeds
// "--max-request-bytes" flag value + 512-byte
fragment bool
// for put
ignoreValue bool
ignoreLease bool
@ -511,3 +517,14 @@ func toLeaseTimeToLiveRequest(id LeaseID, opts ...LeaseOption) *pb.LeaseTimeToLi
ret.applyOpts(opts)
return &pb.LeaseTimeToLiveRequest{ID: int64(id), Keys: ret.attachedKeys}
}
// WithFragment to receive raw watch response with fragmentation.
// Fragmentation is disabled by default. If fragmentation is enabled,
// etcd watch server will split watch response before sending to clients
// when the total size of watch events exceed server-side request limit.
// The default server-side request limit is 1.5 MiB, which can be configured
// as "--max-request-bytes" flag value + gRPC-overhead 512 bytes.
// See "etcdserver/api/v3rpc/watch.go" for more details.
func WithFragment() OpOption {
return func(op *Op) { op.fragment = true }
}

View File

@ -16,17 +16,17 @@ package clientv3
import (
"math"
"time"
"google.golang.org/grpc"
)
var (
// Disable gRPC internal retrial logic
// TODO: enable when gRPC retry is stable (FailFast=false)
// Reference:
// - https://github.com/grpc/grpc-go/issues/1532
// - https://github.com/grpc/proposal/blob/master/A6-client-retries.md
defaultFailFast = grpc.FailFast(true)
// client-side handling retrying of request failures where data was not written to the wire or
// where server indicates it did not process the data. gRPC default is default is "FailFast(true)"
// but for etcd we default to "FailFast(false)" to minimize client request error responses due to
// transient failures.
defaultFailFast = grpc.FailFast(false)
// client-side request send limit, gRPC default is math.MaxInt32
// Make sure that "client-side send limit < server-side default send/recv limit"
@ -38,6 +38,22 @@ var (
// because range response can easily exceed request send limits
// Default to math.MaxInt32; writes exceeding server-side send limit fails anyway
defaultMaxCallRecvMsgSize = grpc.MaxCallRecvMsgSize(math.MaxInt32)
// client-side non-streaming retry limit, only applied to requests where server responds with
// a error code clearly indicating it was unable to process the request such as codes.Unavailable.
// If set to 0, retry is disabled.
defaultUnaryMaxRetries uint = 100
// client-side streaming retry limit, only applied to requests where server responds with
// a error code clearly indicating it was unable to process the request such as codes.Unavailable.
// If set to 0, retry is disabled.
defaultStreamMaxRetries = ^uint(0) // max uint
// client-side retry backoff wait between requests.
defaultBackoffWaitBetween = 25 * time.Millisecond
// client-side retry backoff default jitter fraction.
defaultBackoffJitterFraction = 0.10
)
// defaultCallOpts defines a list of default "gRPC.CallOption".

View File

@ -32,465 +32,263 @@ const (
nonRepeatable
)
type rpcFunc func(ctx context.Context) error
type retryRPCFunc func(context.Context, rpcFunc, retryPolicy) error
type retryStopErrFunc func(error) bool
func (rp retryPolicy) String() string {
switch rp {
case repeatable:
return "repeatable"
case nonRepeatable:
return "nonRepeatable"
default:
return "UNKNOWN"
}
}
// isSafeRetryImmutableRPC returns "true" when an immutable request is safe for retry.
//
// immutable requests (e.g. Get) should be retried unless it's
// an obvious server-side error (e.g. rpctypes.ErrRequestTooLarge).
//
// "isRepeatableStopError" returns "true" when an immutable request
// is interrupted by server-side or gRPC-side error and its status
// code is not transient (!= codes.Unavailable).
//
// Returning "true" means retry should stop, since client cannot
// Returning "false" means retry should stop, since client cannot
// handle itself even with retries.
func isRepeatableStopError(err error) bool {
func isSafeRetryImmutableRPC(err error) bool {
eErr := rpctypes.Error(err)
// always stop retry on etcd errors
if serverErr, ok := eErr.(rpctypes.EtcdError); ok && serverErr.Code() != codes.Unavailable {
return true
// interrupted by non-transient server-side or gRPC-side error
// client cannot handle itself (e.g. rpctypes.ErrCompacted)
return false
}
// only retry if unavailable
ev, _ := status.FromError(err)
return ev.Code() != codes.Unavailable
ev, ok := status.FromError(err)
if !ok {
// all errors from RPC is typed "grpc/status.(*statusError)"
// (ref. https://github.com/grpc/grpc-go/pull/1782)
//
// if the error type is not "grpc/status.(*statusError)",
// it could be from "Dial"
// TODO: do not retry for now
// ref. https://github.com/grpc/grpc-go/issues/1581
return false
}
return ev.Code() == codes.Unavailable
}
// isSafeRetryMutableRPC returns "true" when a mutable request is safe for retry.
//
// mutable requests (e.g. Put, Delete, Txn) should only be retried
// when the status code is codes.Unavailable when initial connection
// has not been established (no pinned endpoint).
// has not been established (no endpoint is up).
//
// "isNonRepeatableStopError" returns "true" when a mutable request
// is interrupted by non-transient error that client cannot handle itself,
// or transient error while the connection has already been established
// (pinned endpoint exists).
//
// Returning "true" means retry should stop, otherwise it violates
// Returning "false" means retry should stop, otherwise it violates
// write-at-most-once semantics.
func isNonRepeatableStopError(err error) bool {
ev, _ := status.FromError(err)
if ev.Code() != codes.Unavailable {
return true
func isSafeRetryMutableRPC(err error) bool {
if ev, ok := status.FromError(err); ok && ev.Code() != codes.Unavailable {
// not safe for mutable RPCs
// e.g. interrupted by non-transient error that client cannot handle itself,
// or transient error while the connection has already been established
return false
}
desc := rpctypes.ErrorDesc(err)
return desc != "there is no address available" && desc != "there is no connection available"
}
func (c *Client) newRetryWrapper() retryRPCFunc {
return func(rpcCtx context.Context, f rpcFunc, rp retryPolicy) error {
var isStop retryStopErrFunc
switch rp {
case repeatable:
isStop = isRepeatableStopError
case nonRepeatable:
isStop = isNonRepeatableStopError
}
for {
if err := readyWait(rpcCtx, c.ctx, c.balancer.ConnectNotify()); err != nil {
return err
}
pinned := c.balancer.pinned()
err := f(rpcCtx)
if err == nil {
return nil
}
logger.Lvl(4).Infof("clientv3/retry: error %q on pinned endpoint %q", err.Error(), pinned)
if s, ok := status.FromError(err); ok && (s.Code() == codes.Unavailable || s.Code() == codes.DeadlineExceeded || s.Code() == codes.Internal) {
// mark this before endpoint switch is triggered
c.balancer.hostPortError(pinned, err)
c.balancer.next()
logger.Lvl(4).Infof("clientv3/retry: switching from %q due to error %q", pinned, err.Error())
}
if isStop(err) {
return err
}
}
}
}
func (c *Client) newAuthRetryWrapper(retryf retryRPCFunc) retryRPCFunc {
return func(rpcCtx context.Context, f rpcFunc, rp retryPolicy) error {
for {
pinned := c.balancer.pinned()
err := retryf(rpcCtx, f, rp)
if err == nil {
return nil
}
logger.Lvl(4).Infof("clientv3/auth-retry: error %q on pinned endpoint %q", err.Error(), pinned)
// always stop retry on etcd errors other than invalid auth token
if rpctypes.Error(err) == rpctypes.ErrInvalidAuthToken {
gterr := c.getToken(rpcCtx)
if gterr != nil {
logger.Lvl(4).Infof("clientv3/auth-retry: cannot retry due to error %q(%q) on pinned endpoint %q", err.Error(), gterr.Error(), pinned)
return err // return the original error for simplicity
}
continue
}
return err
}
}
return desc == "there is no address available" || desc == "there is no connection available"
}
type retryKVClient struct {
kc pb.KVClient
retryf retryRPCFunc
kc pb.KVClient
}
// RetryKVClient implements a KVClient.
func RetryKVClient(c *Client) pb.KVClient {
return &retryKVClient{
kc: pb.NewKVClient(c.conn),
retryf: c.newAuthRetryWrapper(c.newRetryWrapper()),
kc: pb.NewKVClient(c.conn),
}
}
func (rkv *retryKVClient) Range(ctx context.Context, in *pb.RangeRequest, opts ...grpc.CallOption) (resp *pb.RangeResponse, err error) {
err = rkv.retryf(ctx, func(rctx context.Context) error {
resp, err = rkv.kc.Range(rctx, in, opts...)
return err
}, repeatable)
return resp, err
return rkv.kc.Range(ctx, in, append(opts, withRetryPolicy(repeatable))...)
}
func (rkv *retryKVClient) Put(ctx context.Context, in *pb.PutRequest, opts ...grpc.CallOption) (resp *pb.PutResponse, err error) {
err = rkv.retryf(ctx, func(rctx context.Context) error {
resp, err = rkv.kc.Put(rctx, in, opts...)
return err
}, nonRepeatable)
return resp, err
return rkv.kc.Put(ctx, in, opts...)
}
func (rkv *retryKVClient) DeleteRange(ctx context.Context, in *pb.DeleteRangeRequest, opts ...grpc.CallOption) (resp *pb.DeleteRangeResponse, err error) {
err = rkv.retryf(ctx, func(rctx context.Context) error {
resp, err = rkv.kc.DeleteRange(rctx, in, opts...)
return err
}, nonRepeatable)
return resp, err
return rkv.kc.DeleteRange(ctx, in, opts...)
}
func (rkv *retryKVClient) Txn(ctx context.Context, in *pb.TxnRequest, opts ...grpc.CallOption) (resp *pb.TxnResponse, err error) {
// TODO: "repeatable" for read-only txn
err = rkv.retryf(ctx, func(rctx context.Context) error {
resp, err = rkv.kc.Txn(rctx, in, opts...)
return err
}, nonRepeatable)
return resp, err
return rkv.kc.Txn(ctx, in, opts...)
}
func (rkv *retryKVClient) Compact(ctx context.Context, in *pb.CompactionRequest, opts ...grpc.CallOption) (resp *pb.CompactionResponse, err error) {
err = rkv.retryf(ctx, func(rctx context.Context) error {
resp, err = rkv.kc.Compact(rctx, in, opts...)
return err
}, nonRepeatable)
return resp, err
return rkv.kc.Compact(ctx, in, opts...)
}
type retryLeaseClient struct {
lc pb.LeaseClient
retryf retryRPCFunc
lc pb.LeaseClient
}
// RetryLeaseClient implements a LeaseClient.
func RetryLeaseClient(c *Client) pb.LeaseClient {
return &retryLeaseClient{
lc: pb.NewLeaseClient(c.conn),
retryf: c.newAuthRetryWrapper(c.newRetryWrapper()),
lc: pb.NewLeaseClient(c.conn),
}
}
func (rlc *retryLeaseClient) LeaseTimeToLive(ctx context.Context, in *pb.LeaseTimeToLiveRequest, opts ...grpc.CallOption) (resp *pb.LeaseTimeToLiveResponse, err error) {
err = rlc.retryf(ctx, func(rctx context.Context) error {
resp, err = rlc.lc.LeaseTimeToLive(rctx, in, opts...)
return err
}, repeatable)
return resp, err
return rlc.lc.LeaseTimeToLive(ctx, in, append(opts, withRetryPolicy(repeatable))...)
}
func (rlc *retryLeaseClient) LeaseLeases(ctx context.Context, in *pb.LeaseLeasesRequest, opts ...grpc.CallOption) (resp *pb.LeaseLeasesResponse, err error) {
err = rlc.retryf(ctx, func(rctx context.Context) error {
resp, err = rlc.lc.LeaseLeases(rctx, in, opts...)
return err
}, repeatable)
return resp, err
return rlc.lc.LeaseLeases(ctx, in, append(opts, withRetryPolicy(repeatable))...)
}
func (rlc *retryLeaseClient) LeaseGrant(ctx context.Context, in *pb.LeaseGrantRequest, opts ...grpc.CallOption) (resp *pb.LeaseGrantResponse, err error) {
err = rlc.retryf(ctx, func(rctx context.Context) error {
resp, err = rlc.lc.LeaseGrant(rctx, in, opts...)
return err
}, repeatable)
return resp, err
return rlc.lc.LeaseGrant(ctx, in, append(opts, withRetryPolicy(repeatable))...)
}
func (rlc *retryLeaseClient) LeaseRevoke(ctx context.Context, in *pb.LeaseRevokeRequest, opts ...grpc.CallOption) (resp *pb.LeaseRevokeResponse, err error) {
err = rlc.retryf(ctx, func(rctx context.Context) error {
resp, err = rlc.lc.LeaseRevoke(rctx, in, opts...)
return err
}, repeatable)
return resp, err
return rlc.lc.LeaseRevoke(ctx, in, append(opts, withRetryPolicy(repeatable))...)
}
func (rlc *retryLeaseClient) LeaseKeepAlive(ctx context.Context, opts ...grpc.CallOption) (stream pb.Lease_LeaseKeepAliveClient, err error) {
err = rlc.retryf(ctx, func(rctx context.Context) error {
stream, err = rlc.lc.LeaseKeepAlive(rctx, opts...)
return err
}, repeatable)
return stream, err
return rlc.lc.LeaseKeepAlive(ctx, append(opts, withRetryPolicy(repeatable))...)
}
type retryClusterClient struct {
cc pb.ClusterClient
retryf retryRPCFunc
cc pb.ClusterClient
}
// RetryClusterClient implements a ClusterClient.
func RetryClusterClient(c *Client) pb.ClusterClient {
return &retryClusterClient{
cc: pb.NewClusterClient(c.conn),
retryf: c.newRetryWrapper(),
cc: pb.NewClusterClient(c.conn),
}
}
func (rcc *retryClusterClient) MemberList(ctx context.Context, in *pb.MemberListRequest, opts ...grpc.CallOption) (resp *pb.MemberListResponse, err error) {
err = rcc.retryf(ctx, func(rctx context.Context) error {
resp, err = rcc.cc.MemberList(rctx, in, opts...)
return err
}, repeatable)
return resp, err
return rcc.cc.MemberList(ctx, in, append(opts, withRetryPolicy(repeatable))...)
}
func (rcc *retryClusterClient) MemberAdd(ctx context.Context, in *pb.MemberAddRequest, opts ...grpc.CallOption) (resp *pb.MemberAddResponse, err error) {
err = rcc.retryf(ctx, func(rctx context.Context) error {
resp, err = rcc.cc.MemberAdd(rctx, in, opts...)
return err
}, nonRepeatable)
return resp, err
return rcc.cc.MemberAdd(ctx, in, opts...)
}
func (rcc *retryClusterClient) MemberRemove(ctx context.Context, in *pb.MemberRemoveRequest, opts ...grpc.CallOption) (resp *pb.MemberRemoveResponse, err error) {
err = rcc.retryf(ctx, func(rctx context.Context) error {
resp, err = rcc.cc.MemberRemove(rctx, in, opts...)
return err
}, nonRepeatable)
return resp, err
return rcc.cc.MemberRemove(ctx, in, opts...)
}
func (rcc *retryClusterClient) MemberUpdate(ctx context.Context, in *pb.MemberUpdateRequest, opts ...grpc.CallOption) (resp *pb.MemberUpdateResponse, err error) {
err = rcc.retryf(ctx, func(rctx context.Context) error {
resp, err = rcc.cc.MemberUpdate(rctx, in, opts...)
return err
}, nonRepeatable)
return resp, err
return rcc.cc.MemberUpdate(ctx, in, opts...)
}
type retryMaintenanceClient struct {
mc pb.MaintenanceClient
retryf retryRPCFunc
mc pb.MaintenanceClient
}
// RetryMaintenanceClient implements a Maintenance.
func RetryMaintenanceClient(c *Client, conn *grpc.ClientConn) pb.MaintenanceClient {
return &retryMaintenanceClient{
mc: pb.NewMaintenanceClient(conn),
retryf: c.newRetryWrapper(),
mc: pb.NewMaintenanceClient(conn),
}
}
func (rmc *retryMaintenanceClient) Alarm(ctx context.Context, in *pb.AlarmRequest, opts ...grpc.CallOption) (resp *pb.AlarmResponse, err error) {
err = rmc.retryf(ctx, func(rctx context.Context) error {
resp, err = rmc.mc.Alarm(rctx, in, opts...)
return err
}, repeatable)
return resp, err
return rmc.mc.Alarm(ctx, in, append(opts, withRetryPolicy(repeatable))...)
}
func (rmc *retryMaintenanceClient) Status(ctx context.Context, in *pb.StatusRequest, opts ...grpc.CallOption) (resp *pb.StatusResponse, err error) {
err = rmc.retryf(ctx, func(rctx context.Context) error {
resp, err = rmc.mc.Status(rctx, in, opts...)
return err
}, repeatable)
return resp, err
return rmc.mc.Status(ctx, in, append(opts, withRetryPolicy(repeatable))...)
}
func (rmc *retryMaintenanceClient) Hash(ctx context.Context, in *pb.HashRequest, opts ...grpc.CallOption) (resp *pb.HashResponse, err error) {
err = rmc.retryf(ctx, func(rctx context.Context) error {
resp, err = rmc.mc.Hash(rctx, in, opts...)
return err
}, repeatable)
return resp, err
return rmc.mc.Hash(ctx, in, append(opts, withRetryPolicy(repeatable))...)
}
func (rmc *retryMaintenanceClient) HashKV(ctx context.Context, in *pb.HashKVRequest, opts ...grpc.CallOption) (resp *pb.HashKVResponse, err error) {
err = rmc.retryf(ctx, func(rctx context.Context) error {
resp, err = rmc.mc.HashKV(rctx, in, opts...)
return err
}, repeatable)
return resp, err
return rmc.mc.HashKV(ctx, in, append(opts, withRetryPolicy(repeatable))...)
}
func (rmc *retryMaintenanceClient) Snapshot(ctx context.Context, in *pb.SnapshotRequest, opts ...grpc.CallOption) (stream pb.Maintenance_SnapshotClient, err error) {
err = rmc.retryf(ctx, func(rctx context.Context) error {
stream, err = rmc.mc.Snapshot(rctx, in, opts...)
return err
}, repeatable)
return stream, err
return rmc.mc.Snapshot(ctx, in, append(opts, withRetryPolicy(repeatable))...)
}
func (rmc *retryMaintenanceClient) MoveLeader(ctx context.Context, in *pb.MoveLeaderRequest, opts ...grpc.CallOption) (resp *pb.MoveLeaderResponse, err error) {
err = rmc.retryf(ctx, func(rctx context.Context) error {
resp, err = rmc.mc.MoveLeader(rctx, in, opts...)
return err
}, repeatable)
return resp, err
return rmc.mc.MoveLeader(ctx, in, append(opts, withRetryPolicy(repeatable))...)
}
func (rmc *retryMaintenanceClient) Defragment(ctx context.Context, in *pb.DefragmentRequest, opts ...grpc.CallOption) (resp *pb.DefragmentResponse, err error) {
err = rmc.retryf(ctx, func(rctx context.Context) error {
resp, err = rmc.mc.Defragment(rctx, in, opts...)
return err
}, nonRepeatable)
return resp, err
return rmc.mc.Defragment(ctx, in, opts...)
}
type retryAuthClient struct {
ac pb.AuthClient
retryf retryRPCFunc
ac pb.AuthClient
}
// RetryAuthClient implements a AuthClient.
func RetryAuthClient(c *Client) pb.AuthClient {
return &retryAuthClient{
ac: pb.NewAuthClient(c.conn),
retryf: c.newRetryWrapper(),
ac: pb.NewAuthClient(c.conn),
}
}
func (rac *retryAuthClient) UserList(ctx context.Context, in *pb.AuthUserListRequest, opts ...grpc.CallOption) (resp *pb.AuthUserListResponse, err error) {
err = rac.retryf(ctx, func(rctx context.Context) error {
resp, err = rac.ac.UserList(rctx, in, opts...)
return err
}, repeatable)
return resp, err
return rac.ac.UserList(ctx, in, append(opts, withRetryPolicy(repeatable))...)
}
func (rac *retryAuthClient) UserGet(ctx context.Context, in *pb.AuthUserGetRequest, opts ...grpc.CallOption) (resp *pb.AuthUserGetResponse, err error) {
err = rac.retryf(ctx, func(rctx context.Context) error {
resp, err = rac.ac.UserGet(rctx, in, opts...)
return err
}, repeatable)
return resp, err
return rac.ac.UserGet(ctx, in, append(opts, withRetryPolicy(repeatable))...)
}
func (rac *retryAuthClient) RoleGet(ctx context.Context, in *pb.AuthRoleGetRequest, opts ...grpc.CallOption) (resp *pb.AuthRoleGetResponse, err error) {
err = rac.retryf(ctx, func(rctx context.Context) error {
resp, err = rac.ac.RoleGet(rctx, in, opts...)
return err
}, repeatable)
return resp, err
return rac.ac.RoleGet(ctx, in, append(opts, withRetryPolicy(repeatable))...)
}
func (rac *retryAuthClient) RoleList(ctx context.Context, in *pb.AuthRoleListRequest, opts ...grpc.CallOption) (resp *pb.AuthRoleListResponse, err error) {
err = rac.retryf(ctx, func(rctx context.Context) error {
resp, err = rac.ac.RoleList(rctx, in, opts...)
return err
}, repeatable)
return resp, err
return rac.ac.RoleList(ctx, in, append(opts, withRetryPolicy(repeatable))...)
}
func (rac *retryAuthClient) AuthEnable(ctx context.Context, in *pb.AuthEnableRequest, opts ...grpc.CallOption) (resp *pb.AuthEnableResponse, err error) {
err = rac.retryf(ctx, func(rctx context.Context) error {
resp, err = rac.ac.AuthEnable(rctx, in, opts...)
return err
}, nonRepeatable)
return resp, err
return rac.ac.AuthEnable(ctx, in, opts...)
}
func (rac *retryAuthClient) AuthDisable(ctx context.Context, in *pb.AuthDisableRequest, opts ...grpc.CallOption) (resp *pb.AuthDisableResponse, err error) {
err = rac.retryf(ctx, func(rctx context.Context) error {
resp, err = rac.ac.AuthDisable(rctx, in, opts...)
return err
}, nonRepeatable)
return resp, err
return rac.ac.AuthDisable(ctx, in, opts...)
}
func (rac *retryAuthClient) UserAdd(ctx context.Context, in *pb.AuthUserAddRequest, opts ...grpc.CallOption) (resp *pb.AuthUserAddResponse, err error) {
err = rac.retryf(ctx, func(rctx context.Context) error {
resp, err = rac.ac.UserAdd(rctx, in, opts...)
return err
}, nonRepeatable)
return resp, err
return rac.ac.UserAdd(ctx, in, opts...)
}
func (rac *retryAuthClient) UserDelete(ctx context.Context, in *pb.AuthUserDeleteRequest, opts ...grpc.CallOption) (resp *pb.AuthUserDeleteResponse, err error) {
err = rac.retryf(ctx, func(rctx context.Context) error {
resp, err = rac.ac.UserDelete(rctx, in, opts...)
return err
}, nonRepeatable)
return resp, err
return rac.ac.UserDelete(ctx, in, opts...)
}
func (rac *retryAuthClient) UserChangePassword(ctx context.Context, in *pb.AuthUserChangePasswordRequest, opts ...grpc.CallOption) (resp *pb.AuthUserChangePasswordResponse, err error) {
err = rac.retryf(ctx, func(rctx context.Context) error {
resp, err = rac.ac.UserChangePassword(rctx, in, opts...)
return err
}, nonRepeatable)
return resp, err
return rac.ac.UserChangePassword(ctx, in, opts...)
}
func (rac *retryAuthClient) UserGrantRole(ctx context.Context, in *pb.AuthUserGrantRoleRequest, opts ...grpc.CallOption) (resp *pb.AuthUserGrantRoleResponse, err error) {
err = rac.retryf(ctx, func(rctx context.Context) error {
resp, err = rac.ac.UserGrantRole(rctx, in, opts...)
return err
}, nonRepeatable)
return resp, err
return rac.ac.UserGrantRole(ctx, in, opts...)
}
func (rac *retryAuthClient) UserRevokeRole(ctx context.Context, in *pb.AuthUserRevokeRoleRequest, opts ...grpc.CallOption) (resp *pb.AuthUserRevokeRoleResponse, err error) {
err = rac.retryf(ctx, func(rctx context.Context) error {
resp, err = rac.ac.UserRevokeRole(rctx, in, opts...)
return err
}, nonRepeatable)
return resp, err
return rac.ac.UserRevokeRole(ctx, in, opts...)
}
func (rac *retryAuthClient) RoleAdd(ctx context.Context, in *pb.AuthRoleAddRequest, opts ...grpc.CallOption) (resp *pb.AuthRoleAddResponse, err error) {
err = rac.retryf(ctx, func(rctx context.Context) error {
resp, err = rac.ac.RoleAdd(rctx, in, opts...)
return err
}, nonRepeatable)
return resp, err
return rac.ac.RoleAdd(ctx, in, opts...)
}
func (rac *retryAuthClient) RoleDelete(ctx context.Context, in *pb.AuthRoleDeleteRequest, opts ...grpc.CallOption) (resp *pb.AuthRoleDeleteResponse, err error) {
err = rac.retryf(ctx, func(rctx context.Context) error {
resp, err = rac.ac.RoleDelete(rctx, in, opts...)
return err
}, nonRepeatable)
return resp, err
return rac.ac.RoleDelete(ctx, in, opts...)
}
func (rac *retryAuthClient) RoleGrantPermission(ctx context.Context, in *pb.AuthRoleGrantPermissionRequest, opts ...grpc.CallOption) (resp *pb.AuthRoleGrantPermissionResponse, err error) {
err = rac.retryf(ctx, func(rctx context.Context) error {
resp, err = rac.ac.RoleGrantPermission(rctx, in, opts...)
return err
}, nonRepeatable)
return resp, err
return rac.ac.RoleGrantPermission(ctx, in, opts...)
}
func (rac *retryAuthClient) RoleRevokePermission(ctx context.Context, in *pb.AuthRoleRevokePermissionRequest, opts ...grpc.CallOption) (resp *pb.AuthRoleRevokePermissionResponse, err error) {
err = rac.retryf(ctx, func(rctx context.Context) error {
resp, err = rac.ac.RoleRevokePermission(rctx, in, opts...)
return err
}, nonRepeatable)
return resp, err
return rac.ac.RoleRevokePermission(ctx, in, opts...)
}
func (rac *retryAuthClient) Authenticate(ctx context.Context, in *pb.AuthenticateRequest, opts ...grpc.CallOption) (resp *pb.AuthenticateResponse, err error) {
err = rac.retryf(ctx, func(rctx context.Context) error {
resp, err = rac.ac.Authenticate(rctx, in, opts...)
return err
}, nonRepeatable)
return resp, err
return rac.ac.Authenticate(ctx, in, opts...)
}

View File

@ -0,0 +1,389 @@
// Copyright 2016 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Based on github.com/grpc-ecosystem/go-grpc-middleware/retry, but modified to support the more
// fine grained error checking required by write-at-most-once retry semantics of etcd.
package clientv3
import (
"context"
"io"
"sync"
"time"
"github.com/coreos/etcd/etcdserver/api/v3rpc/rpctypes"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
// unaryClientInterceptor returns a new retrying unary client interceptor.
//
// The default configuration of the interceptor is to not retry *at all*. This behaviour can be
// changed through options (e.g. WithMax) on creation of the interceptor or on call (through grpc.CallOptions).
func (c *Client) unaryClientInterceptor(logger *zap.Logger, optFuncs ...retryOption) grpc.UnaryClientInterceptor {
intOpts := reuseOrNewWithCallOptions(defaultOptions, optFuncs)
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
grpcOpts, retryOpts := filterCallOptions(opts)
callOpts := reuseOrNewWithCallOptions(intOpts, retryOpts)
// short circuit for simplicity, and avoiding allocations.
if callOpts.max == 0 {
return invoker(ctx, method, req, reply, cc, grpcOpts...)
}
var lastErr error
for attempt := uint(0); attempt < callOpts.max; attempt++ {
if err := waitRetryBackoff(ctx, attempt, callOpts); err != nil {
return err
}
logger.Debug(
"retrying of unary invoker",
zap.String("target", cc.Target()),
zap.Uint("attempt", attempt),
)
lastErr = invoker(ctx, method, req, reply, cc, grpcOpts...)
if lastErr == nil {
return nil
}
logger.Warn(
"retrying of unary invoker failed",
zap.String("target", cc.Target()),
zap.Uint("attempt", attempt),
zap.Error(lastErr),
)
if isContextError(lastErr) {
if ctx.Err() != nil {
// its the context deadline or cancellation.
return lastErr
}
// its the callCtx deadline or cancellation, in which case try again.
continue
}
if callOpts.retryAuth && rpctypes.Error(lastErr) == rpctypes.ErrInvalidAuthToken {
gterr := c.getToken(ctx)
if gterr != nil {
logger.Warn(
"retrying of unary invoker failed to fetch new auth token",
zap.String("target", cc.Target()),
zap.Error(gterr),
)
return gterr // lastErr must be invalid auth token
}
continue
}
if !isSafeRetry(c.lg, lastErr, callOpts) {
return lastErr
}
}
return lastErr
}
}
// streamClientInterceptor returns a new retrying stream client interceptor for server side streaming calls.
//
// The default configuration of the interceptor is to not retry *at all*. This behaviour can be
// changed through options (e.g. WithMax) on creation of the interceptor or on call (through grpc.CallOptions).
//
// Retry logic is available *only for ServerStreams*, i.e. 1:n streams, as the internal logic needs
// to buffer the messages sent by the client. If retry is enabled on any other streams (ClientStreams,
// BidiStreams), the retry interceptor will fail the call.
func (c *Client) streamClientInterceptor(logger *zap.Logger, optFuncs ...retryOption) grpc.StreamClientInterceptor {
intOpts := reuseOrNewWithCallOptions(defaultOptions, optFuncs)
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
grpcOpts, retryOpts := filterCallOptions(opts)
callOpts := reuseOrNewWithCallOptions(intOpts, retryOpts)
// short circuit for simplicity, and avoiding allocations.
if callOpts.max == 0 {
return streamer(ctx, desc, cc, method, grpcOpts...)
}
if desc.ClientStreams {
return nil, status.Errorf(codes.Unimplemented, "clientv3/retry_interceptor: cannot retry on ClientStreams, set Disable()")
}
newStreamer, err := streamer(ctx, desc, cc, method, grpcOpts...)
logger.Warn("retry stream intercept", zap.Error(err))
if err != nil {
// TODO(mwitkow): Maybe dial and transport errors should be retriable?
return nil, err
}
retryingStreamer := &serverStreamingRetryingStream{
client: c,
ClientStream: newStreamer,
callOpts: callOpts,
ctx: ctx,
streamerCall: func(ctx context.Context) (grpc.ClientStream, error) {
return streamer(ctx, desc, cc, method, grpcOpts...)
},
}
return retryingStreamer, nil
}
}
// type serverStreamingRetryingStream is the implementation of grpc.ClientStream that acts as a
// proxy to the underlying call. If any of the RecvMsg() calls fail, it will try to reestablish
// a new ClientStream according to the retry policy.
type serverStreamingRetryingStream struct {
grpc.ClientStream
client *Client
bufferedSends []interface{} // single message that the client can sen
receivedGood bool // indicates whether any prior receives were successful
wasClosedSend bool // indicates that CloseSend was closed
ctx context.Context
callOpts *options
streamerCall func(ctx context.Context) (grpc.ClientStream, error)
mu sync.RWMutex
}
func (s *serverStreamingRetryingStream) setStream(clientStream grpc.ClientStream) {
s.mu.Lock()
s.ClientStream = clientStream
s.mu.Unlock()
}
func (s *serverStreamingRetryingStream) getStream() grpc.ClientStream {
s.mu.RLock()
defer s.mu.RUnlock()
return s.ClientStream
}
func (s *serverStreamingRetryingStream) SendMsg(m interface{}) error {
s.mu.Lock()
s.bufferedSends = append(s.bufferedSends, m)
s.mu.Unlock()
return s.getStream().SendMsg(m)
}
func (s *serverStreamingRetryingStream) CloseSend() error {
s.mu.Lock()
s.wasClosedSend = true
s.mu.Unlock()
return s.getStream().CloseSend()
}
func (s *serverStreamingRetryingStream) Header() (metadata.MD, error) {
return s.getStream().Header()
}
func (s *serverStreamingRetryingStream) Trailer() metadata.MD {
return s.getStream().Trailer()
}
func (s *serverStreamingRetryingStream) RecvMsg(m interface{}) error {
attemptRetry, lastErr := s.receiveMsgAndIndicateRetry(m)
if !attemptRetry {
return lastErr // success or hard failure
}
// We start off from attempt 1, because zeroth was already made on normal SendMsg().
for attempt := uint(1); attempt < s.callOpts.max; attempt++ {
if err := waitRetryBackoff(s.ctx, attempt, s.callOpts); err != nil {
return err
}
newStream, err := s.reestablishStreamAndResendBuffer(s.ctx)
if err != nil {
// TODO(mwitkow): Maybe dial and transport errors should be retriable?
return err
}
s.setStream(newStream)
attemptRetry, lastErr = s.receiveMsgAndIndicateRetry(m)
//fmt.Printf("Received message and indicate: %v %v\n", attemptRetry, lastErr)
if !attemptRetry {
return lastErr
}
}
return lastErr
}
func (s *serverStreamingRetryingStream) receiveMsgAndIndicateRetry(m interface{}) (bool, error) {
s.mu.RLock()
wasGood := s.receivedGood
s.mu.RUnlock()
err := s.getStream().RecvMsg(m)
if err == nil || err == io.EOF {
s.mu.Lock()
s.receivedGood = true
s.mu.Unlock()
return false, err
} else if wasGood {
// previous RecvMsg in the stream succeeded, no retry logic should interfere
return false, err
}
if isContextError(err) {
if s.ctx.Err() != nil {
return false, err
}
// its the callCtx deadline or cancellation, in which case try again.
return true, err
}
if s.callOpts.retryAuth && rpctypes.Error(err) == rpctypes.ErrInvalidAuthToken {
gterr := s.client.getToken(s.ctx)
if gterr != nil {
s.client.lg.Warn("retry failed to fetch new auth token", zap.Error(gterr))
return false, err // return the original error for simplicity
}
return true, err
}
return isSafeRetry(s.client.lg, err, s.callOpts), err
}
func (s *serverStreamingRetryingStream) reestablishStreamAndResendBuffer(callCtx context.Context) (grpc.ClientStream, error) {
s.mu.RLock()
bufferedSends := s.bufferedSends
s.mu.RUnlock()
newStream, err := s.streamerCall(callCtx)
if err != nil {
return nil, err
}
for _, msg := range bufferedSends {
if err := newStream.SendMsg(msg); err != nil {
return nil, err
}
}
if err := newStream.CloseSend(); err != nil {
return nil, err
}
return newStream, nil
}
func waitRetryBackoff(ctx context.Context, attempt uint, callOpts *options) error {
waitTime := time.Duration(0)
if attempt > 0 {
waitTime = callOpts.backoffFunc(attempt)
}
if waitTime > 0 {
timer := time.NewTimer(waitTime)
select {
case <-ctx.Done():
timer.Stop()
return contextErrToGrpcErr(ctx.Err())
case <-timer.C:
}
}
return nil
}
// isSafeRetry returns "true", if request is safe for retry with the given error.
func isSafeRetry(lg *zap.Logger, err error, callOpts *options) bool {
if isContextError(err) {
return false
}
switch callOpts.retryPolicy {
case repeatable:
return isSafeRetryImmutableRPC(err)
case nonRepeatable:
return isSafeRetryMutableRPC(err)
default:
lg.Warn("unrecognized retry policy", zap.String("retryPolicy", callOpts.retryPolicy.String()))
return false
}
}
func isContextError(err error) bool {
return grpc.Code(err) == codes.DeadlineExceeded || grpc.Code(err) == codes.Canceled
}
func contextErrToGrpcErr(err error) error {
switch err {
case context.DeadlineExceeded:
return status.Errorf(codes.DeadlineExceeded, err.Error())
case context.Canceled:
return status.Errorf(codes.Canceled, err.Error())
default:
return status.Errorf(codes.Unknown, err.Error())
}
}
var (
defaultOptions = &options{
retryPolicy: nonRepeatable,
max: 0, // disable
backoffFunc: backoffLinearWithJitter(50*time.Millisecond /*jitter*/, 0.10),
retryAuth: true,
}
)
// backoffFunc denotes a family of functions that control the backoff duration between call retries.
//
// They are called with an identifier of the attempt, and should return a time the system client should
// hold off for. If the time returned is longer than the `context.Context.Deadline` of the request
// the deadline of the request takes precedence and the wait will be interrupted before proceeding
// with the next iteration.
type backoffFunc func(attempt uint) time.Duration
// withRetryPolicy sets the retry policy of this call.
func withRetryPolicy(rp retryPolicy) retryOption {
return retryOption{applyFunc: func(o *options) {
o.retryPolicy = rp
}}
}
// withMax sets the maximum number of retries on this call, or this interceptor.
func withMax(maxRetries uint) retryOption {
return retryOption{applyFunc: func(o *options) {
o.max = maxRetries
}}
}
// WithBackoff sets the `BackoffFunc `used to control time between retries.
func withBackoff(bf backoffFunc) retryOption {
return retryOption{applyFunc: func(o *options) {
o.backoffFunc = bf
}}
}
type options struct {
retryPolicy retryPolicy
max uint
backoffFunc backoffFunc
retryAuth bool
}
// retryOption is a grpc.CallOption that is local to clientv3's retry interceptor.
type retryOption struct {
grpc.EmptyCallOption // make sure we implement private after() and before() fields so we don't panic.
applyFunc func(opt *options)
}
func reuseOrNewWithCallOptions(opt *options, retryOptions []retryOption) *options {
if len(retryOptions) == 0 {
return opt
}
optCopy := &options{}
*optCopy = *opt
for _, f := range retryOptions {
f.applyFunc(optCopy)
}
return optCopy
}
func filterCallOptions(callOptions []grpc.CallOption) (grpcOptions []grpc.CallOption, retryOptions []retryOption) {
for _, opt := range callOptions {
if co, ok := opt.(retryOption); ok {
retryOptions = append(retryOptions, co)
} else {
grpcOptions = append(grpcOptions, opt)
}
}
return grpcOptions, retryOptions
}
// BackoffLinearWithJitter waits a set period of time, allowing for jitter (fractional adjustment).
//
// For example waitBetween=1s and jitter=0.10 can generate waits between 900ms and 1100ms.
func backoffLinearWithJitter(waitBetween time.Duration, jitterFraction float64) backoffFunc {
return func(attempt uint) time.Duration {
return jitterUp(waitBetween, jitterFraction)
}
}

16
clientv3/snapshot/doc.go Normal file
View File

@ -0,0 +1,16 @@
// Copyright 2018 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package snapshot implements utilities around etcd snapshot.
package snapshot

Binary file not shown.

35
clientv3/snapshot/util.go Normal file
View File

@ -0,0 +1,35 @@
// Copyright 2018 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package snapshot
import "encoding/binary"
type revision struct {
main int64
sub int64
}
func bytesToRev(bytes []byte) revision {
return revision{
main: int64(binary.BigEndian.Uint64(bytes[0:8])),
sub: int64(binary.BigEndian.Uint64(bytes[9:])),
}
}
// initIndex implements ConsistentIndexGetter so the snapshot won't block
// the new raft instance by waiting for a future raft index.
type initIndex int
func (i *initIndex) ConsistentIndex() uint64 { return uint64(*i) }

View File

@ -0,0 +1,492 @@
// Copyright 2018 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package snapshot
import (
"context"
"crypto/sha256"
"encoding/json"
"fmt"
"hash/crc32"
"io"
"math"
"os"
"path/filepath"
"reflect"
"strings"
"time"
bolt "github.com/coreos/bbolt"
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/etcdserver"
"github.com/coreos/etcd/etcdserver/etcdserverpb"
"github.com/coreos/etcd/etcdserver/membership"
"github.com/coreos/etcd/lease"
"github.com/coreos/etcd/mvcc"
"github.com/coreos/etcd/mvcc/backend"
"github.com/coreos/etcd/pkg/fileutil"
"github.com/coreos/etcd/pkg/types"
"github.com/coreos/etcd/raft"
"github.com/coreos/etcd/raft/raftpb"
"github.com/coreos/etcd/snap"
"github.com/coreos/etcd/store"
"github.com/coreos/etcd/wal"
"github.com/coreos/etcd/wal/walpb"
"go.uber.org/zap"
)
// Manager defines snapshot methods.
type Manager interface {
// Save fetches snapshot from remote etcd server and saves data
// to target path. If the context "ctx" is canceled or timed out,
// snapshot save stream will error out (e.g. context.Canceled,
// context.DeadlineExceeded). Make sure to specify only one endpoint
// in client configuration. Snapshot API must be requested to a
// selected node, and saved snapshot is the point-in-time state of
// the selected node.
Save(ctx context.Context, cfg clientv3.Config, dbPath string) error
// Status returns the snapshot file information.
Status(dbPath string) (Status, error)
// Restore restores a new etcd data directory from given snapshot
// file. It returns an error if specified data directory already
// exists, to prevent unintended data directory overwrites.
Restore(cfg RestoreConfig) error
}
// NewV3 returns a new snapshot Manager for v3.x snapshot.
func NewV3(lg *zap.Logger) Manager {
if lg == nil {
lg = zap.NewExample()
}
return &v3Manager{lg: lg}
}
type v3Manager struct {
lg *zap.Logger
name string
dbPath string
walDir string
snapDir string
cl *membership.RaftCluster
skipHashCheck bool
}
// Save fetches snapshot from remote etcd server and saves data to target path.
func (s *v3Manager) Save(ctx context.Context, cfg clientv3.Config, dbPath string) error {
if len(cfg.Endpoints) != 1 {
return fmt.Errorf("snapshot must be requested to one selected node, not multiple %v", cfg.Endpoints)
}
cli, err := clientv3.New(cfg)
if err != nil {
return err
}
defer cli.Close()
partpath := dbPath + ".part"
defer os.RemoveAll(partpath)
var f *os.File
f, err = os.OpenFile(partpath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fileutil.PrivateFileMode)
if err != nil {
return fmt.Errorf("could not open %s (%v)", partpath, err)
}
s.lg.Info(
"created temporary db file",
zap.String("path", partpath),
)
now := time.Now()
var rd io.ReadCloser
rd, err = cli.Snapshot(ctx)
if err != nil {
return err
}
s.lg.Info(
"fetching snapshot",
zap.String("endpoint", cfg.Endpoints[0]),
)
if _, err = io.Copy(f, rd); err != nil {
return err
}
if err = fileutil.Fsync(f); err != nil {
return err
}
if err = f.Close(); err != nil {
return err
}
s.lg.Info(
"fetched snapshot",
zap.String("endpoint", cfg.Endpoints[0]),
zap.Duration("took", time.Since(now)),
)
if err = os.Rename(partpath, dbPath); err != nil {
return fmt.Errorf("could not rename %s to %s (%v)", partpath, dbPath, err)
}
s.lg.Info("saved", zap.String("path", dbPath))
return nil
}
// Status is the snapshot file status.
type Status struct {
Hash uint32 `json:"hash"`
Revision int64 `json:"revision"`
TotalKey int `json:"totalKey"`
TotalSize int64 `json:"totalSize"`
}
// Status returns the snapshot file information.
func (s *v3Manager) Status(dbPath string) (ds Status, err error) {
if _, err = os.Stat(dbPath); err != nil {
return ds, err
}
db, err := bolt.Open(dbPath, 0400, &bolt.Options{ReadOnly: true})
if err != nil {
return ds, err
}
defer db.Close()
h := crc32.New(crc32.MakeTable(crc32.Castagnoli))
if err = db.View(func(tx *bolt.Tx) error {
// check snapshot file integrity first
var dbErrStrings []string
for dbErr := range tx.Check() {
dbErrStrings = append(dbErrStrings, dbErr.Error())
}
if len(dbErrStrings) > 0 {
return fmt.Errorf("snapshot file integrity check failed. %d errors found.\n"+strings.Join(dbErrStrings, "\n"), len(dbErrStrings))
}
ds.TotalSize = tx.Size()
c := tx.Cursor()
for next, _ := c.First(); next != nil; next, _ = c.Next() {
b := tx.Bucket(next)
if b == nil {
return fmt.Errorf("cannot get hash of bucket %s", string(next))
}
h.Write(next)
iskeyb := (string(next) == "key")
b.ForEach(func(k, v []byte) error {
h.Write(k)
h.Write(v)
if iskeyb {
rev := bytesToRev(k)
ds.Revision = rev.main
}
ds.TotalKey++
return nil
})
}
return nil
}); err != nil {
return ds, err
}
ds.Hash = h.Sum32()
return ds, nil
}
// RestoreConfig configures snapshot restore operation.
type RestoreConfig struct {
// SnapshotPath is the path of snapshot file to restore from.
SnapshotPath string
// Name is the human-readable name of this member.
Name string
// OutputDataDir is the target data directory to save restored data.
// OutputDataDir should not conflict with existing etcd data directory.
// If OutputDataDir already exists, it will return an error to prevent
// unintended data directory overwrites.
// If empty, defaults to "[Name].etcd" if not given.
OutputDataDir string
// OutputWALDir is the target WAL data directory.
// If empty, defaults to "[OutputDataDir]/member/wal" if not given.
OutputWALDir string
// PeerURLs is a list of member's peer URLs to advertise to the rest of the cluster.
PeerURLs []string
// InitialCluster is the initial cluster configuration for restore bootstrap.
InitialCluster string
// InitialClusterToken is the initial cluster token for etcd cluster during restore bootstrap.
InitialClusterToken string
// SkipHashCheck is "true" to ignore snapshot integrity hash value
// (required if copied from data directory).
SkipHashCheck bool
}
// Restore restores a new etcd data directory from given snapshot file.
func (s *v3Manager) Restore(cfg RestoreConfig) error {
pURLs, err := types.NewURLs(cfg.PeerURLs)
if err != nil {
return err
}
var ics types.URLsMap
ics, err = types.NewURLsMap(cfg.InitialCluster)
if err != nil {
return err
}
srv := etcdserver.ServerConfig{
Name: cfg.Name,
PeerURLs: pURLs,
InitialPeerURLsMap: ics,
InitialClusterToken: cfg.InitialClusterToken,
}
if err = srv.VerifyBootstrap(); err != nil {
return err
}
s.cl, err = membership.NewClusterFromURLsMap(cfg.InitialClusterToken, ics)
if err != nil {
return err
}
dataDir := cfg.OutputDataDir
if dataDir == "" {
dataDir = cfg.Name + ".etcd"
}
if fileutil.Exist(dataDir) {
return fmt.Errorf("data-dir %q exists", dataDir)
}
walDir := cfg.OutputWALDir
if walDir == "" {
walDir = filepath.Join(dataDir, "member", "wal")
} else if fileutil.Exist(walDir) {
return fmt.Errorf("wal-dir %q exists", walDir)
}
s.name = cfg.Name
s.dbPath = cfg.SnapshotPath
s.walDir = walDir
s.snapDir = filepath.Join(dataDir, "member", "snap")
s.skipHashCheck = cfg.SkipHashCheck
s.lg.Info(
"restoring snapshot",
zap.String("path", s.dbPath),
zap.String("wal-dir", s.walDir),
zap.String("data-dir", dataDir),
zap.String("snap-dir", s.snapDir),
)
if err = s.saveDB(); err != nil {
return err
}
if err = s.saveWALAndSnap(); err != nil {
return err
}
s.lg.Info(
"restored snapshot",
zap.String("path", s.dbPath),
zap.String("wal-dir", s.walDir),
zap.String("data-dir", dataDir),
zap.String("snap-dir", s.snapDir),
)
return nil
}
// saveDB copies the database snapshot to the snapshot directory
func (s *v3Manager) saveDB() error {
f, ferr := os.OpenFile(s.dbPath, os.O_RDONLY, 0600)
if ferr != nil {
return ferr
}
defer f.Close()
// get snapshot integrity hash
if _, err := f.Seek(-sha256.Size, io.SeekEnd); err != nil {
return err
}
sha := make([]byte, sha256.Size)
if _, err := f.Read(sha); err != nil {
return err
}
if _, err := f.Seek(0, io.SeekStart); err != nil {
return err
}
if err := fileutil.CreateDirAll(s.snapDir); err != nil {
return err
}
dbpath := filepath.Join(s.snapDir, "db")
db, dberr := os.OpenFile(dbpath, os.O_RDWR|os.O_CREATE, 0600)
if dberr != nil {
return dberr
}
if _, err := io.Copy(db, f); err != nil {
return err
}
// truncate away integrity hash, if any.
off, serr := db.Seek(0, io.SeekEnd)
if serr != nil {
return serr
}
hasHash := (off % 512) == sha256.Size
if hasHash {
if err := db.Truncate(off - sha256.Size); err != nil {
return err
}
}
if !hasHash && !s.skipHashCheck {
return fmt.Errorf("snapshot missing hash but --skip-hash-check=false")
}
if hasHash && !s.skipHashCheck {
// check for match
if _, err := db.Seek(0, io.SeekStart); err != nil {
return err
}
h := sha256.New()
if _, err := io.Copy(h, db); err != nil {
return err
}
dbsha := h.Sum(nil)
if !reflect.DeepEqual(sha, dbsha) {
return fmt.Errorf("expected sha256 %v, got %v", sha, dbsha)
}
}
// db hash is OK, can now modify DB so it can be part of a new cluster
db.Close()
commit := len(s.cl.Members())
// update consistentIndex so applies go through on etcdserver despite
// having a new raft instance
be := backend.NewDefaultBackend(dbpath)
// a lessor never timeouts leases
lessor := lease.NewLessor(be, math.MaxInt64)
mvs := mvcc.NewStore(be, lessor, (*initIndex)(&commit))
txn := mvs.Write()
btx := be.BatchTx()
del := func(k, v []byte) error {
txn.DeleteRange(k, nil)
return nil
}
// delete stored members from old cluster since using new members
btx.UnsafeForEach([]byte("members"), del)
// todo: add back new members when we start to deprecate old snap file.
btx.UnsafeForEach([]byte("members_removed"), del)
// trigger write-out of new consistent index
txn.End()
mvs.Commit()
mvs.Close()
be.Close()
return nil
}
// saveWALAndSnap creates a WAL for the initial cluster
func (s *v3Manager) saveWALAndSnap() error {
if err := fileutil.CreateDirAll(s.walDir); err != nil {
return err
}
// add members again to persist them to the store we create.
st := store.New(etcdserver.StoreClusterPrefix, etcdserver.StoreKeysPrefix)
s.cl.SetStore(st)
for _, m := range s.cl.Members() {
s.cl.AddMember(m)
}
m := s.cl.MemberByName(s.name)
md := &etcdserverpb.Metadata{NodeID: uint64(m.ID), ClusterID: uint64(s.cl.ID())}
metadata, merr := md.Marshal()
if merr != nil {
return merr
}
w, walerr := wal.Create(s.walDir, metadata)
if walerr != nil {
return walerr
}
defer w.Close()
peers := make([]raft.Peer, len(s.cl.MemberIDs()))
for i, id := range s.cl.MemberIDs() {
ctx, err := json.Marshal((*s.cl).Member(id))
if err != nil {
return err
}
peers[i] = raft.Peer{ID: uint64(id), Context: ctx}
}
ents := make([]raftpb.Entry, len(peers))
nodeIDs := make([]uint64, len(peers))
for i, p := range peers {
nodeIDs[i] = p.ID
cc := raftpb.ConfChange{
Type: raftpb.ConfChangeAddNode,
NodeID: p.ID,
Context: p.Context,
}
d, err := cc.Marshal()
if err != nil {
return err
}
ents[i] = raftpb.Entry{
Type: raftpb.EntryConfChange,
Term: 1,
Index: uint64(i + 1),
Data: d,
}
}
commit, term := uint64(len(ents)), uint64(1)
if err := w.Save(raftpb.HardState{
Term: term,
Vote: peers[0].ID,
Commit: commit,
}, ents); err != nil {
return err
}
b, berr := st.Save()
if berr != nil {
return berr
}
raftSnap := raftpb.Snapshot{
Data: b,
Metadata: raftpb.SnapshotMetadata{
Index: commit,
Term: term,
ConfState: raftpb.ConfState{
Nodes: nodeIDs,
},
},
}
sn := snap.New(s.snapDir)
if err := sn.SaveSnap(raftSnap); err != nil {
return err
}
return w.SaveSnapshot(walpb.Snapshot{Index: commit, Term: term})
}

49
clientv3/utils.go Normal file
View File

@ -0,0 +1,49 @@
// Copyright 2018 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package clientv3
import (
"math/rand"
"reflect"
"runtime"
"strings"
"time"
)
// jitterUp adds random jitter to the duration.
//
// This adds or subtracts time from the duration within a given jitter fraction.
// For example for 10s and jitter 0.1, it will return a time within [9s, 11s])
//
// Reference: https://godoc.org/github.com/grpc-ecosystem/go-grpc-middleware/util/backoffutils
func jitterUp(duration time.Duration, jitter float64) time.Duration {
multiplier := jitter * (rand.Float64()*2 - 1)
return time.Duration(float64(duration) * (1 + multiplier))
}
// Check if the provided function is being called in the op options.
func isOpFuncCalled(op string, opts []OpOption) bool {
for _, opt := range opts {
v := reflect.ValueOf(opt)
if v.Kind() == reflect.Func {
if opFunc := runtime.FuncForPC(v.Pointer()); opFunc != nil {
if strings.Contains(opFunc.Name(), op) {
return true
}
}
}
}
return false
}

View File

@ -16,6 +16,7 @@ package clientv3
import (
"context"
"errors"
"fmt"
"sync"
"time"
@ -46,8 +47,33 @@ type Watcher interface {
// through the returned channel. If revisions waiting to be sent over the
// watch are compacted, then the watch will be canceled by the server, the
// client will post a compacted error watch response, and the channel will close.
// If the context "ctx" is canceled or timed out, returned "WatchChan" is closed,
// and "WatchResponse" from this closed channel has zero events and nil "Err()".
// The context "ctx" MUST be canceled, as soon as watcher is no longer being used,
// to release the associated resources.
//
// If the context is "context.Background/TODO", returned "WatchChan" will
// not be closed and block until event is triggered, except when server
// returns a non-recoverable error (e.g. ErrCompacted).
// For example, when context passed with "WithRequireLeader" and the
// connected server has no leader (e.g. due to network partition),
// error "etcdserver: no leader" (ErrNoLeader) will be returned,
// and then "WatchChan" is closed with non-nil "Err()".
// In order to prevent a watch stream being stuck in a partitioned node,
// make sure to wrap context with "WithRequireLeader".
//
// Otherwise, as long as the context has not been canceled or timed out,
// watch will retry on other recoverable errors forever until reconnected.
//
// TODO: explicitly set context error in the last "WatchResponse" message and close channel?
// Currently, client contexts are overwritten with "valCtx" that never closes.
// TODO(v3.4): configure watch retry policy, limit maximum retry number
// (see https://github.com/etcd-io/etcd/issues/8980)
Watch(ctx context.Context, key string, opts ...OpOption) WatchChan
// RequestProgress requests a progress notify response be sent in all watch channels.
RequestProgress(ctx context.Context) error
// Close closes the watcher and cancels all watch requests.
Close() error
}
@ -134,7 +160,7 @@ type watchGrpcStream struct {
resuming []*watcherStream
// reqc sends a watch request from Watch() to the main goroutine
reqc chan *watchRequest
reqc chan watchStreamRequest
// respc receives data from the watch client
respc chan *pb.WatchResponse
// donec closes to broadcast shutdown
@ -152,16 +178,27 @@ type watchGrpcStream struct {
closeErr error
}
// watchStreamRequest is a union of the supported watch request operation types
type watchStreamRequest interface {
toPB() *pb.WatchRequest
}
// watchRequest is issued by the subscriber to start a new watcher
type watchRequest struct {
ctx context.Context
key string
end string
rev int64
// send created notification event if this field is true
createdNotify bool
// progressNotify is for progress updates
progressNotify bool
// fragmentation should be disabled by default
// if true, split watch events when total exceeds
// "--max-request-bytes" flag value + 512-byte
fragment bool
// filters is the list of events to filter out
filters []pb.WatchCreateRequest_FilterType
// get the previous key-value pair before the event happens
@ -170,6 +207,10 @@ type watchRequest struct {
retc chan chan WatchResponse
}
// progressRequest is issued by the subscriber to request watch progress
type progressRequest struct {
}
// watcherStream represents a registered watcher
type watcherStream struct {
// initReq is the request that initiated this request
@ -227,7 +268,7 @@ func (w *watcher) newWatcherGrpcStream(inctx context.Context) *watchGrpcStream {
cancel: cancel,
substreams: make(map[int64]*watcherStream),
respc: make(chan *pb.WatchResponse),
reqc: make(chan *watchRequest),
reqc: make(chan watchStreamRequest),
donec: make(chan struct{}),
errc: make(chan error, 1),
closingc: make(chan *watcherStream),
@ -256,6 +297,7 @@ func (w *watcher) Watch(ctx context.Context, key string, opts ...OpOption) Watch
end: string(ow.end),
rev: ow.rev,
progressNotify: ow.progressNotify,
fragment: ow.fragment,
filters: filters,
prevKV: ow.prevKV,
retc: make(chan chan WatchResponse, 1),
@ -292,7 +334,7 @@ func (w *watcher) Watch(ctx context.Context, key string, opts ...OpOption) Watch
case <-wr.ctx.Done():
case <-donec:
if wgs.closeErr != nil {
closeCh <- WatchResponse{closeErr: wgs.closeErr}
closeCh <- WatchResponse{Canceled: true, closeErr: wgs.closeErr}
break
}
// retry; may have dropped stream from no ctxs
@ -307,7 +349,7 @@ func (w *watcher) Watch(ctx context.Context, key string, opts ...OpOption) Watch
case <-ctx.Done():
case <-donec:
if wgs.closeErr != nil {
closeCh <- WatchResponse{closeErr: wgs.closeErr}
closeCh <- WatchResponse{Canceled: true, closeErr: wgs.closeErr}
break
}
// retry; may have dropped stream from no ctxs
@ -329,9 +371,50 @@ func (w *watcher) Close() (err error) {
err = werr
}
}
// Consider context.Canceled as a successful close
if err == context.Canceled {
err = nil
}
return err
}
// RequestProgress requests a progress notify response be sent in all watch channels.
func (w *watcher) RequestProgress(ctx context.Context) (err error) {
ctxKey := streamKeyFromCtx(ctx)
w.mu.Lock()
if w.streams == nil {
w.mu.Unlock()
return fmt.Errorf("no stream found for context")
}
wgs := w.streams[ctxKey]
if wgs == nil {
wgs = w.newWatcherGrpcStream(ctx)
w.streams[ctxKey] = wgs
}
donec := wgs.donec
reqc := wgs.reqc
w.mu.Unlock()
pr := &progressRequest{}
select {
case reqc <- pr:
return nil
case <-ctx.Done():
if err == nil {
return ctx.Err()
}
return err
case <-donec:
if wgs.closeErr != nil {
return wgs.closeErr
}
// retry; may have dropped stream from no ctxs
return w.RequestProgress(ctx)
}
}
func (w *watchGrpcStream) close() (err error) {
w.cancel()
<-w.donec
@ -353,7 +436,9 @@ func (w *watcher) closeStream(wgs *watchGrpcStream) {
}
func (w *watchGrpcStream) addSubstream(resp *pb.WatchResponse, ws *watcherStream) {
if resp.WatchId == -1 {
// check watch ID for backward compatibility (<= v3.3)
if resp.WatchId == -1 || (resp.Canceled && resp.CancelReason != "") {
w.closeErr = v3rpc.Error(errors.New(resp.CancelReason))
// failed; no channel
close(ws.recvc)
return
@ -379,7 +464,7 @@ func (w *watchGrpcStream) closeSubstream(ws *watcherStream) {
}
// close subscriber's channel
if closeErr := w.closeErr; closeErr != nil && ws.initReq.ctx.Err() == nil {
go w.sendCloseSubstream(ws, &WatchResponse{closeErr: w.closeErr})
go w.sendCloseSubstream(ws, &WatchResponse{Canceled: true, closeErr: w.closeErr})
} else if ws.outc != nil {
close(ws.outc)
}
@ -434,31 +519,48 @@ func (w *watchGrpcStream) run() {
cancelSet := make(map[int64]struct{})
var cur *pb.WatchResponse
for {
select {
// Watch() requested
case wreq := <-w.reqc:
outc := make(chan WatchResponse, 1)
ws := &watcherStream{
initReq: *wreq,
id: -1,
outc: outc,
// unbuffered so resumes won't cause repeat events
recvc: make(chan *WatchResponse),
case req := <-w.reqc:
switch wreq := req.(type) {
case *watchRequest:
outc := make(chan WatchResponse, 1)
// TODO: pass custom watch ID?
ws := &watcherStream{
initReq: *wreq,
id: -1,
outc: outc,
// unbuffered so resumes won't cause repeat events
recvc: make(chan *WatchResponse),
}
ws.donec = make(chan struct{})
w.wg.Add(1)
go w.serveSubstream(ws, w.resumec)
// queue up for watcher creation/resume
w.resuming = append(w.resuming, ws)
if len(w.resuming) == 1 {
// head of resume queue, can register a new watcher
wc.Send(ws.initReq.toPB())
}
case *progressRequest:
wc.Send(wreq.toPB())
}
ws.donec = make(chan struct{})
w.wg.Add(1)
go w.serveSubstream(ws, w.resumec)
// queue up for watcher creation/resume
w.resuming = append(w.resuming, ws)
if len(w.resuming) == 1 {
// head of resume queue, can register a new watcher
wc.Send(ws.initReq.toPB())
}
// New events from the watch client
// new events from the watch client
case pbresp := <-w.respc:
if cur == nil || pbresp.Created || pbresp.Canceled {
cur = pbresp
} else if cur != nil && cur.WatchId == pbresp.WatchId {
// merge new events
cur.Events = append(cur.Events, pbresp.Events...)
// update "Fragment" field; last response with "Fragment" == false
cur.Fragment = pbresp.Fragment
}
switch {
case pbresp.Created:
// response to head of queue creation
@ -467,9 +569,14 @@ func (w *watchGrpcStream) run() {
w.dispatchEvent(pbresp)
w.resuming[0] = nil
}
if ws := w.nextResume(); ws != nil {
wc.Send(ws.initReq.toPB())
}
// reset for next iteration
cur = nil
case pbresp.Canceled && pbresp.CompactRevision == 0:
delete(cancelSet, pbresp.WatchId)
if ws, ok := w.substreams[pbresp.WatchId]; ok {
@ -477,15 +584,31 @@ func (w *watchGrpcStream) run() {
close(ws.recvc)
closing[ws] = struct{}{}
}
// reset for next iteration
cur = nil
case cur.Fragment:
// watch response events are still fragmented
// continue to fetch next fragmented event arrival
continue
default:
// dispatch to appropriate watch stream
if ok := w.dispatchEvent(pbresp); ok {
ok := w.dispatchEvent(cur)
// reset for next iteration
cur = nil
if ok {
break
}
// watch response on unexpected watch id; cancel id
if _, ok := cancelSet[pbresp.WatchId]; ok {
break
}
cancelSet[pbresp.WatchId] = struct{}{}
cr := &pb.WatchRequest_CancelRequest{
CancelRequest: &pb.WatchCancelRequest{
@ -495,6 +618,7 @@ func (w *watchGrpcStream) run() {
req := &pb.WatchRequest{RequestUnion: cr}
wc.Send(req)
}
// watch client failed on Recv; spawn another if possible
case err := <-w.errc:
if isHaltErr(w.ctx, err) || toErr(w.ctx, err) == v3rpc.ErrNoLeader {
@ -508,13 +632,15 @@ func (w *watchGrpcStream) run() {
wc.Send(ws.initReq.toPB())
}
cancelSet = make(map[int64]struct{})
case <-w.ctx.Done():
return
case ws := <-w.closingc:
w.closeSubstream(ws)
delete(closing, ws)
// no more watchers on this stream, shutdown
if len(w.substreams)+len(w.resuming) == 0 {
// no more watchers on this stream, shutdown
return
}
}
@ -539,6 +665,7 @@ func (w *watchGrpcStream) dispatchEvent(pbresp *pb.WatchResponse) bool {
for i, ev := range pbresp.Events {
events[i] = (*Event)(ev)
}
// TODO: return watch ID?
wr := &WatchResponse{
Header: *pbresp.Header,
Events: events,
@ -547,7 +674,31 @@ func (w *watchGrpcStream) dispatchEvent(pbresp *pb.WatchResponse) bool {
Canceled: pbresp.Canceled,
cancelReason: pbresp.CancelReason,
}
ws, ok := w.substreams[pbresp.WatchId]
// watch IDs are zero indexed, so request notify watch responses are assigned a watch ID of -1 to
// indicate they should be broadcast.
if wr.IsProgressNotify() && pbresp.WatchId == -1 {
return w.broadcastResponse(wr)
}
return w.unicastResponse(wr, pbresp.WatchId)
}
// broadcastResponse send a watch response to all watch substreams.
func (w *watchGrpcStream) broadcastResponse(wr *WatchResponse) bool {
for _, ws := range w.substreams {
select {
case ws.recvc <- wr:
case <-ws.donec:
}
}
return true
}
// unicastResponse sends a watch response to a specific watch substream.
func (w *watchGrpcStream) unicastResponse(wr *WatchResponse, watchId int64) bool {
ws, ok := w.substreams[watchId]
if !ok {
return false
}
@ -815,11 +966,19 @@ func (wr *watchRequest) toPB() *pb.WatchRequest {
ProgressNotify: wr.progressNotify,
Filters: wr.filters,
PrevKv: wr.prevKV,
Fragment: wr.fragment,
}
cr := &pb.WatchRequest_CreateRequest{CreateRequest: req}
return &pb.WatchRequest{RequestUnion: cr}
}
// toPB converts an internal progress request structure to its protobuf WatchRequest structure.
func (pr *progressRequest) toPB() *pb.WatchRequest {
req := &pb.WatchProgressRequest{}
cr := &pb.WatchRequest_ProgressRequest{ProgressRequest: req}
return &pb.WatchRequest{RequestUnion: cr}
}
func streamKeyFromCtx(ctx context.Context) string {
if md, ok := metadata.FromOutgoingContext(ctx); ok {
return fmt.Sprintf("%+v", md)