diff --git a/client/client.go b/client/client.go index a931fd8be..278e02c83 100644 --- a/client/client.go +++ b/client/client.go @@ -87,6 +87,23 @@ type Config struct { // Password is the password for the specified user to add as an authorization header // to the request. Password string + + // HeaderTimeoutPerRequest specifies the time limit to wait for response + // header in a single request made by the Client. The timeout includes + // connection time, any redirects, and header wait time. + // + // For non-watch GET request, server returns the response body immediately. + // For PUT/POST/DELETE request, server will attempt to commit request + // before responding, which is expected to take `100ms + 2 * RTT`. + // For watch request, server returns the header immediately to notify Client + // watch start. But if server is behind some kind of proxy, the response + // header may be cached at proxy, and Client cannot rely on this behavior. + // + // One API call may send multiple requests to different etcd servers until it + // succeeds. Use context of the API to specify the overall timeout. + // + // A HeaderTimeoutPerRequest of zero means no timeout. + HeaderTimeoutPerRequest time.Duration } func (cfg *Config) transport() CancelableTransport { @@ -150,7 +167,7 @@ type Client interface { func New(cfg Config) (Client, error) { c := &httpClusterClient{ - clientFactory: newHTTPClientFactory(cfg.transport(), cfg.checkRedirect()), + clientFactory: newHTTPClientFactory(cfg.transport(), cfg.checkRedirect(), cfg.HeaderTimeoutPerRequest), rand: rand.New(rand.NewSource(int64(time.Now().Nanosecond()))), } if cfg.Username != "" { @@ -169,13 +186,14 @@ type httpClient interface { Do(context.Context, httpAction) (*http.Response, []byte, error) } -func newHTTPClientFactory(tr CancelableTransport, cr CheckRedirectFunc) httpClientFactory { +func newHTTPClientFactory(tr CancelableTransport, cr CheckRedirectFunc, headerTimeout time.Duration) httpClientFactory { return func(ep url.URL) httpClient { return &redirectFollowingHTTPClient{ checkRedirect: cr, client: &simpleHTTPClient{ - transport: tr, - endpoint: ep, + transport: tr, + endpoint: ep, + headerTimeout: headerTimeout, }, } } @@ -353,8 +371,9 @@ type roundTripResponse struct { } type simpleHTTPClient struct { - transport CancelableTransport - endpoint url.URL + transport CancelableTransport + endpoint url.URL + headerTimeout time.Duration } func (c *simpleHTTPClient) Do(ctx context.Context, act httpAction) (*http.Response, []byte, error) { @@ -364,6 +383,12 @@ func (c *simpleHTTPClient) Do(ctx context.Context, act httpAction) (*http.Respon return nil, nil, err } + hctx, hcancel := context.WithCancel(ctx) + if c.headerTimeout > 0 { + hctx, hcancel = context.WithTimeout(ctx, c.headerTimeout) + } + defer hcancel() + rtchan := make(chan roundTripResponse, 1) go func() { resp, err := c.transport.RoundTrip(req) @@ -377,12 +402,19 @@ func (c *simpleHTTPClient) Do(ctx context.Context, act httpAction) (*http.Respon select { case rtresp := <-rtchan: resp, err = rtresp.resp, rtresp.err - case <-ctx.Done(): + case <-hctx.Done(): // cancel and wait for request to actually exit before continuing c.transport.CancelRequest(req) rtresp := <-rtchan resp = rtresp.resp - err = ctx.Err() + switch { + case ctx.Err() != nil: + err = ctx.Err() + case hctx.Err() != nil: + err = fmt.Errorf("client: endpoint %s exceeded header timeout", c.endpoint) + default: + panic("failed to get error from context") + } } // always check for resp nil-ness to deal with possible diff --git a/client/client_test.go b/client/client_test.go index 3a020d46f..25d665929 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -297,6 +297,27 @@ func TestSimpleHTTPClientDoCancelContextWaitForRoundTrip(t *testing.T) { } } +func TestSimpleHTTPClientDoHeaderTimeout(t *testing.T) { + tr := newFakeTransport() + tr.finishCancel <- struct{}{} + c := &simpleHTTPClient{transport: tr, headerTimeout: time.Millisecond} + + errc := make(chan error) + go func() { + _, _, err := c.Do(context.Background(), &fakeAction{}) + errc <- err + }() + + select { + case err := <-errc: + if err == nil { + t.Fatalf("expected non-nil error, got nil") + } + case <-time.After(time.Second): + t.Fatalf("unexpected timeout when waitting for the test to finish") + } +} + func TestHTTPClusterClientDo(t *testing.T) { fakeErr := errors.New("fake!") fakeURL := url.URL{} @@ -337,21 +358,6 @@ func TestHTTPClusterClientDo(t *testing.T) { wantPinned: 1, }, - // context.DeadlineExceeded short-circuits Do - { - client: &httpClusterClient{ - endpoints: []url.URL{fakeURL, fakeURL}, - clientFactory: newStaticHTTPClientFactory( - []staticHTTPResponse{ - staticHTTPResponse{err: context.DeadlineExceeded}, - staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}}, - }, - ), - rand: rand.New(rand.NewSource(0)), - }, - wantErr: &ClusterError{Errors: []error{context.DeadlineExceeded}}, - }, - // context.Canceled short-circuits Do { client: &httpClusterClient{ @@ -371,7 +377,7 @@ func TestHTTPClusterClientDo(t *testing.T) { { client: &httpClusterClient{ endpoints: []url.URL{}, - clientFactory: newHTTPClientFactory(nil, nil), + clientFactory: newHTTPClientFactory(nil, nil, 0), rand: rand.New(rand.NewSource(0)), }, wantErr: ErrNoEndpoints, @@ -434,6 +440,34 @@ func TestHTTPClusterClientDo(t *testing.T) { } } +func TestHTTPClusterClientDoDeadlineExceedContext(t *testing.T) { + fakeURL := url.URL{} + tr := newFakeTransport() + tr.finishCancel <- struct{}{} + c := &httpClusterClient{ + clientFactory: newHTTPClientFactory(tr, DefaultCheckRedirect, 0), + endpoints: []url.URL{fakeURL}, + } + + errc := make(chan error) + go func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + _, _, err := c.Do(ctx, &fakeAction{}) + errc <- err + }() + + select { + case err := <-errc: + werr := &ClusterError{Errors: []error{context.DeadlineExceeded}} + if !reflect.DeepEqual(err, werr) { + t.Errorf("err = %+v, want %+v", err, werr) + } + case <-time.After(time.Second): + t.Fatalf("unexpected timeout when waitting for request to deadline exceed") + } +} + func TestRedirectedHTTPAction(t *testing.T) { act := &redirectedHTTPAction{ action: &staticHTTPAction{