From da6827f09e870b209473192e9c7e36cbb287116d Mon Sep 17 00:00:00 2001 From: Brian Waldon Date: Fri, 31 Oct 2014 19:56:48 -0700 Subject: [PATCH] client: use all endpoints --- client/http.go | 22 ++++++++-- client/http_test.go | 104 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 122 insertions(+), 4 deletions(-) diff --git a/client/http.go b/client/http.go index a5ae7e56f..575c4dfb6 100644 --- a/client/http.go +++ b/client/http.go @@ -26,7 +26,9 @@ import ( ) var ( - ErrTimeout = context.DeadlineExceeded + ErrTimeout = context.DeadlineExceeded + ErrCanceled = context.Canceled + DefaultRequestTimeout = 5 * time.Second ) @@ -81,9 +83,21 @@ type httpClusterClient struct { endpoints []HTTPClient } -func (c *httpClusterClient) Do(ctx context.Context, act HTTPAction) (*http.Response, []byte, error) { - //TODO(bcwaldon): introduce retry logic so all endpoints are attempted - return c.endpoints[0].Do(ctx, act) +func (c *httpClusterClient) Do(ctx context.Context, act HTTPAction) (resp *http.Response, body []byte, err error) { + for _, hc := range c.endpoints { + resp, body, err = hc.Do(ctx, act) + if err != nil { + if err == ErrTimeout || err == ErrCanceled { + return nil, nil, err + } + continue + } + if resp.StatusCode/100 == 5 { + continue + } + break + } + return } func (c *httpClusterClient) Sync(ctx context.Context) error { diff --git a/client/http_test.go b/client/http_test.go index 33062b51e..99bbae8b2 100644 --- a/client/http_test.go +++ b/client/http_test.go @@ -29,6 +29,15 @@ import ( "github.com/coreos/etcd/Godeps/_workspace/src/code.google.com/p/go.net/context" ) +type staticHTTPClient struct { + resp http.Response + err error +} + +func (s *staticHTTPClient) Do(context.Context, HTTPAction) (*http.Response, []byte, error) { + return &s.resp, nil, s.err +} + type fakeTransport struct { respchan chan *http.Response errchan chan error @@ -149,3 +158,98 @@ func TestHTTPClientDoCancelContextWaitForRoundTrip(t *testing.T) { t.Fatalf("httpClient.do did not exit within 1s") } } + +func TestHTTPClusterClientDo(t *testing.T) { + fakeErr := errors.New("fake!") + tests := []struct { + client *httpClusterClient + wantCode int + wantErr error + }{ + // first good response short-circuits Do + { + client: &httpClusterClient{ + endpoints: []HTTPClient{ + &staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}}, + &staticHTTPClient{err: fakeErr}, + }, + }, + wantCode: http.StatusTeapot, + }, + + // fall through to good endpoint if err is arbitrary + { + client: &httpClusterClient{ + endpoints: []HTTPClient{ + &staticHTTPClient{err: fakeErr}, + &staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}}, + }, + }, + wantCode: http.StatusTeapot, + }, + + // ErrTimeout short-circuits Do + { + client: &httpClusterClient{ + endpoints: []HTTPClient{ + &staticHTTPClient{err: ErrTimeout}, + &staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}}, + }, + }, + wantErr: ErrTimeout, + }, + + // ErrCanceled short-circuits Do + { + client: &httpClusterClient{ + endpoints: []HTTPClient{ + &staticHTTPClient{err: ErrCanceled}, + &staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}}, + }, + }, + wantErr: ErrCanceled, + }, + + // return err if all endpoints return arbitrary errors + { + client: &httpClusterClient{ + endpoints: []HTTPClient{ + &staticHTTPClient{err: fakeErr}, + &staticHTTPClient{err: fakeErr}, + }, + }, + wantErr: fakeErr, + }, + + // 500-level errors cause Do to fallthrough to next endpoint + { + client: &httpClusterClient{ + endpoints: []HTTPClient{ + &staticHTTPClient{resp: http.Response{StatusCode: http.StatusBadGateway}}, + &staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}}, + }, + }, + wantCode: http.StatusTeapot, + }, + } + + for i, tt := range tests { + resp, _, err := tt.client.Do(context.Background(), nil) + if !reflect.DeepEqual(tt.wantErr, err) { + t.Errorf("#%d: got err=%v, want=%v", i, err, tt.wantErr) + continue + } + + if resp == nil { + if tt.wantCode != 0 { + t.Errorf("#%d: resp is nil, want=%d", i, tt.wantCode) + } + continue + } + + if resp.StatusCode != tt.wantCode { + t.Errorf("#%d: resp code=%d, want=%d", i, resp.StatusCode, tt.wantCode) + continue + } + } +}