diff --git a/discovery/discovery_test.go b/discovery/discovery_test.go index 611346531..d50f20f81 100644 --- a/discovery/discovery_test.go +++ b/discovery/discovery_test.go @@ -108,7 +108,7 @@ func TestCheckCluster(t *testing.T) { c := &clientWithResp{rs: rs} d := discovery{cluster: cluster, id: 1, c: c} - cRetry := &clientWithRetry{} + cRetry := &clientWithRetry{failTimes: 2} cRetry.rs = rs dRetry := discovery{cluster: cluster, id: 1, c: cRetry, timeoutTimescale: time.Millisecond * 2} @@ -199,9 +199,10 @@ func TestWaitNodes(t *testing.T) { }, }) } - cRetry := &clientWithRetry{} - cRetry.rs = retryScanResp - cRetry.w = &watcherWithRetry{tt.rs, false} + cRetry := &clientWithResp{ + rs: retryScanResp, + w: &watcherWithRetry{rs: tt.rs, failTimes: 2}, + } dRetry := &discovery{ cluster: "1000", c: cRetry, @@ -312,7 +313,7 @@ func (c *clientWithResp) Get(key string) (*client.Response, error) { return &client.Response{}, client.ErrKeyNoExist } r := c.rs[0] - c.rs = c.rs[1:] + c.rs = append(c.rs[1:], r) return r, nil } @@ -369,12 +370,13 @@ func (w *watcherWithErr) Next() (*client.Response, error) { // Fails every other time type clientWithRetry struct { clientWithResp - haveFailed bool + failCount int + failTimes int } func (c *clientWithRetry) Create(key string, value string, ttl time.Duration) (*client.Response, error) { - if !c.haveFailed { - c.haveFailed = true + if c.failCount < c.failTimes { + c.failCount++ return nil, client.ErrTimeout } if len(c.rs) == 0 { @@ -386,26 +388,22 @@ func (c *clientWithRetry) Create(key string, value string, ttl time.Duration) (* } func (c *clientWithRetry) Get(key string) (*client.Response, error) { - if !c.haveFailed { - c.haveFailed = true + if c.failCount < c.failTimes { + c.failCount++ return nil, client.ErrTimeout } - if len(c.rs) == 0 { - return &client.Response{}, client.ErrKeyNoExist - } - r := c.rs[0] - c.rs = c.rs[1:] - return r, nil + return c.clientWithResp.Get(key) } type watcherWithRetry struct { - rs []*client.Response - haveFailed bool + rs []*client.Response + failCount int + failTimes int } func (w *watcherWithRetry) Next() (*client.Response, error) { - if !w.haveFailed { - w.haveFailed = true + if w.failCount < w.failTimes { + w.failCount++ return nil, client.ErrTimeout } if len(w.rs) == 0 {