diff --git a/client/http.go b/client/http.go index a24103f5d..738f5ea6d 100644 --- a/client/http.go +++ b/client/http.go @@ -36,13 +36,23 @@ var ( DefaultMaxRedirects = 10 ) +func defaultHTTPClientFactory(tr CancelableTransport, ep url.URL) HTTPClient { + return &redirectFollowingHTTPClient{ + max: DefaultMaxRedirects, + client: &httpClient{ + transport: tr, + endpoint: ep, + }, + } +} + type ClientConfig struct { Endpoints []string Transport CancelableTransport } func New(cfg ClientConfig) (SyncableHTTPClient, error) { - return newHTTPClusterClient(cfg.Transport, cfg.Endpoints) + return newHTTPClusterClient(cfg.Transport, cfg.Endpoints, defaultHTTPClientFactory) } type SyncableHTTPClient interface { @@ -55,6 +65,8 @@ type HTTPClient interface { Do(context.Context, HTTPAction) (*http.Response, []byte, error) } +type httpClientFactory func(CancelableTransport, url.URL) HTTPClient + type HTTPAction interface { HTTPRequest(url.URL) *http.Request } @@ -67,8 +79,8 @@ type CancelableTransport interface { CancelRequest(req *http.Request) } -func newHTTPClusterClient(tr CancelableTransport, eps []string) (*httpClusterClient, error) { - c := &httpClusterClient{} +func newHTTPClusterClient(tr CancelableTransport, eps []string, cf httpClientFactory) (*httpClusterClient, error) { + c := &httpClusterClient{clientFactory: cf} if err := c.reset(tr, eps); err != nil { return nil, err } @@ -76,37 +88,27 @@ func newHTTPClusterClient(tr CancelableTransport, eps []string) (*httpClusterCli } type httpClusterClient struct { - transport CancelableTransport - endpoints []string - clients []HTTPClient + clientFactory httpClientFactory + transport CancelableTransport + endpoints []url.URL sync.RWMutex } func (c *httpClusterClient) reset(tr CancelableTransport, eps []string) error { - le := len(eps) - ne := make([]string, le) - if copy(ne, eps) != le { - return errors.New("copy call failed") + if len(eps) == 0 { + return ErrNoEndpoints } - nc := make([]HTTPClient, len(ne)) - for i, e := range ne { - u, err := url.Parse(e) + neps := make([]url.URL, len(eps)) + for i, ep := range eps { + u, err := url.Parse(ep) if err != nil { return err } - - nc[i] = &redirectFollowingHTTPClient{ - max: DefaultMaxRedirects, - client: &httpClient{ - transport: tr, - endpoint: *u, - }, - } + neps[i] = *u } - c.endpoints = ne - c.clients = nc + c.endpoints = neps c.transport = tr return nil @@ -114,12 +116,24 @@ func (c *httpClusterClient) reset(tr CancelableTransport, eps []string) error { func (c *httpClusterClient) Do(ctx context.Context, act HTTPAction) (resp *http.Response, body []byte, err error) { c.RLock() - defer c.RUnlock() + leps := len(c.endpoints) + eps := make([]url.URL, leps) + n := copy(eps, c.endpoints) + tr := c.transport + c.RUnlock() - if len(c.clients) == 0 { - return nil, nil, ErrNoEndpoints + if leps == 0 { + err = ErrNoEndpoints + return } - for _, hc := range c.clients { + + if leps != n { + err = errors.New("unable to pick endpoint: copy failed") + return + } + + for _, ep := range eps { + hc := c.clientFactory(tr, ep) resp, body, err = hc.Do(ctx, act) if err != nil { if err == ErrTimeout || err == ErrCanceled { @@ -132,13 +146,20 @@ func (c *httpClusterClient) Do(ctx context.Context, act HTTPAction) (resp *http. } break } + return } func (c *httpClusterClient) Endpoints() []string { c.RLock() defer c.RUnlock() - return c.endpoints + + eps := make([]string, len(c.endpoints)) + for i, ep := range c.endpoints { + eps[i] = ep.String() + } + + return eps } func (c *httpClusterClient) Sync(ctx context.Context) error { @@ -155,9 +176,6 @@ func (c *httpClusterClient) Sync(ctx context.Context) error { for _, m := range ms { eps = append(eps, m.ClientURLs...) } - if len(eps) == 0 { - return ErrNoEndpoints - } return c.reset(c.transport, eps) } diff --git a/client/http_test.go b/client/http_test.go index bbe27c569..45c27cdbc 100644 --- a/client/http_test.go +++ b/client/http_test.go @@ -60,6 +60,15 @@ func (s *multiStaticHTTPClient) Do(context.Context, HTTPAction) (*http.Response, return &r.resp, nil, r.err } +func newStaticHTTPClientFactory(responses []staticHTTPResponse) httpClientFactory { + var cur int + return func(CancelableTransport, url.URL) HTTPClient { + r := responses[cur] + cur++ + return &staticHTTPClient{resp: r.resp, err: r.err} + } +} + type fakeTransport struct { respchan chan *http.Response errchan chan error @@ -183,6 +192,7 @@ func TestHTTPClientDoCancelContextWaitForRoundTrip(t *testing.T) { func TestHTTPClusterClientDo(t *testing.T) { fakeErr := errors.New("fake!") + fakeURL := url.URL{} tests := []struct { client *httpClusterClient wantCode int @@ -191,10 +201,13 @@ func TestHTTPClusterClientDo(t *testing.T) { // first good response short-circuits Do { client: &httpClusterClient{ - clients: []HTTPClient{ - &staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}}, - &staticHTTPClient{err: fakeErr}, - }, + endpoints: []url.URL{fakeURL, fakeURL}, + clientFactory: newStaticHTTPClientFactory( + []staticHTTPResponse{ + staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}}, + staticHTTPResponse{err: fakeErr}, + }, + ), }, wantCode: http.StatusTeapot, }, @@ -202,10 +215,13 @@ func TestHTTPClusterClientDo(t *testing.T) { // fall through to good endpoint if err is arbitrary { client: &httpClusterClient{ - clients: []HTTPClient{ - &staticHTTPClient{err: fakeErr}, - &staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}}, - }, + endpoints: []url.URL{fakeURL, fakeURL}, + clientFactory: newStaticHTTPClientFactory( + []staticHTTPResponse{ + staticHTTPResponse{err: fakeErr}, + staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}}, + }, + ), }, wantCode: http.StatusTeapot, }, @@ -213,10 +229,13 @@ func TestHTTPClusterClientDo(t *testing.T) { // ErrTimeout short-circuits Do { client: &httpClusterClient{ - clients: []HTTPClient{ - &staticHTTPClient{err: ErrTimeout}, - &staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}}, - }, + endpoints: []url.URL{fakeURL, fakeURL}, + clientFactory: newStaticHTTPClientFactory( + []staticHTTPResponse{ + staticHTTPResponse{err: ErrTimeout}, + staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}}, + }, + ), }, wantErr: ErrTimeout, }, @@ -224,10 +243,13 @@ func TestHTTPClusterClientDo(t *testing.T) { // ErrCanceled short-circuits Do { client: &httpClusterClient{ - clients: []HTTPClient{ - &staticHTTPClient{err: ErrCanceled}, - &staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}}, - }, + endpoints: []url.URL{fakeURL, fakeURL}, + clientFactory: newStaticHTTPClientFactory( + []staticHTTPResponse{ + staticHTTPResponse{err: ErrCanceled}, + staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}}, + }, + ), }, wantErr: ErrCanceled, }, @@ -235,7 +257,8 @@ func TestHTTPClusterClientDo(t *testing.T) { // return err if there are no endpoints { client: &httpClusterClient{ - clients: []HTTPClient{}, + endpoints: []url.URL{}, + clientFactory: defaultHTTPClientFactory, }, wantErr: ErrNoEndpoints, }, @@ -243,10 +266,13 @@ func TestHTTPClusterClientDo(t *testing.T) { // return err if all endpoints return arbitrary errors { client: &httpClusterClient{ - clients: []HTTPClient{ - &staticHTTPClient{err: fakeErr}, - &staticHTTPClient{err: fakeErr}, - }, + endpoints: []url.URL{fakeURL, fakeURL}, + clientFactory: newStaticHTTPClientFactory( + []staticHTTPResponse{ + staticHTTPResponse{err: fakeErr}, + staticHTTPResponse{err: fakeErr}, + }, + ), }, wantErr: fakeErr, }, @@ -254,10 +280,13 @@ func TestHTTPClusterClientDo(t *testing.T) { // 500-level errors cause Do to fallthrough to next endpoint { client: &httpClusterClient{ - clients: []HTTPClient{ - &staticHTTPClient{resp: http.Response{StatusCode: http.StatusBadGateway}}, - &staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}}, - }, + endpoints: []url.URL{fakeURL, fakeURL}, + clientFactory: newStaticHTTPClientFactory( + []staticHTTPResponse{ + staticHTTPResponse{resp: http.Response{StatusCode: http.StatusBadGateway}}, + staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}}, + }, + ), }, wantCode: http.StatusTeapot, },