mirror of
https://github.com/etcd-io/etcd.git
synced 2024-09-27 06:25:44 +00:00
Merge pull request #1571 from bcwaldon/client-redirects
client: follow redirects
This commit is contained in:
commit
d1ec13210f
@ -17,6 +17,8 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@ -26,10 +28,12 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrTimeout = context.DeadlineExceeded
|
||||
ErrCanceled = context.Canceled
|
||||
ErrTimeout = context.DeadlineExceeded
|
||||
ErrCanceled = context.Canceled
|
||||
ErrTooManyRedirects = errors.New("too many redirects")
|
||||
|
||||
DefaultRequestTimeout = 5 * time.Second
|
||||
DefaultMaxRedirects = 10
|
||||
)
|
||||
|
||||
type SyncableHTTPClient interface {
|
||||
@ -69,9 +73,12 @@ func newHTTPClusterClient(tr CancelableTransport, eps []string) (*httpClusterCli
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.endpoints[i] = &httpClient{
|
||||
transport: tr,
|
||||
endpoint: *u,
|
||||
c.endpoints[i] = &redirectFollowingHTTPClient{
|
||||
max: DefaultMaxRedirects,
|
||||
client: &httpClient{
|
||||
transport: tr,
|
||||
endpoint: *u,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@ -168,3 +175,45 @@ func (c *httpClient) Do(ctx context.Context, act HTTPAction) (*http.Response, []
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
return resp, body, err
|
||||
}
|
||||
|
||||
type redirectFollowingHTTPClient struct {
|
||||
client HTTPClient
|
||||
max int
|
||||
}
|
||||
|
||||
func (r *redirectFollowingHTTPClient) Do(ctx context.Context, act HTTPAction) (*http.Response, []byte, error) {
|
||||
for i := 0; i <= r.max; i++ {
|
||||
resp, body, err := r.client.Do(ctx, act)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if resp.StatusCode/100 == 3 {
|
||||
hdr := resp.Header.Get("Location")
|
||||
if hdr == "" {
|
||||
return nil, nil, fmt.Errorf("Location header not set")
|
||||
}
|
||||
loc, err := url.Parse(hdr)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("Location header not valid URL: %s", hdr)
|
||||
}
|
||||
act = &redirectedHTTPAction{
|
||||
action: act,
|
||||
location: *loc,
|
||||
}
|
||||
continue
|
||||
}
|
||||
return resp, body, nil
|
||||
}
|
||||
return nil, nil, ErrTooManyRedirects
|
||||
}
|
||||
|
||||
type redirectedHTTPAction struct {
|
||||
action HTTPAction
|
||||
location url.URL
|
||||
}
|
||||
|
||||
func (r *redirectedHTTPAction) HTTPRequest(ep url.URL) *http.Request {
|
||||
orig := r.action.HTTPRequest(ep)
|
||||
orig.URL = &r.location
|
||||
return orig
|
||||
}
|
||||
|
@ -38,6 +38,30 @@ func (s *staticHTTPClient) Do(context.Context, HTTPAction) (*http.Response, []by
|
||||
return &s.resp, nil, s.err
|
||||
}
|
||||
|
||||
type staticHTTPAction struct {
|
||||
request http.Request
|
||||
}
|
||||
|
||||
type staticHTTPResponse struct {
|
||||
resp http.Response
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *staticHTTPAction) HTTPRequest(url.URL) *http.Request {
|
||||
return &s.request
|
||||
}
|
||||
|
||||
type multiStaticHTTPClient struct {
|
||||
responses []staticHTTPResponse
|
||||
cur int
|
||||
}
|
||||
|
||||
func (s *multiStaticHTTPClient) Do(context.Context, HTTPAction) (*http.Response, []byte, error) {
|
||||
r := s.responses[s.cur]
|
||||
s.cur++
|
||||
return &r.resp, nil, r.err
|
||||
}
|
||||
|
||||
type fakeTransport struct {
|
||||
respchan chan *http.Response
|
||||
errchan chan error
|
||||
@ -253,3 +277,201 @@ func TestHTTPClusterClientDo(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedirectedHTTPAction(t *testing.T) {
|
||||
act := &redirectedHTTPAction{
|
||||
action: &staticHTTPAction{
|
||||
request: http.Request{
|
||||
Method: "DELETE",
|
||||
URL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "foo.example.com",
|
||||
Path: "/ping",
|
||||
},
|
||||
},
|
||||
},
|
||||
location: url.URL{
|
||||
Scheme: "https",
|
||||
Host: "bar.example.com",
|
||||
Path: "/pong",
|
||||
},
|
||||
}
|
||||
|
||||
want := &http.Request{
|
||||
Method: "DELETE",
|
||||
URL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "bar.example.com",
|
||||
Path: "/pong",
|
||||
},
|
||||
}
|
||||
got := act.HTTPRequest(url.URL{Scheme: "http", Host: "baz.example.com", Path: "/pang"})
|
||||
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Fatalf("HTTPRequest is %#v, want %#v", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedirectFollowingHTTPClient(t *testing.T) {
|
||||
tests := []struct {
|
||||
max int
|
||||
client HTTPClient
|
||||
wantCode int
|
||||
wantErr error
|
||||
}{
|
||||
// errors bubbled up
|
||||
{
|
||||
max: 2,
|
||||
client: &multiStaticHTTPClient{
|
||||
responses: []staticHTTPResponse{
|
||||
staticHTTPResponse{
|
||||
err: errors.New("fail!"),
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: errors.New("fail!"),
|
||||
},
|
||||
|
||||
// no need to follow redirect if none given
|
||||
{
|
||||
max: 2,
|
||||
client: &multiStaticHTTPClient{
|
||||
responses: []staticHTTPResponse{
|
||||
staticHTTPResponse{
|
||||
resp: http.Response{
|
||||
StatusCode: http.StatusTeapot,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantCode: http.StatusTeapot,
|
||||
},
|
||||
|
||||
// redirects if less than max
|
||||
{
|
||||
max: 2,
|
||||
client: &multiStaticHTTPClient{
|
||||
responses: []staticHTTPResponse{
|
||||
staticHTTPResponse{
|
||||
resp: http.Response{
|
||||
StatusCode: http.StatusTemporaryRedirect,
|
||||
Header: http.Header{"Location": []string{"http://example.com"}},
|
||||
},
|
||||
},
|
||||
staticHTTPResponse{
|
||||
resp: http.Response{
|
||||
StatusCode: http.StatusTeapot,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantCode: http.StatusTeapot,
|
||||
},
|
||||
|
||||
// succeed after reaching max redirects
|
||||
{
|
||||
max: 2,
|
||||
client: &multiStaticHTTPClient{
|
||||
responses: []staticHTTPResponse{
|
||||
staticHTTPResponse{
|
||||
resp: http.Response{
|
||||
StatusCode: http.StatusTemporaryRedirect,
|
||||
Header: http.Header{"Location": []string{"http://example.com"}},
|
||||
},
|
||||
},
|
||||
staticHTTPResponse{
|
||||
resp: http.Response{
|
||||
StatusCode: http.StatusTemporaryRedirect,
|
||||
Header: http.Header{"Location": []string{"http://example.com"}},
|
||||
},
|
||||
},
|
||||
staticHTTPResponse{
|
||||
resp: http.Response{
|
||||
StatusCode: http.StatusTeapot,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantCode: http.StatusTeapot,
|
||||
},
|
||||
|
||||
// fail at max+1 redirects
|
||||
{
|
||||
max: 1,
|
||||
client: &multiStaticHTTPClient{
|
||||
responses: []staticHTTPResponse{
|
||||
staticHTTPResponse{
|
||||
resp: http.Response{
|
||||
StatusCode: http.StatusTemporaryRedirect,
|
||||
Header: http.Header{"Location": []string{"http://example.com"}},
|
||||
},
|
||||
},
|
||||
staticHTTPResponse{
|
||||
resp: http.Response{
|
||||
StatusCode: http.StatusTemporaryRedirect,
|
||||
Header: http.Header{"Location": []string{"http://example.com"}},
|
||||
},
|
||||
},
|
||||
staticHTTPResponse{
|
||||
resp: http.Response{
|
||||
StatusCode: http.StatusTeapot,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: ErrTooManyRedirects,
|
||||
},
|
||||
|
||||
// fail if Location header not set
|
||||
{
|
||||
max: 1,
|
||||
client: &multiStaticHTTPClient{
|
||||
responses: []staticHTTPResponse{
|
||||
staticHTTPResponse{
|
||||
resp: http.Response{
|
||||
StatusCode: http.StatusTemporaryRedirect,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: errors.New("Location header not set"),
|
||||
},
|
||||
|
||||
// fail if Location header is invalid
|
||||
{
|
||||
max: 1,
|
||||
client: &multiStaticHTTPClient{
|
||||
responses: []staticHTTPResponse{
|
||||
staticHTTPResponse{
|
||||
resp: http.Response{
|
||||
StatusCode: http.StatusTemporaryRedirect,
|
||||
Header: http.Header{"Location": []string{":"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: errors.New("Location header not valid URL: :"),
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
client := &redirectFollowingHTTPClient{client: tt.client, max: tt.max}
|
||||
resp, _, err := client.Do(context.Background(), nil)
|
||||
if !reflect.DeepEqual(tt.wantErr, err) {
|
||||
t.Errorf("#%d: got err=%v, want=%v", i, err, tt.wantErr)
|
||||
continue
|
||||
}
|
||||
|
||||
if resp == nil {
|
||||
if tt.wantCode != 0 {
|
||||
t.Errorf("#%d: resp is nil, want=%d", i, tt.wantCode)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode != tt.wantCode {
|
||||
t.Errorf("#%d: resp code=%d, want=%d", i, resp.StatusCode, tt.wantCode)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user