From 9b334e07a613c6baa7265bec8f03bb1022e0a60f Mon Sep 17 00:00:00 2001 From: Brian Waldon Date: Wed, 28 Jan 2015 15:09:00 -0800 Subject: [PATCH] client: allow caller to decide HTTP redirect policy --- client/client.go | 47 +++++++++++++++++++++++++++++++++++-------- client/client_test.go | 43 ++++++++++++++++++++++++++------------- 2 files changed, 68 insertions(+), 22 deletions(-) diff --git a/client/client.go b/client/client.go index 73780a290..3c4368488 100644 --- a/client/client.go +++ b/client/client.go @@ -39,7 +39,6 @@ var ( ErrKeyExists = errors.New("client: key already exists") DefaultRequestTimeout = 5 * time.Second - DefaultMaxRedirects = 10 ) var DefaultTransport CancelableTransport = &http.Transport{ @@ -72,6 +71,17 @@ type Config struct { // Transport is used by the Client to drive HTTP requests. If not // provided, DefaultTransport will be used. Transport CancelableTransport + + // CheckRedirect specifies the policy for handling HTTP redirects. + // If CheckRedirect is not nil, the Client calls it before + // following an HTTP redirect. The sole argument is the number of + // requests that have alrady been made. If CheckRedirect returns + // an error, Client.Do will not make any further requests and return + // the error back it to the caller. + // + // If CheckRedirect is nil, the Client uses its default policy, + // which is to stop after 10 consecutive requests. + CheckRedirect CheckRedirectFunc } func (cfg *Config) transport() CancelableTransport { @@ -81,6 +91,13 @@ func (cfg *Config) transport() CancelableTransport { return cfg.Transport } +func (cfg *Config) checkRedirect() CheckRedirectFunc { + if cfg.CheckRedirect == nil { + return DefaultCheckRedirect + } + return cfg.CheckRedirect +} + // CancelableTransport mimics net/http.Transport, but requires that // the object also support request cancellation. type CancelableTransport interface { @@ -88,6 +105,16 @@ type CancelableTransport interface { CancelRequest(req *http.Request) } +type CheckRedirectFunc func(via int) error + +// DefaultCheckRedirect follows up to 10 redirects, but no more. +var DefaultCheckRedirect CheckRedirectFunc = func(via int) error { + if via > 10 { + return ErrTooManyRedirects + } + return nil +} + type Client interface { // Sync updates the internal cache of the etcd cluster's membership. Sync(context.Context) error @@ -101,7 +128,7 @@ type Client interface { } func New(cfg Config) (Client, error) { - c := &httpClusterClient{clientFactory: newHTTPClientFactory(cfg.transport())} + c := &httpClusterClient{clientFactory: newHTTPClientFactory(cfg.transport(), cfg.checkRedirect())} if err := c.reset(cfg.Endpoints); err != nil { return nil, err } @@ -112,10 +139,10 @@ type httpClient interface { Do(context.Context, httpAction) (*http.Response, []byte, error) } -func newHTTPClientFactory(tr CancelableTransport) httpClientFactory { +func newHTTPClientFactory(tr CancelableTransport, cr CheckRedirectFunc) httpClientFactory { return func(ep url.URL) httpClient { return &redirectFollowingHTTPClient{ - max: DefaultMaxRedirects, + checkRedirect: cr, client: &simpleHTTPClient{ transport: tr, endpoint: ep, @@ -270,12 +297,17 @@ func (c *simpleHTTPClient) Do(ctx context.Context, act httpAction) (*http.Respon } type redirectFollowingHTTPClient struct { - client httpClient - max int + client httpClient + checkRedirect CheckRedirectFunc } func (r *redirectFollowingHTTPClient) Do(ctx context.Context, act httpAction) (*http.Response, []byte, error) { - for i := 0; i <= r.max; i++ { + for i := 0; ; i++ { + if i > 0 { + if err := r.checkRedirect(i); err != nil { + return nil, nil, err + } + } resp, body, err := r.client.Do(ctx, act) if err != nil { return nil, nil, err @@ -297,7 +329,6 @@ func (r *redirectFollowingHTTPClient) Do(ctx context.Context, act httpAction) (* } return resp, body, nil } - return nil, nil, ErrTooManyRedirects } type redirectedHTTPAction struct { diff --git a/client/client_test.go b/client/client_test.go index ab6be99ef..f57ffb94c 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -258,7 +258,7 @@ func TestHTTPClusterClientDo(t *testing.T) { { client: &httpClusterClient{ endpoints: []url.URL{}, - clientFactory: newHTTPClientFactory(nil), + clientFactory: newHTTPClientFactory(nil, nil), }, wantErr: ErrNoEndpoints, }, @@ -349,14 +349,14 @@ func TestRedirectedHTTPAction(t *testing.T) { func TestRedirectFollowingHTTPClient(t *testing.T) { tests := []struct { - max int - client httpClient - wantCode int - wantErr error + checkRedirect CheckRedirectFunc + client httpClient + wantCode int + wantErr error }{ // errors bubbled up { - max: 2, + checkRedirect: func(int) error { return ErrTooManyRedirects }, client: &multiStaticHTTPClient{ responses: []staticHTTPResponse{ staticHTTPResponse{ @@ -369,7 +369,7 @@ func TestRedirectFollowingHTTPClient(t *testing.T) { // no need to follow redirect if none given { - max: 2, + checkRedirect: func(int) error { return ErrTooManyRedirects }, client: &multiStaticHTTPClient{ responses: []staticHTTPResponse{ staticHTTPResponse{ @@ -384,7 +384,12 @@ func TestRedirectFollowingHTTPClient(t *testing.T) { // redirects if less than max { - max: 2, + checkRedirect: func(via int) error { + if via >= 2 { + return ErrTooManyRedirects + } + return nil + }, client: &multiStaticHTTPClient{ responses: []staticHTTPResponse{ staticHTTPResponse{ @@ -405,7 +410,12 @@ func TestRedirectFollowingHTTPClient(t *testing.T) { // succeed after reaching max redirects { - max: 2, + checkRedirect: func(via int) error { + if via >= 3 { + return ErrTooManyRedirects + } + return nil + }, client: &multiStaticHTTPClient{ responses: []staticHTTPResponse{ staticHTTPResponse{ @@ -430,9 +440,14 @@ func TestRedirectFollowingHTTPClient(t *testing.T) { wantCode: http.StatusTeapot, }, - // fail at max+1 redirects + // fail if too many redirects { - max: 1, + checkRedirect: func(via int) error { + if via >= 2 { + return ErrTooManyRedirects + } + return nil + }, client: &multiStaticHTTPClient{ responses: []staticHTTPResponse{ staticHTTPResponse{ @@ -459,7 +474,7 @@ func TestRedirectFollowingHTTPClient(t *testing.T) { // fail if Location header not set { - max: 1, + checkRedirect: func(int) error { return ErrTooManyRedirects }, client: &multiStaticHTTPClient{ responses: []staticHTTPResponse{ staticHTTPResponse{ @@ -474,7 +489,7 @@ func TestRedirectFollowingHTTPClient(t *testing.T) { // fail if Location header is invalid { - max: 1, + checkRedirect: func(int) error { return ErrTooManyRedirects }, client: &multiStaticHTTPClient{ responses: []staticHTTPResponse{ staticHTTPResponse{ @@ -490,7 +505,7 @@ func TestRedirectFollowingHTTPClient(t *testing.T) { } for i, tt := range tests { - client := &redirectFollowingHTTPClient{client: tt.client, max: tt.max} + client := &redirectFollowingHTTPClient{client: tt.client, checkRedirect: tt.checkRedirect} resp, _, err := client.Do(context.Background(), nil) if !reflect.DeepEqual(tt.wantErr, err) { t.Errorf("#%d: got err=%v, want=%v", i, err, tt.wantErr)