client: don't cache httpClients in httpClusterClient

This commit is contained in:
Brian Waldon 2015-01-26 16:50:24 -08:00 committed by Yicheng Qin
parent 99d63eb62e
commit 62054dfb5e
2 changed files with 103 additions and 56 deletions

View File

@ -36,13 +36,23 @@ var (
DefaultMaxRedirects = 10
)
func defaultHTTPClientFactory(tr CancelableTransport, ep url.URL) HTTPClient {
return &redirectFollowingHTTPClient{
max: DefaultMaxRedirects,
client: &httpClient{
transport: tr,
endpoint: ep,
},
}
}
type ClientConfig struct {
Endpoints []string
Transport CancelableTransport
}
func New(cfg ClientConfig) (SyncableHTTPClient, error) {
return newHTTPClusterClient(cfg.Transport, cfg.Endpoints)
return newHTTPClusterClient(cfg.Transport, cfg.Endpoints, defaultHTTPClientFactory)
}
type SyncableHTTPClient interface {
@ -55,6 +65,8 @@ type HTTPClient interface {
Do(context.Context, HTTPAction) (*http.Response, []byte, error)
}
type httpClientFactory func(CancelableTransport, url.URL) HTTPClient
type HTTPAction interface {
HTTPRequest(url.URL) *http.Request
}
@ -67,8 +79,8 @@ type CancelableTransport interface {
CancelRequest(req *http.Request)
}
func newHTTPClusterClient(tr CancelableTransport, eps []string) (*httpClusterClient, error) {
c := &httpClusterClient{}
func newHTTPClusterClient(tr CancelableTransport, eps []string, cf httpClientFactory) (*httpClusterClient, error) {
c := &httpClusterClient{clientFactory: cf}
if err := c.reset(tr, eps); err != nil {
return nil, err
}
@ -76,37 +88,27 @@ func newHTTPClusterClient(tr CancelableTransport, eps []string) (*httpClusterCli
}
type httpClusterClient struct {
transport CancelableTransport
endpoints []string
clients []HTTPClient
clientFactory httpClientFactory
transport CancelableTransport
endpoints []url.URL
sync.RWMutex
}
func (c *httpClusterClient) reset(tr CancelableTransport, eps []string) error {
le := len(eps)
ne := make([]string, le)
if copy(ne, eps) != le {
return errors.New("copy call failed")
if len(eps) == 0 {
return ErrNoEndpoints
}
nc := make([]HTTPClient, len(ne))
for i, e := range ne {
u, err := url.Parse(e)
neps := make([]url.URL, len(eps))
for i, ep := range eps {
u, err := url.Parse(ep)
if err != nil {
return err
}
nc[i] = &redirectFollowingHTTPClient{
max: DefaultMaxRedirects,
client: &httpClient{
transport: tr,
endpoint: *u,
},
}
neps[i] = *u
}
c.endpoints = ne
c.clients = nc
c.endpoints = neps
c.transport = tr
return nil
@ -114,12 +116,24 @@ func (c *httpClusterClient) reset(tr CancelableTransport, eps []string) error {
func (c *httpClusterClient) Do(ctx context.Context, act HTTPAction) (resp *http.Response, body []byte, err error) {
c.RLock()
defer c.RUnlock()
leps := len(c.endpoints)
eps := make([]url.URL, leps)
n := copy(eps, c.endpoints)
tr := c.transport
c.RUnlock()
if len(c.clients) == 0 {
return nil, nil, ErrNoEndpoints
if leps == 0 {
err = ErrNoEndpoints
return
}
for _, hc := range c.clients {
if leps != n {
err = errors.New("unable to pick endpoint: copy failed")
return
}
for _, ep := range eps {
hc := c.clientFactory(tr, ep)
resp, body, err = hc.Do(ctx, act)
if err != nil {
if err == ErrTimeout || err == ErrCanceled {
@ -132,13 +146,20 @@ func (c *httpClusterClient) Do(ctx context.Context, act HTTPAction) (resp *http.
}
break
}
return
}
func (c *httpClusterClient) Endpoints() []string {
c.RLock()
defer c.RUnlock()
return c.endpoints
eps := make([]string, len(c.endpoints))
for i, ep := range c.endpoints {
eps[i] = ep.String()
}
return eps
}
func (c *httpClusterClient) Sync(ctx context.Context) error {
@ -155,9 +176,6 @@ func (c *httpClusterClient) Sync(ctx context.Context) error {
for _, m := range ms {
eps = append(eps, m.ClientURLs...)
}
if len(eps) == 0 {
return ErrNoEndpoints
}
return c.reset(c.transport, eps)
}

View File

@ -60,6 +60,15 @@ func (s *multiStaticHTTPClient) Do(context.Context, HTTPAction) (*http.Response,
return &r.resp, nil, r.err
}
func newStaticHTTPClientFactory(responses []staticHTTPResponse) httpClientFactory {
var cur int
return func(CancelableTransport, url.URL) HTTPClient {
r := responses[cur]
cur++
return &staticHTTPClient{resp: r.resp, err: r.err}
}
}
type fakeTransport struct {
respchan chan *http.Response
errchan chan error
@ -183,6 +192,7 @@ func TestHTTPClientDoCancelContextWaitForRoundTrip(t *testing.T) {
func TestHTTPClusterClientDo(t *testing.T) {
fakeErr := errors.New("fake!")
fakeURL := url.URL{}
tests := []struct {
client *httpClusterClient
wantCode int
@ -191,10 +201,13 @@ func TestHTTPClusterClientDo(t *testing.T) {
// first good response short-circuits Do
{
client: &httpClusterClient{
clients: []HTTPClient{
&staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}},
&staticHTTPClient{err: fakeErr},
},
endpoints: []url.URL{fakeURL, fakeURL},
clientFactory: newStaticHTTPClientFactory(
[]staticHTTPResponse{
staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}},
staticHTTPResponse{err: fakeErr},
},
),
},
wantCode: http.StatusTeapot,
},
@ -202,10 +215,13 @@ func TestHTTPClusterClientDo(t *testing.T) {
// fall through to good endpoint if err is arbitrary
{
client: &httpClusterClient{
clients: []HTTPClient{
&staticHTTPClient{err: fakeErr},
&staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}},
},
endpoints: []url.URL{fakeURL, fakeURL},
clientFactory: newStaticHTTPClientFactory(
[]staticHTTPResponse{
staticHTTPResponse{err: fakeErr},
staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}},
},
),
},
wantCode: http.StatusTeapot,
},
@ -213,10 +229,13 @@ func TestHTTPClusterClientDo(t *testing.T) {
// ErrTimeout short-circuits Do
{
client: &httpClusterClient{
clients: []HTTPClient{
&staticHTTPClient{err: ErrTimeout},
&staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}},
},
endpoints: []url.URL{fakeURL, fakeURL},
clientFactory: newStaticHTTPClientFactory(
[]staticHTTPResponse{
staticHTTPResponse{err: ErrTimeout},
staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}},
},
),
},
wantErr: ErrTimeout,
},
@ -224,10 +243,13 @@ func TestHTTPClusterClientDo(t *testing.T) {
// ErrCanceled short-circuits Do
{
client: &httpClusterClient{
clients: []HTTPClient{
&staticHTTPClient{err: ErrCanceled},
&staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}},
},
endpoints: []url.URL{fakeURL, fakeURL},
clientFactory: newStaticHTTPClientFactory(
[]staticHTTPResponse{
staticHTTPResponse{err: ErrCanceled},
staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}},
},
),
},
wantErr: ErrCanceled,
},
@ -235,7 +257,8 @@ func TestHTTPClusterClientDo(t *testing.T) {
// return err if there are no endpoints
{
client: &httpClusterClient{
clients: []HTTPClient{},
endpoints: []url.URL{},
clientFactory: defaultHTTPClientFactory,
},
wantErr: ErrNoEndpoints,
},
@ -243,10 +266,13 @@ func TestHTTPClusterClientDo(t *testing.T) {
// return err if all endpoints return arbitrary errors
{
client: &httpClusterClient{
clients: []HTTPClient{
&staticHTTPClient{err: fakeErr},
&staticHTTPClient{err: fakeErr},
},
endpoints: []url.URL{fakeURL, fakeURL},
clientFactory: newStaticHTTPClientFactory(
[]staticHTTPResponse{
staticHTTPResponse{err: fakeErr},
staticHTTPResponse{err: fakeErr},
},
),
},
wantErr: fakeErr,
},
@ -254,10 +280,13 @@ func TestHTTPClusterClientDo(t *testing.T) {
// 500-level errors cause Do to fallthrough to next endpoint
{
client: &httpClusterClient{
clients: []HTTPClient{
&staticHTTPClient{resp: http.Response{StatusCode: http.StatusBadGateway}},
&staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}},
},
endpoints: []url.URL{fakeURL, fakeURL},
clientFactory: newStaticHTTPClientFactory(
[]staticHTTPResponse{
staticHTTPResponse{resp: http.Response{StatusCode: http.StatusBadGateway}},
staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}},
},
),
},
wantCode: http.StatusTeapot,
},