mirror of
https://github.com/etcd-io/etcd.git
synced 2024-09-27 06:25:44 +00:00
Merge pull request #1571 from bcwaldon/client-redirects
client: follow redirects
This commit is contained in:
commit
d1ec13210f
@ -17,6 +17,8 @@
|
|||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
@ -26,10 +28,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrTimeout = context.DeadlineExceeded
|
ErrTimeout = context.DeadlineExceeded
|
||||||
ErrCanceled = context.Canceled
|
ErrCanceled = context.Canceled
|
||||||
|
ErrTooManyRedirects = errors.New("too many redirects")
|
||||||
|
|
||||||
DefaultRequestTimeout = 5 * time.Second
|
DefaultRequestTimeout = 5 * time.Second
|
||||||
|
DefaultMaxRedirects = 10
|
||||||
)
|
)
|
||||||
|
|
||||||
type SyncableHTTPClient interface {
|
type SyncableHTTPClient interface {
|
||||||
@ -69,9 +73,12 @@ func newHTTPClusterClient(tr CancelableTransport, eps []string) (*httpClusterCli
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
c.endpoints[i] = &httpClient{
|
c.endpoints[i] = &redirectFollowingHTTPClient{
|
||||||
transport: tr,
|
max: DefaultMaxRedirects,
|
||||||
endpoint: *u,
|
client: &httpClient{
|
||||||
|
transport: tr,
|
||||||
|
endpoint: *u,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -168,3 +175,45 @@ func (c *httpClient) Do(ctx context.Context, act HTTPAction) (*http.Response, []
|
|||||||
body, err := ioutil.ReadAll(resp.Body)
|
body, err := ioutil.ReadAll(resp.Body)
|
||||||
return resp, body, err
|
return resp, body, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type redirectFollowingHTTPClient struct {
|
||||||
|
client HTTPClient
|
||||||
|
max int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *redirectFollowingHTTPClient) Do(ctx context.Context, act HTTPAction) (*http.Response, []byte, error) {
|
||||||
|
for i := 0; i <= r.max; i++ {
|
||||||
|
resp, body, err := r.client.Do(ctx, act)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
if resp.StatusCode/100 == 3 {
|
||||||
|
hdr := resp.Header.Get("Location")
|
||||||
|
if hdr == "" {
|
||||||
|
return nil, nil, fmt.Errorf("Location header not set")
|
||||||
|
}
|
||||||
|
loc, err := url.Parse(hdr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("Location header not valid URL: %s", hdr)
|
||||||
|
}
|
||||||
|
act = &redirectedHTTPAction{
|
||||||
|
action: act,
|
||||||
|
location: *loc,
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return resp, body, nil
|
||||||
|
}
|
||||||
|
return nil, nil, ErrTooManyRedirects
|
||||||
|
}
|
||||||
|
|
||||||
|
type redirectedHTTPAction struct {
|
||||||
|
action HTTPAction
|
||||||
|
location url.URL
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *redirectedHTTPAction) HTTPRequest(ep url.URL) *http.Request {
|
||||||
|
orig := r.action.HTTPRequest(ep)
|
||||||
|
orig.URL = &r.location
|
||||||
|
return orig
|
||||||
|
}
|
||||||
|
@ -38,6 +38,30 @@ func (s *staticHTTPClient) Do(context.Context, HTTPAction) (*http.Response, []by
|
|||||||
return &s.resp, nil, s.err
|
return &s.resp, nil, s.err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type staticHTTPAction struct {
|
||||||
|
request http.Request
|
||||||
|
}
|
||||||
|
|
||||||
|
type staticHTTPResponse struct {
|
||||||
|
resp http.Response
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *staticHTTPAction) HTTPRequest(url.URL) *http.Request {
|
||||||
|
return &s.request
|
||||||
|
}
|
||||||
|
|
||||||
|
type multiStaticHTTPClient struct {
|
||||||
|
responses []staticHTTPResponse
|
||||||
|
cur int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *multiStaticHTTPClient) Do(context.Context, HTTPAction) (*http.Response, []byte, error) {
|
||||||
|
r := s.responses[s.cur]
|
||||||
|
s.cur++
|
||||||
|
return &r.resp, nil, r.err
|
||||||
|
}
|
||||||
|
|
||||||
type fakeTransport struct {
|
type fakeTransport struct {
|
||||||
respchan chan *http.Response
|
respchan chan *http.Response
|
||||||
errchan chan error
|
errchan chan error
|
||||||
@ -253,3 +277,201 @@ func TestHTTPClusterClientDo(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRedirectedHTTPAction(t *testing.T) {
|
||||||
|
act := &redirectedHTTPAction{
|
||||||
|
action: &staticHTTPAction{
|
||||||
|
request: http.Request{
|
||||||
|
Method: "DELETE",
|
||||||
|
URL: &url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: "foo.example.com",
|
||||||
|
Path: "/ping",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
location: url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: "bar.example.com",
|
||||||
|
Path: "/pong",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
want := &http.Request{
|
||||||
|
Method: "DELETE",
|
||||||
|
URL: &url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: "bar.example.com",
|
||||||
|
Path: "/pong",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
got := act.HTTPRequest(url.URL{Scheme: "http", Host: "baz.example.com", Path: "/pang"})
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(want, got) {
|
||||||
|
t.Fatalf("HTTPRequest is %#v, want %#v", want, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRedirectFollowingHTTPClient(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
max int
|
||||||
|
client HTTPClient
|
||||||
|
wantCode int
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
// errors bubbled up
|
||||||
|
{
|
||||||
|
max: 2,
|
||||||
|
client: &multiStaticHTTPClient{
|
||||||
|
responses: []staticHTTPResponse{
|
||||||
|
staticHTTPResponse{
|
||||||
|
err: errors.New("fail!"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: errors.New("fail!"),
|
||||||
|
},
|
||||||
|
|
||||||
|
// no need to follow redirect if none given
|
||||||
|
{
|
||||||
|
max: 2,
|
||||||
|
client: &multiStaticHTTPClient{
|
||||||
|
responses: []staticHTTPResponse{
|
||||||
|
staticHTTPResponse{
|
||||||
|
resp: http.Response{
|
||||||
|
StatusCode: http.StatusTeapot,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantCode: http.StatusTeapot,
|
||||||
|
},
|
||||||
|
|
||||||
|
// redirects if less than max
|
||||||
|
{
|
||||||
|
max: 2,
|
||||||
|
client: &multiStaticHTTPClient{
|
||||||
|
responses: []staticHTTPResponse{
|
||||||
|
staticHTTPResponse{
|
||||||
|
resp: http.Response{
|
||||||
|
StatusCode: http.StatusTemporaryRedirect,
|
||||||
|
Header: http.Header{"Location": []string{"http://example.com"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
staticHTTPResponse{
|
||||||
|
resp: http.Response{
|
||||||
|
StatusCode: http.StatusTeapot,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantCode: http.StatusTeapot,
|
||||||
|
},
|
||||||
|
|
||||||
|
// succeed after reaching max redirects
|
||||||
|
{
|
||||||
|
max: 2,
|
||||||
|
client: &multiStaticHTTPClient{
|
||||||
|
responses: []staticHTTPResponse{
|
||||||
|
staticHTTPResponse{
|
||||||
|
resp: http.Response{
|
||||||
|
StatusCode: http.StatusTemporaryRedirect,
|
||||||
|
Header: http.Header{"Location": []string{"http://example.com"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
staticHTTPResponse{
|
||||||
|
resp: http.Response{
|
||||||
|
StatusCode: http.StatusTemporaryRedirect,
|
||||||
|
Header: http.Header{"Location": []string{"http://example.com"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
staticHTTPResponse{
|
||||||
|
resp: http.Response{
|
||||||
|
StatusCode: http.StatusTeapot,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantCode: http.StatusTeapot,
|
||||||
|
},
|
||||||
|
|
||||||
|
// fail at max+1 redirects
|
||||||
|
{
|
||||||
|
max: 1,
|
||||||
|
client: &multiStaticHTTPClient{
|
||||||
|
responses: []staticHTTPResponse{
|
||||||
|
staticHTTPResponse{
|
||||||
|
resp: http.Response{
|
||||||
|
StatusCode: http.StatusTemporaryRedirect,
|
||||||
|
Header: http.Header{"Location": []string{"http://example.com"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
staticHTTPResponse{
|
||||||
|
resp: http.Response{
|
||||||
|
StatusCode: http.StatusTemporaryRedirect,
|
||||||
|
Header: http.Header{"Location": []string{"http://example.com"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
staticHTTPResponse{
|
||||||
|
resp: http.Response{
|
||||||
|
StatusCode: http.StatusTeapot,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: ErrTooManyRedirects,
|
||||||
|
},
|
||||||
|
|
||||||
|
// fail if Location header not set
|
||||||
|
{
|
||||||
|
max: 1,
|
||||||
|
client: &multiStaticHTTPClient{
|
||||||
|
responses: []staticHTTPResponse{
|
||||||
|
staticHTTPResponse{
|
||||||
|
resp: http.Response{
|
||||||
|
StatusCode: http.StatusTemporaryRedirect,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: errors.New("Location header not set"),
|
||||||
|
},
|
||||||
|
|
||||||
|
// fail if Location header is invalid
|
||||||
|
{
|
||||||
|
max: 1,
|
||||||
|
client: &multiStaticHTTPClient{
|
||||||
|
responses: []staticHTTPResponse{
|
||||||
|
staticHTTPResponse{
|
||||||
|
resp: http.Response{
|
||||||
|
StatusCode: http.StatusTemporaryRedirect,
|
||||||
|
Header: http.Header{"Location": []string{":"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: errors.New("Location header not valid URL: :"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
client := &redirectFollowingHTTPClient{client: tt.client, max: tt.max}
|
||||||
|
resp, _, err := client.Do(context.Background(), nil)
|
||||||
|
if !reflect.DeepEqual(tt.wantErr, err) {
|
||||||
|
t.Errorf("#%d: got err=%v, want=%v", i, err, tt.wantErr)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp == nil {
|
||||||
|
if tt.wantCode != 0 {
|
||||||
|
t.Errorf("#%d: resp is nil, want=%d", i, tt.wantCode)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != tt.wantCode {
|
||||||
|
t.Errorf("#%d: resp code=%d, want=%d", i, resp.StatusCode, tt.wantCode)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user