diff --git a/client/client.go b/client/client.go index 86cf51bed..fe41253c2 100644 --- a/client/client.go +++ b/client/client.go @@ -18,6 +18,7 @@ import ( "errors" "fmt" "io/ioutil" + "math/rand" "net" "net/http" "net/url" @@ -131,6 +132,7 @@ type Client interface { func New(cfg Config) (Client, error) { c := &httpClusterClient{ clientFactory: newHTTPClientFactory(cfg.transport(), cfg.checkRedirect()), + rand: rand.New(rand.NewSource(int64(time.Now().Nanosecond()))), } if cfg.Username != "" { c.credentials = &credentials{ @@ -174,8 +176,10 @@ type httpAction interface { type httpClusterClient struct { clientFactory httpClientFactory endpoints []url.URL + pinned int credentials *credentials sync.RWMutex + rand *rand.Rand } func (c *httpClusterClient) reset(eps []string) error { @@ -192,7 +196,9 @@ func (c *httpClusterClient) reset(eps []string) error { neps[i] = *u } - c.endpoints = neps + c.endpoints = shuffleEndpoints(c.rand, neps) + // TODO: pin old endpoint if possible, and rebalance when new endpoint appears + c.pinned = 0 return nil } @@ -203,6 +209,7 @@ func (c *httpClusterClient) Do(ctx context.Context, act httpAction) (*http.Respo leps := len(c.endpoints) eps := make([]url.URL, leps) n := copy(eps, c.endpoints) + pinned := c.pinned if c.credentials != nil { action = &authedAction{ @@ -224,8 +231,9 @@ func (c *httpClusterClient) Do(ctx context.Context, act httpAction) (*http.Respo var body []byte var err error - for _, ep := range eps { - hc := c.clientFactory(ep) + for i := pinned; i < leps+pinned; i++ { + k := i % leps + hc := c.clientFactory(eps[k]) resp, body, err = hc.Do(ctx, action) if err != nil { if err == context.DeadlineExceeded || err == context.Canceled { @@ -236,6 +244,11 @@ func (c *httpClusterClient) Do(ctx context.Context, act httpAction) (*http.Respo if resp.StatusCode/100 == 5 { continue } + if k != pinned { + c.Lock() + c.pinned = k + c.Unlock() + } break } @@ -401,3 +414,12 @@ func (r *redirectedHTTPAction) HTTPRequest(ep url.URL) *http.Request { orig.URL = &r.location return orig } + +func shuffleEndpoints(r *rand.Rand, eps []url.URL) []url.URL { + p := r.Perm(len(eps)) + neps := make([]url.URL, len(eps)) + for i, k := range p { + neps[i] = eps[k] + } + return neps +} diff --git a/client/client_test.go b/client/client_test.go index faa78ad67..c340b9f00 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -18,9 +18,11 @@ import ( "errors" "io" "io/ioutil" + "math/rand" "net/http" "net/url" "reflect" + "sort" "strings" "testing" "time" @@ -299,9 +301,10 @@ func TestHTTPClusterClientDo(t *testing.T) { fakeErr := errors.New("fake!") fakeURL := url.URL{} tests := []struct { - client *httpClusterClient - wantCode int - wantErr error + client *httpClusterClient + wantCode int + wantErr error + wantPinned int }{ // first good response short-circuits Do { @@ -313,6 +316,7 @@ func TestHTTPClusterClientDo(t *testing.T) { staticHTTPResponse{err: fakeErr}, }, ), + rand: rand.New(rand.NewSource(0)), }, wantCode: http.StatusTeapot, }, @@ -327,8 +331,10 @@ func TestHTTPClusterClientDo(t *testing.T) { staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}}, }, ), + rand: rand.New(rand.NewSource(0)), }, - wantCode: http.StatusTeapot, + wantCode: http.StatusTeapot, + wantPinned: 1, }, // context.DeadlineExceeded short-circuits Do @@ -341,6 +347,7 @@ func TestHTTPClusterClientDo(t *testing.T) { staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}}, }, ), + rand: rand.New(rand.NewSource(0)), }, wantErr: context.DeadlineExceeded, }, @@ -355,6 +362,7 @@ func TestHTTPClusterClientDo(t *testing.T) { staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}}, }, ), + rand: rand.New(rand.NewSource(0)), }, wantErr: context.Canceled, }, @@ -364,6 +372,7 @@ func TestHTTPClusterClientDo(t *testing.T) { client: &httpClusterClient{ endpoints: []url.URL{}, clientFactory: newHTTPClientFactory(nil, nil), + rand: rand.New(rand.NewSource(0)), }, wantErr: ErrNoEndpoints, }, @@ -378,6 +387,7 @@ func TestHTTPClusterClientDo(t *testing.T) { staticHTTPResponse{err: fakeErr}, }, ), + rand: rand.New(rand.NewSource(0)), }, wantErr: fakeErr, }, @@ -392,8 +402,10 @@ func TestHTTPClusterClientDo(t *testing.T) { staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}}, }, ), + rand: rand.New(rand.NewSource(0)), }, - wantCode: http.StatusTeapot, + wantCode: http.StatusTeapot, + wantPinned: 1, }, } @@ -415,6 +427,10 @@ func TestHTTPClusterClientDo(t *testing.T) { t.Errorf("#%d: resp code=%d, want=%d", i, resp.StatusCode, tt.wantCode) continue } + + if tt.client.pinned != tt.wantPinned { + t.Errorf("#%d: pinned=%d, want=%d", i, tt.client.pinned, tt.wantPinned) + } } } @@ -671,7 +687,10 @@ func TestHTTPClusterClientSync(t *testing.T) { }, }) - hc := &httpClusterClient{clientFactory: cf} + hc := &httpClusterClient{ + clientFactory: cf, + rand: rand.New(rand.NewSource(0)), + } err := hc.reset([]string{"http://127.0.0.1:2379"}) if err != nil { t.Fatalf("unexpected error during setup: %#v", err) @@ -688,8 +707,9 @@ func TestHTTPClusterClientSync(t *testing.T) { t.Fatalf("unexpected error during Sync: %#v", err) } - want = []string{"http://127.0.0.1:4003", "http://127.0.0.1:2379", "http://127.0.0.1:4001", "http://127.0.0.1:4002"} + want = []string{"http://127.0.0.1:2379", "http://127.0.0.1:4001", "http://127.0.0.1:4002", "http://127.0.0.1:4003"} got = hc.Endpoints() + sort.Sort(sort.StringSlice(got)) if !reflect.DeepEqual(want, got) { t.Fatalf("incorrect endpoints post-Sync: want=%#v got=%#v", want, got) } @@ -711,7 +731,10 @@ func TestHTTPClusterClientSyncFail(t *testing.T) { staticHTTPResponse{err: errors.New("fail!")}, }) - hc := &httpClusterClient{clientFactory: cf} + hc := &httpClusterClient{ + clientFactory: cf, + rand: rand.New(rand.NewSource(0)), + } err := hc.reset([]string{"http://127.0.0.1:2379"}) if err != nil { t.Fatalf("unexpected error during setup: %#v", err) @@ -744,10 +767,31 @@ func TestHTTPClusterClientResetFail(t *testing.T) { } for i, tt := range tests { - hc := &httpClusterClient{} + hc := &httpClusterClient{rand: rand.New(rand.NewSource(0))} err := hc.reset(tt) if err == nil { t.Errorf("#%d: expected non-nil error", i) } } } + +func TestHTTPClusterClientResetPinRandom(t *testing.T) { + round := 2000 + pinNum := 0 + for i := 0; i < round; i++ { + hc := &httpClusterClient{rand: rand.New(rand.NewSource(int64(i)))} + err := hc.reset([]string{"http://127.0.0.1:4001", "http://127.0.0.1:4002", "http://127.0.0.1:4003"}) + if err != nil { + t.Fatalf("#%d: reset error (%v)", i, err) + } + if hc.endpoints[hc.pinned].String() == "http://127.0.0.1:4001" { + pinNum++ + } + } + + min := 1.0/3.0 - 0.05 + max := 1.0/3.0 + 0.05 + if ratio := float64(pinNum) / float64(round); ratio > max || ratio < min { + t.Errorf("pinned ratio = %v, want [%v, %v]", ratio, min, max) + } +}