From 6dd4944e621ea03f248361a1815c4368c00e1ade Mon Sep 17 00:00:00 2001 From: Brian Waldon Date: Mon, 3 Nov 2014 12:15:16 -0800 Subject: [PATCH] client: follow redirects --- client/http.go | 59 +++++++++++- client/http_test.go | 222 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 276 insertions(+), 5 deletions(-) diff --git a/client/http.go b/client/http.go index 575c4dfb6..b6ab415e1 100644 --- a/client/http.go +++ b/client/http.go @@ -17,6 +17,8 @@ package client import ( + "errors" + "fmt" "io/ioutil" "net/http" "net/url" @@ -26,10 +28,12 @@ import ( ) var ( - ErrTimeout = context.DeadlineExceeded - ErrCanceled = context.Canceled + ErrTimeout = context.DeadlineExceeded + ErrCanceled = context.Canceled + ErrTooManyRedirects = errors.New("too many redirects") DefaultRequestTimeout = 5 * time.Second + DefaultMaxRedirects = 10 ) type SyncableHTTPClient interface { @@ -69,9 +73,12 @@ func newHTTPClusterClient(tr CancelableTransport, eps []string) (*httpClusterCli return nil, err } - c.endpoints[i] = &httpClient{ - transport: tr, - endpoint: *u, + c.endpoints[i] = &redirectFollowingHTTPClient{ + max: DefaultMaxRedirects, + client: &httpClient{ + transport: tr, + endpoint: *u, + }, } } @@ -168,3 +175,45 @@ func (c *httpClient) Do(ctx context.Context, act HTTPAction) (*http.Response, [] body, err := ioutil.ReadAll(resp.Body) return resp, body, err } + +type redirectFollowingHTTPClient struct { + client HTTPClient + max int +} + +func (r *redirectFollowingHTTPClient) Do(ctx context.Context, act HTTPAction) (*http.Response, []byte, error) { + for i := 0; i <= r.max; i++ { + resp, body, err := r.client.Do(ctx, act) + if err != nil { + return nil, nil, err + } + if resp.StatusCode/100 == 3 { + hdr := resp.Header.Get("Location") + if hdr == "" { + return nil, nil, fmt.Errorf("Location header not set") + } + loc, err := url.Parse(hdr) + if err != nil { + return nil, nil, fmt.Errorf("Location header not valid URL: %s", hdr) + } + act = &redirectedHTTPAction{ + action: act, + location: *loc, + } + continue + } + return resp, body, nil + } + return nil, nil, ErrTooManyRedirects +} + +type redirectedHTTPAction struct { + action HTTPAction + location url.URL +} + +func (r *redirectedHTTPAction) HTTPRequest(ep url.URL) *http.Request { + orig := r.action.HTTPRequest(ep) + orig.URL = &r.location + return orig +} diff --git a/client/http_test.go b/client/http_test.go index 99bbae8b2..7a8beedb1 100644 --- a/client/http_test.go +++ b/client/http_test.go @@ -38,6 +38,30 @@ func (s *staticHTTPClient) Do(context.Context, HTTPAction) (*http.Response, []by return &s.resp, nil, s.err } +type staticHTTPAction struct { + request http.Request +} + +type staticHTTPResponse struct { + resp http.Response + err error +} + +func (s *staticHTTPAction) HTTPRequest(url.URL) *http.Request { + return &s.request +} + +type multiStaticHTTPClient struct { + responses []staticHTTPResponse + cur int +} + +func (s *multiStaticHTTPClient) Do(context.Context, HTTPAction) (*http.Response, []byte, error) { + r := s.responses[s.cur] + s.cur++ + return &r.resp, nil, r.err +} + type fakeTransport struct { respchan chan *http.Response errchan chan error @@ -253,3 +277,201 @@ func TestHTTPClusterClientDo(t *testing.T) { } } } + +func TestRedirectedHTTPAction(t *testing.T) { + act := &redirectedHTTPAction{ + action: &staticHTTPAction{ + request: http.Request{ + Method: "DELETE", + URL: &url.URL{ + Scheme: "https", + Host: "foo.example.com", + Path: "/ping", + }, + }, + }, + location: url.URL{ + Scheme: "https", + Host: "bar.example.com", + Path: "/pong", + }, + } + + want := &http.Request{ + Method: "DELETE", + URL: &url.URL{ + Scheme: "https", + Host: "bar.example.com", + Path: "/pong", + }, + } + got := act.HTTPRequest(url.URL{Scheme: "http", Host: "baz.example.com", Path: "/pang"}) + + if !reflect.DeepEqual(want, got) { + t.Fatalf("HTTPRequest is %#v, want %#v", want, got) + } +} + +func TestRedirectFollowingHTTPClient(t *testing.T) { + tests := []struct { + max int + client HTTPClient + wantCode int + wantErr error + }{ + // errors bubbled up + { + max: 2, + client: &multiStaticHTTPClient{ + responses: []staticHTTPResponse{ + staticHTTPResponse{ + err: errors.New("fail!"), + }, + }, + }, + wantErr: errors.New("fail!"), + }, + + // no need to follow redirect if none given + { + max: 2, + client: &multiStaticHTTPClient{ + responses: []staticHTTPResponse{ + staticHTTPResponse{ + resp: http.Response{ + StatusCode: http.StatusTeapot, + }, + }, + }, + }, + wantCode: http.StatusTeapot, + }, + + // redirects if less than max + { + max: 2, + client: &multiStaticHTTPClient{ + responses: []staticHTTPResponse{ + staticHTTPResponse{ + resp: http.Response{ + StatusCode: http.StatusTemporaryRedirect, + Header: http.Header{"Location": []string{"http://example.com"}}, + }, + }, + staticHTTPResponse{ + resp: http.Response{ + StatusCode: http.StatusTeapot, + }, + }, + }, + }, + wantCode: http.StatusTeapot, + }, + + // succeed after reaching max redirects + { + max: 2, + client: &multiStaticHTTPClient{ + responses: []staticHTTPResponse{ + staticHTTPResponse{ + resp: http.Response{ + StatusCode: http.StatusTemporaryRedirect, + Header: http.Header{"Location": []string{"http://example.com"}}, + }, + }, + staticHTTPResponse{ + resp: http.Response{ + StatusCode: http.StatusTemporaryRedirect, + Header: http.Header{"Location": []string{"http://example.com"}}, + }, + }, + staticHTTPResponse{ + resp: http.Response{ + StatusCode: http.StatusTeapot, + }, + }, + }, + }, + wantCode: http.StatusTeapot, + }, + + // fail at max+1 redirects + { + max: 1, + client: &multiStaticHTTPClient{ + responses: []staticHTTPResponse{ + staticHTTPResponse{ + resp: http.Response{ + StatusCode: http.StatusTemporaryRedirect, + Header: http.Header{"Location": []string{"http://example.com"}}, + }, + }, + staticHTTPResponse{ + resp: http.Response{ + StatusCode: http.StatusTemporaryRedirect, + Header: http.Header{"Location": []string{"http://example.com"}}, + }, + }, + staticHTTPResponse{ + resp: http.Response{ + StatusCode: http.StatusTeapot, + }, + }, + }, + }, + wantErr: ErrTooManyRedirects, + }, + + // fail if Location header not set + { + max: 1, + client: &multiStaticHTTPClient{ + responses: []staticHTTPResponse{ + staticHTTPResponse{ + resp: http.Response{ + StatusCode: http.StatusTemporaryRedirect, + }, + }, + }, + }, + wantErr: errors.New("Location header not set"), + }, + + // fail if Location header is invalid + { + max: 1, + client: &multiStaticHTTPClient{ + responses: []staticHTTPResponse{ + staticHTTPResponse{ + resp: http.Response{ + StatusCode: http.StatusTemporaryRedirect, + Header: http.Header{"Location": []string{":"}}, + }, + }, + }, + }, + wantErr: errors.New("Location header not valid URL: :"), + }, + } + + for i, tt := range tests { + client := &redirectFollowingHTTPClient{client: tt.client, max: tt.max} + resp, _, err := client.Do(context.Background(), nil) + if !reflect.DeepEqual(tt.wantErr, err) { + t.Errorf("#%d: got err=%v, want=%v", i, err, tt.wantErr) + continue + } + + if resp == nil { + if tt.wantCode != 0 { + t.Errorf("#%d: resp is nil, want=%d", i, tt.wantCode) + } + continue + } + + if resp.StatusCode != tt.wantCode { + t.Errorf("#%d: resp code=%d, want=%d", i, resp.StatusCode, tt.wantCode) + continue + } + } +}