diff --git a/client/v3/client.go b/client/v3/client.go index 85ffd0637..23069468b 100644 --- a/client/v3/client.go +++ b/client/v3/client.go @@ -507,7 +507,7 @@ func (c *Client) checkVersion() (err error) { return } } - if maj < 3 || (maj == 3 && min < 2) { + if maj < 3 || (maj == 3 && min < 4) { rerr = ErrOldCluster } errc <- rerr @@ -515,7 +515,7 @@ func (c *Client) checkVersion() (err error) { } // wait for success for range eps { - if err = <-errc; err == nil { + if err = <-errc; err != nil { break } } diff --git a/client/v3/client_test.go b/client/v3/client_test.go index b26aa999d..00ab263ea 100644 --- a/client/v3/client_test.go +++ b/client/v3/client_test.go @@ -17,7 +17,9 @@ package clientv3 import ( "context" "fmt" + "io" "net" + "sync" "testing" "time" @@ -266,6 +268,108 @@ func TestSyncFiltersMembers(t *testing.T) { } } +func TestClientRejectOldCluster(t *testing.T) { + testutil.BeforeTest(t) + var tests = []struct { + name string + endpoints []string + versions []string + expectedError error + }{ + { + name: "all new versions with the same value", + endpoints: []string{"192.168.3.41:22379", "192.168.3.41:22479", "192.168.3.41:22579"}, + versions: []string{"3.5.4", "3.5.4", "3.5.4"}, + expectedError: nil, + }, + { + name: "all new versions with different values", + endpoints: []string{"192.168.3.41:22379", "192.168.3.41:22479", "192.168.3.41:22579"}, + versions: []string{"3.5.4", "3.5.4", "3.4.0"}, + expectedError: nil, + }, + { + name: "all old versions with different values", + endpoints: []string{"192.168.3.41:22379", "192.168.3.41:22479", "192.168.3.41:22579"}, + versions: []string{"3.3.0", "3.3.0", "3.4.0"}, + expectedError: ErrOldCluster, + }, + { + name: "all old versions with the same value", + endpoints: []string{"192.168.3.41:22379", "192.168.3.41:22479", "192.168.3.41:22579"}, + versions: []string{"3.3.0", "3.3.0", "3.3.0"}, + expectedError: ErrOldCluster, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if len(tt.endpoints) != len(tt.versions) || len(tt.endpoints) == 0 { + t.Errorf("Unexpected endpoints and versions length, len(endpoints):%d, len(versions):%d", len(tt.endpoints), len(tt.versions)) + return + } + endpointToVersion := make(map[string]string) + for j := range tt.endpoints { + endpointToVersion[tt.endpoints[j]] = tt.versions[j] + } + c := &Client{ + ctx: context.Background(), + endpoints: tt.endpoints, + epMu: new(sync.RWMutex), + Maintenance: &mockMaintenance{ + Version: endpointToVersion, + }, + } + + if err := c.checkVersion(); err != tt.expectedError { + t.Errorf("heckVersion err:%v", err) + } + }) + + } + +} + +type mockMaintenance struct { + Version map[string]string +} + +func (mm mockMaintenance) Status(ctx context.Context, endpoint string) (*StatusResponse, error) { + return &StatusResponse{Version: mm.Version[endpoint]}, nil +} + +func (mm mockMaintenance) AlarmList(ctx context.Context) (*AlarmResponse, error) { + return nil, nil +} + +func (mm mockMaintenance) AlarmDisarm(ctx context.Context, m *AlarmMember) (*AlarmResponse, error) { + return nil, nil +} + +func (mm mockMaintenance) Defragment(ctx context.Context, endpoint string) (*DefragmentResponse, error) { + return nil, nil +} + +func (mm mockMaintenance) HashKV(ctx context.Context, endpoint string, rev int64) (*HashKVResponse, error) { + return nil, nil +} + +func (mm mockMaintenance) SnapshotWithVersion(ctx context.Context) (*SnapshotResponse, error) { + return nil, nil +} + +func (mm mockMaintenance) Snapshot(ctx context.Context) (io.ReadCloser, error) { + return nil, nil +} + +func (mm mockMaintenance) MoveLeader(ctx context.Context, transfereeID uint64) (*MoveLeaderResponse, error) { + return nil, nil +} + +func (mm mockMaintenance) Downgrade(ctx context.Context, action DowngradeAction, version string) (*DowngradeResponse, error) { + return nil, nil +} + type mockAuthServer struct { *etcdserverpb.UnimplementedAuthServer }