diff --git a/client/keys.go b/client/keys.go index ffa6bf088..38d5813a2 100644 --- a/client/keys.go +++ b/client/keys.go @@ -71,6 +71,7 @@ type Response struct { Action string `json:"action"` Node *Node `json:"node"` PrevNode *Node `json:"prevNode"` + Index uint64 } type Nodes []*Node @@ -107,7 +108,7 @@ func (k *httpKeysAPI) Create(ctx context.Context, key, val string, ttl time.Dura return nil, err } - return unmarshalHTTPResponse(resp.StatusCode, body) + return unmarshalHTTPResponse(resp.StatusCode, resp.Header, body) } func (k *httpKeysAPI) Get(ctx context.Context, key string) (*Response, error) { @@ -122,7 +123,7 @@ func (k *httpKeysAPI) Get(ctx context.Context, key string) (*Response, error) { return nil, err } - return unmarshalHTTPResponse(resp.StatusCode, body) + return unmarshalHTTPResponse(resp.StatusCode, resp.Header, body) } func (k *httpKeysAPI) Watch(key string, idx uint64) Watcher { @@ -160,7 +161,7 @@ func (hw *httpWatcher) Next(ctx context.Context) (*Response, error) { return nil, err } - resp, err := unmarshalHTTPResponse(httpresp.StatusCode, body) + resp, err := unmarshalHTTPResponse(httpresp.StatusCode, httpresp.Header, body) if err != nil { return nil, err } @@ -243,10 +244,10 @@ func (c *createAction) HTTPRequest(ep url.URL) *http.Request { return req } -func unmarshalHTTPResponse(code int, body []byte) (res *Response, err error) { +func unmarshalHTTPResponse(code int, header http.Header, body []byte) (res *Response, err error) { switch code { case http.StatusOK, http.StatusCreated: - res, err = unmarshalSuccessfulResponse(body) + res, err = unmarshalSuccessfulResponse(header, body) default: err = unmarshalErrorResponse(code) } @@ -254,13 +255,18 @@ func unmarshalHTTPResponse(code int, body []byte) (res *Response, err error) { return } -func unmarshalSuccessfulResponse(body []byte) (*Response, error) { +func unmarshalSuccessfulResponse(header http.Header, body []byte) (*Response, error) { var res Response err := json.Unmarshal(body, &res) if err != nil { return nil, err } - + if header.Get("X-Etcd-Index") != "" { + res.Index, err = strconv.ParseUint(header.Get("X-Etcd-Index"), 10, 64) + } + if err != nil { + return nil, err + } return &res, nil } @@ -273,6 +279,8 @@ func unmarshalErrorResponse(code int) error { case http.StatusInternalServerError: // this isn't necessarily true return ErrNoLeader + case http.StatusGatewayTimeout: + return ErrTimeout default: } diff --git a/client/keys_test.go b/client/keys_test.go index 85bca1afa..72a4ceabe 100644 --- a/client/keys_test.go +++ b/client/keys_test.go @@ -255,40 +255,46 @@ func assertResponse(got http.Request, wantURL *url.URL, wantHeader http.Header, func TestUnmarshalSuccessfulResponse(t *testing.T) { tests := []struct { + indexHeader string body string res *Response expectError bool }{ // Neither PrevNode or Node { + "1", `{"action":"delete"}`, - &Response{Action: "delete"}, + &Response{Action: "delete", Index: 1}, false, }, // PrevNode { + "15", `{"action":"delete", "prevNode": {"key": "/foo", "value": "bar", "modifiedIndex": 12, "createdIndex": 10}}`, - &Response{Action: "delete", PrevNode: &Node{Key: "/foo", Value: "bar", ModifiedIndex: 12, CreatedIndex: 10}}, + &Response{Action: "delete", Index: 15, PrevNode: &Node{Key: "/foo", Value: "bar", ModifiedIndex: 12, CreatedIndex: 10}}, false, }, // Node { + "15", `{"action":"get", "node": {"key": "/foo", "value": "bar", "modifiedIndex": 12, "createdIndex": 10}}`, - &Response{Action: "get", Node: &Node{Key: "/foo", Value: "bar", ModifiedIndex: 12, CreatedIndex: 10}}, + &Response{Action: "get", Index: 15, Node: &Node{Key: "/foo", Value: "bar", ModifiedIndex: 12, CreatedIndex: 10}}, false, }, // PrevNode and Node { + "15", `{"action":"update", "prevNode": {"key": "/foo", "value": "baz", "modifiedIndex": 10, "createdIndex": 10}, "node": {"key": "/foo", "value": "bar", "modifiedIndex": 12, "createdIndex": 10}}`, - &Response{Action: "update", PrevNode: &Node{Key: "/foo", Value: "baz", ModifiedIndex: 10, CreatedIndex: 10}, Node: &Node{Key: "/foo", Value: "bar", ModifiedIndex: 12, CreatedIndex: 10}}, + &Response{Action: "update", Index: 15, PrevNode: &Node{Key: "/foo", Value: "baz", ModifiedIndex: 10, CreatedIndex: 10}, Node: &Node{Key: "/foo", Value: "bar", ModifiedIndex: 12, CreatedIndex: 10}}, false, }, // Garbage in body { + "", `garbage`, nil, true, @@ -296,7 +302,9 @@ func TestUnmarshalSuccessfulResponse(t *testing.T) { } for i, tt := range tests { - res, err := unmarshalSuccessfulResponse([]byte(tt.body)) + h := make(http.Header) + h.Add("X-Etcd-Index", tt.indexHeader) + res, err := unmarshalSuccessfulResponse(h, []byte(tt.body)) if tt.expectError != (err != nil) { t.Errorf("#%d: expectError=%t, err=%v", i, tt.expectError, err) } @@ -312,7 +320,9 @@ func TestUnmarshalSuccessfulResponse(t *testing.T) { if res.Action != tt.res.Action { t.Errorf("#%d: Action=%s, expected %s", i, res.Action, tt.res.Action) } - + if res.Index != tt.res.Index { + t.Errorf("#%d: Index=%d, expected %d", i, res.Index, tt.res.Index) + } if !reflect.DeepEqual(res.Node, tt.res.Node) { t.Errorf("#%d: Node=%v, expected %v", i, res.Node, tt.res.Node) } @@ -350,7 +360,7 @@ func TestUnmarshalErrorResponse(t *testing.T) { {http.StatusNotImplemented, unrecognized}, {http.StatusBadGateway, unrecognized}, {http.StatusServiceUnavailable, unrecognized}, - {http.StatusGatewayTimeout, unrecognized}, + {http.StatusGatewayTimeout, ErrTimeout}, {http.StatusHTTPVersionNotSupported, unrecognized}, } diff --git a/discovery/discovery.go b/discovery/discovery.go index 1b50a291c..9ea7ea65f 100644 --- a/discovery/discovery.go +++ b/discovery/discovery.go @@ -20,6 +20,7 @@ import ( "errors" "fmt" "log" + "math" "net/http" "net/url" "path" @@ -44,9 +45,9 @@ var ( ErrTooManyRetries = errors.New("discovery: too many retries") ) -const ( +var ( // Number of retries discovery will attempt before giving up and erroring out. - nRetries = uint(3) + nRetries = uint(math.MaxUint32) ) // JoinCluster will connect to the discovery service at the given url, and @@ -135,7 +136,7 @@ func newDiscovery(durl, dproxyurl string, id types.ID) (*discovery, error) { func (d *discovery) joinCluster(config string) (string, error) { // fast path: if the cluster is full, return the error // do not need to register to the cluster in this case. - if _, _, err := d.checkCluster(); err != nil { + if _, _, _, err := d.checkCluster(); err != nil { return "", err } @@ -146,12 +147,12 @@ func (d *discovery) joinCluster(config string) (string, error) { return "", err } - nodes, size, err := d.checkCluster() + nodes, size, index, err := d.checkCluster() if err != nil { return "", err } - all, err := d.waitNodes(nodes, size) + all, err := d.waitNodes(nodes, size, index) if err != nil { return "", err } @@ -160,7 +161,7 @@ func (d *discovery) joinCluster(config string) (string, error) { } func (d *discovery) getCluster() (string, error) { - nodes, size, err := d.checkCluster() + nodes, size, index, err := d.checkCluster() if err != nil { if err == ErrFullCluster { return nodesToCluster(nodes), nil @@ -168,7 +169,7 @@ func (d *discovery) getCluster() (string, error) { return "", err } - all, err := d.waitNodes(nodes, size) + all, err := d.waitNodes(nodes, size, index) if err != nil { return "", err } @@ -189,7 +190,7 @@ func (d *discovery) createSelf(contents string) error { return err } -func (d *discovery) checkCluster() (client.Nodes, int, error) { +func (d *discovery) checkCluster() (client.Nodes, int, uint64, error) { configKey := path.Join("/", d.cluster, "_config") ctx, cancel := context.WithTimeout(context.Background(), client.DefaultRequestTimeout) // find cluster size @@ -197,16 +198,16 @@ func (d *discovery) checkCluster() (client.Nodes, int, error) { cancel() if err != nil { if err == client.ErrKeyNoExist { - return nil, 0, ErrSizeNotFound + return nil, 0, 0, ErrSizeNotFound } if err == client.ErrTimeout { return d.checkClusterRetry() } - return nil, 0, err + return nil, 0, 0, err } size, err := strconv.Atoi(resp.Node.Value) if err != nil { - return nil, 0, ErrBadSizeKey + return nil, 0, 0, ErrBadSizeKey } ctx, cancel = context.WithTimeout(context.Background(), client.DefaultRequestTimeout) @@ -216,7 +217,7 @@ func (d *discovery) checkCluster() (client.Nodes, int, error) { if err == client.ErrTimeout { return d.checkClusterRetry() } - return nil, 0, err + return nil, 0, 0, err } nodes := make(client.Nodes, 0) // append non-config keys to nodes @@ -235,10 +236,10 @@ func (d *discovery) checkCluster() (client.Nodes, int, error) { break } if i >= size-1 { - return nodes[:size], size, ErrFullCluster + return nodes[:size], size, resp.Index, ErrFullCluster } } - return nodes, size, nil + return nodes, size, resp.Index, nil } func (d *discovery) logAndBackoffForRetry(step string) { @@ -248,31 +249,31 @@ func (d *discovery) logAndBackoffForRetry(step string) { d.clock.Sleep(retryTime) } -func (d *discovery) checkClusterRetry() (client.Nodes, int, error) { +func (d *discovery) checkClusterRetry() (client.Nodes, int, uint64, error) { if d.retries < nRetries { d.logAndBackoffForRetry("cluster status check") return d.checkCluster() } - return nil, 0, ErrTooManyRetries + return nil, 0, 0, ErrTooManyRetries } func (d *discovery) waitNodesRetry() (client.Nodes, error) { if d.retries < nRetries { d.logAndBackoffForRetry("waiting for other nodes") - nodes, n, err := d.checkCluster() + nodes, n, index, err := d.checkCluster() if err != nil { return nil, err } - return d.waitNodes(nodes, n) + return d.waitNodes(nodes, n, index) } return nil, ErrTooManyRetries } -func (d *discovery) waitNodes(nodes client.Nodes, size int) (client.Nodes, error) { +func (d *discovery) waitNodes(nodes client.Nodes, size int, index uint64) (client.Nodes, error) { if len(nodes) > size { nodes = nodes[:size] } - w := d.c.RecursiveWatch(d.cluster, nodes[len(nodes)-1].ModifiedIndex+1) + w := d.c.RecursiveWatch(d.cluster, index) all := make(client.Nodes, len(nodes)) copy(all, nodes) for _, n := range all { diff --git a/discovery/discovery_test.go b/discovery/discovery_test.go index c32baca4f..faea82a11 100644 --- a/discovery/discovery_test.go +++ b/discovery/discovery_test.go @@ -18,6 +18,7 @@ package discovery import ( "errors" + "math" "math/rand" "net/http" "reflect" @@ -31,6 +32,10 @@ import ( "github.com/coreos/etcd/client" ) +const ( + maxRetryInTest = 3 +) + func TestNewProxyFuncUnset(t *testing.T) { pf, err := newProxyFunc("") if pf != nil { @@ -89,6 +94,7 @@ func TestCheckCluster(t *testing.T) { tests := []struct { nodes []*client.Node + index uint64 werr error wsize int }{ @@ -102,6 +108,7 @@ func TestCheckCluster(t *testing.T) { {Key: "/1000/3", CreatedIndex: 4}, {Key: "/1000/4", CreatedIndex: 5}, }, + 5, nil, 3, }, @@ -115,6 +122,7 @@ func TestCheckCluster(t *testing.T) { {Key: self, CreatedIndex: 4}, {Key: "/1000/4", CreatedIndex: 5}, }, + 5, nil, 3, }, @@ -128,6 +136,7 @@ func TestCheckCluster(t *testing.T) { {Key: "/1000/4", CreatedIndex: 4}, {Key: self, CreatedIndex: 5}, }, + 5, ErrFullCluster, 3, }, @@ -139,6 +148,7 @@ func TestCheckCluster(t *testing.T) { {Key: "/1000/2", CreatedIndex: 2}, {Key: "/1000/3", CreatedIndex: 3}, }, + 3, nil, 3, }, @@ -150,6 +160,7 @@ func TestCheckCluster(t *testing.T) { {Key: "/1000/3", CreatedIndex: 3}, {Key: "/1000/4", CreatedIndex: 4}, }, + 3, ErrFullCluster, 3, }, @@ -158,12 +169,14 @@ func TestCheckCluster(t *testing.T) { []*client.Node{ {Key: "/1000/_config/size", Value: "bad", CreatedIndex: 1}, }, + 0, ErrBadSizeKey, 0, }, { // no size key []*client.Node{}, + 0, ErrSizeNotFound, 0, }, @@ -172,12 +185,13 @@ func TestCheckCluster(t *testing.T) { for i, tt := range tests { rs := make([]*client.Response, 0) if len(tt.nodes) > 0 { - rs = append(rs, &client.Response{Node: tt.nodes[0]}) + rs = append(rs, &client.Response{Node: tt.nodes[0], Index: tt.index}) rs = append(rs, &client.Response{ Node: &client.Node{ Key: cluster, Nodes: tt.nodes[1:], }, + Index: tt.index, }) } c := &clientWithResp{rs: rs} @@ -190,12 +204,12 @@ func TestCheckCluster(t *testing.T) { for _, d := range []discovery{d, dRetry} { go func() { - for i := uint(1); i <= nRetries; i++ { + for i := uint(1); i <= maxRetryInTest; i++ { fc.BlockUntil(1) fc.Advance(time.Second * (0x1 << i)) } }() - ns, size, err := d.checkCluster() + ns, size, index, err := d.checkCluster() if err != tt.werr { t.Errorf("#%d: err = %v, want %v", i, err, tt.werr) } @@ -205,6 +219,9 @@ func TestCheckCluster(t *testing.T) { if size != tt.wsize { t.Errorf("#%d: size = %v, want %d", i, size, tt.wsize) } + if index != tt.index { + t.Errorf("#%d: index = %v, want %d", i, index, tt.index) + } } } } @@ -278,12 +295,12 @@ func TestWaitNodes(t *testing.T) { for _, d := range []*discovery{d, dRetry} { go func() { - for i := uint(1); i <= nRetries; i++ { + for i := uint(1); i <= maxRetryInTest; i++ { fc.BlockUntil(1) fc.Advance(time.Second * (0x1 << i)) } }() - g, err := d.waitNodes(tt.nodes, 3) + g, err := d.waitNodes(tt.nodes, 3, 0) // we do not care about index in this test if err != nil { t.Errorf("#%d: err = %v, want %v", i, err, nil) } @@ -368,6 +385,9 @@ func TestSortableNodes(t *testing.T) { } func TestRetryFailure(t *testing.T) { + nRetries = maxRetryInTest + defer func() { nRetries = math.MaxUint32 }() + cluster := "1000" c := &clientWithRetry{failTimes: 4} fc := clockwork.NewFakeClock() @@ -378,12 +398,12 @@ func TestRetryFailure(t *testing.T) { clock: fc, } go func() { - for i := uint(1); i <= nRetries; i++ { + for i := uint(1); i <= maxRetryInTest; i++ { fc.BlockUntil(1) fc.Advance(time.Second * (0x1 << i)) } }() - if _, _, err := d.checkCluster(); err != ErrTooManyRetries { + if _, _, _, err := d.checkCluster(); err != ErrTooManyRetries { t.Errorf("err = %v, want %v", err, ErrTooManyRetries) } }