From 172a32e5e33b9415f695500ffc202b1680c1acfc Mon Sep 17 00:00:00 2001 From: Jonathan Boulle Date: Tue, 23 Sep 2014 14:08:18 -0700 Subject: [PATCH] etcdserver: correct timeout and streaming handling This reintroduces the 'stream' parameter to support long-lived watch sessions. These sessions respect a server timeout (set to 5 minutes by default). --- etcdserver/etcdhttp/http.go | 100 ++++++++++++----- etcdserver/etcdhttp/http_test.go | 137 ++++------------------- etcdserver/etcdserverpb/etcdserver.pb.go | 29 +++++ etcdserver/etcdserverpb/etcdserver.proto | 1 + etcdserver/server.go | 2 +- 5 files changed, 123 insertions(+), 146 deletions(-) diff --git a/etcdserver/etcdhttp/http.go b/etcdserver/etcdhttp/http.go index b6bc18552..6b4361e42 100644 --- a/etcdserver/etcdhttp/http.go +++ b/etcdserver/etcdhttp/http.go @@ -7,12 +7,12 @@ import ( "io/ioutil" "log" "net/http" + "net/http/httputil" "net/url" "strconv" "strings" "time" - "github.com/coreos/etcd/elog" etcdErr "github.com/coreos/etcd/error" "github.com/coreos/etcd/etcdserver" "github.com/coreos/etcd/etcdserver/etcdserverpb" @@ -26,7 +26,11 @@ const ( machinesPrefix = "/v2/machines" raftPrefix = "/raft" - DefaultTimeout = 500 * time.Millisecond + // time to wait for response from EtcdServer requests + defaultServerTimeout = 500 * time.Millisecond + + // time to wait for a Watch request + defaultWatchTimeout = 5 * time.Minute ) var errClosed = errors.New("etcdhttp: client closed connection") @@ -39,7 +43,7 @@ func NewClientHandler(server etcdserver.Server, peers Peers, timeout time.Durati timeout: timeout, } if sh.timeout == 0 { - sh.timeout = DefaultTimeout + sh.timeout = defaultServerTimeout } mux := http.NewServeMux() mux.HandleFunc(keysPrefix, sh.serveKeys) @@ -89,23 +93,16 @@ func (h serverHandler) serveKeys(w http.ResponseWriter, r *http.Request) { return } - var ev *store.Event switch { case resp.Event != nil: - ev = resp.Event - case resp.Watcher != nil: - if ev, err = waitForEvent(ctx, w, resp.Watcher); err != nil { - http.Error(w, err.Error(), http.StatusGatewayTimeout) - return + if err := writeEvent(w, resp.Event); err != nil { + // Should never be reached + log.Println("error writing event: %v", err) } + case resp.Watcher != nil: + handleWatch(w, resp.Watcher, rr.Stream) default: writeError(w, errors.New("received response with no Event/Watcher!")) - return - } - - if err = writeEvent(w, ev); err != nil { - // Should never be reached - log.Println("error writing event: %v", err) } } @@ -187,7 +184,7 @@ func parseRequest(r *http.Request, id int64) (etcdserverpb.Request, error) { ) } - var rec, sort, wait bool + var rec, sort, wait, stream bool if rec, err = getBool(r.Form, "recursive"); err != nil { return emptyReq, etcdErr.NewRequestError( etcdErr.EcodeInvalidField, @@ -206,6 +203,19 @@ func parseRequest(r *http.Request, id int64) (etcdserverpb.Request, error) { `invalid value for "wait"`, ) } + if stream, err = getBool(r.Form, "stream"); err != nil { + return emptyReq, etcdErr.NewRequestError( + etcdErr.EcodeInvalidField, + `invalid value for "stream"`, + ) + } + + if wait && r.Method != "GET" { + return emptyReq, etcdErr.NewRequestError( + etcdErr.EcodeInvalidField, + `"wait" can only be used with GET requests`, + ) + } // prevExist is nullable, so leave it null if not specified var pe *bool @@ -231,6 +241,7 @@ func parseRequest(r *http.Request, id int64) (etcdserverpb.Request, error) { Recursive: rec, Since: wIdx, Sorted: sort, + Stream: stream, Wait: wait, } @@ -285,8 +296,9 @@ func writeError(w http.ResponseWriter, err error) { } } -// writeEvent serializes the given Event and writes the resulting JSON to the -// given ResponseWriter +// writeEvent serializes a single Event and writes the resulting +// JSON to the given ResponseWriter, along with the appropriate +// headers func writeEvent(w http.ResponseWriter, ev *store.Event) error { if ev == nil { return errors.New("cannot write empty Event!") @@ -301,25 +313,51 @@ func writeEvent(w http.ResponseWriter, ev *store.Event) error { return json.NewEncoder(w).Encode(ev) } -// waitForEvent waits for a given Watcher to return its associated -// event. It returns a non-nil error if the given Context times out -// or the given ResponseWriter triggers a CloseNotify. -func waitForEvent(ctx context.Context, w http.ResponseWriter, wa store.Watcher) (*store.Event, error) { - // TODO(bmizerany): support streaming? +func handleWatch(w http.ResponseWriter, wa store.Watcher, stream bool) { defer wa.Remove() + ech := wa.EventChan() + tch := time.After(defaultWatchTimeout) var nch <-chan bool if x, ok := w.(http.CloseNotifier); ok { nch = x.CloseNotify() } - select { - case ev := <-wa.EventChan(): - return ev, nil - case <-nch: - elog.TODO() - return nil, errClosed - case <-ctx.Done(): - return nil, ctx.Err() + + w.Header().Set("Content-Type", "application/json") + // WriteHeader will implicitly write a Transfer-Encoding: chunked header, so no need to do it explicitly + w.WriteHeader(http.StatusOK) + + // Ensure headers are flushed early, in case of long polling + w.(http.Flusher).Flush() + + cw := httputil.NewChunkedWriter(w) + + for { + select { + case <-nch: + // Client closed connection. Nothing to do. + return + case <-tch: + cw.Close() + return + case ev, ok := <-ech: + if !ok { + // If the channel is closed this may be an indication of + // that notifications are much more than we are able to + // send to the client in time. Then we simply end streaming. + return + } + if err := json.NewEncoder(cw).Encode(ev); err != nil { + // Should never be reached + log.Println("error writing event: %v", err) + return + } + if !stream { + return + } + w.(http.Flusher).Flush() + } } + } // allowMethod verifies that the given method is one of the allowed methods, diff --git a/etcdserver/etcdhttp/http_test.go b/etcdserver/etcdhttp/http_test.go index f0f75f2a7..77114d5ea 100644 --- a/etcdserver/etcdhttp/http_test.go +++ b/etcdserver/etcdhttp/http_test.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "errors" + "fmt" "io" "net/http" "net/http/httptest" @@ -11,7 +12,6 @@ import ( "path" "reflect" "strings" - "sync" "testing" "time" @@ -36,8 +36,12 @@ func mustNewURL(t *testing.T, s string) *url.URL { // mustNewRequest takes a path, appends it to the standard keysPrefix, and constructs // a GET *http.Request referencing the resulting URL func mustNewRequest(t *testing.T, p string) *http.Request { + return mustNewMethodRequest(t, "GET", p) +} + +func mustNewMethodRequest(t *testing.T, m, p string) *http.Request { return &http.Request{ - Method: "GET", + Method: m, URL: mustNewURL(t, path.Join(keysPrefix, p)), } } @@ -99,7 +103,7 @@ func TestBadParseRequest(t *testing.T) { mustNewForm(t, "foo", url.Values{"ttl": []string{"-1"}}), etcdErr.EcodeTTLNaN, }, - // bad values for recursive, sorted, wait, prevExist + // bad values for recursive, sorted, wait, prevExist, stream { mustNewForm(t, "foo", url.Values{"recursive": []string{"hahaha"}}), etcdErr.EcodeInvalidField, @@ -136,6 +140,19 @@ func TestBadParseRequest(t *testing.T) { mustNewForm(t, "foo", url.Values{"prevExist": []string{"#2"}}), etcdErr.EcodeInvalidField, }, + { + mustNewForm(t, "foo", url.Values{"stream": []string{"zzz"}}), + etcdErr.EcodeInvalidField, + }, + { + mustNewForm(t, "foo", url.Values{"stream": []string{"something"}}), + etcdErr.EcodeInvalidField, + }, + // wait is only valid with GET requests + { + mustNewMethodRequest(t, "HEAD", "foo?wait=true"), + etcdErr.EcodeInvalidField, + }, // query values are considered { mustNewRequest(t, "foo?prevExist=wrong"), @@ -256,14 +273,10 @@ func TestGoodParseRequest(t *testing.T) { }, { // wait specified - mustNewForm( - t, - "foo", - url.Values{"wait": []string{"true"}}, - ), + mustNewRequest(t, "foo?wait=true"), etcdserverpb.Request{ Id: 1234, - Method: "PUT", + Method: "GET", Wait: true, Path: "/foo", }, @@ -492,100 +505,6 @@ func (w *dummyWatcher) EventChan() chan *store.Event { } func (w *dummyWatcher) Remove() {} -type dummyResponseWriter struct { - cnchan chan bool - http.ResponseWriter -} - -func (rw *dummyResponseWriter) CloseNotify() <-chan bool { - return rw.cnchan -} - -func TestWaitForEventChan(t *testing.T) { - ctx := context.Background() - ec := make(chan *store.Event) - dw := &dummyWatcher{ - echan: ec, - } - w := httptest.NewRecorder() - var wg sync.WaitGroup - var ev *store.Event - var err error - wg.Add(1) - go func() { - ev, err = waitForEvent(ctx, w, dw) - wg.Done() - }() - ec <- &store.Event{ - Action: store.Get, - Node: &store.NodeExtern{ - Key: "/foo/bar", - ModifiedIndex: 12345, - }, - } - wg.Wait() - want := &store.Event{ - Action: store.Get, - Node: &store.NodeExtern{ - Key: "/foo/bar", - ModifiedIndex: 12345, - }, - } - if !reflect.DeepEqual(ev, want) { - t.Fatalf("bad event: got %#v, want %#v", ev, want) - } - if err != nil { - t.Fatalf("unexpected error: %v", err) - } -} - -func TestWaitForEventCloseNotify(t *testing.T) { - ctx := context.Background() - dw := &dummyWatcher{} - cnchan := make(chan bool) - w := &dummyResponseWriter{ - cnchan: cnchan, - } - var wg sync.WaitGroup - var ev *store.Event - var err error - wg.Add(1) - go func() { - ev, err = waitForEvent(ctx, w, dw) - wg.Done() - }() - close(cnchan) - wg.Wait() - if ev != nil { - t.Fatalf("non-nil Event returned with CloseNotifier: %v", ev) - } - if err == nil { - t.Fatalf("nil err returned with CloseNotifier!") - } -} - -func TestWaitForEventCancelledContext(t *testing.T) { - cctx, cancel := context.WithCancel(context.Background()) - dw := &dummyWatcher{} - w := httptest.NewRecorder() - var wg sync.WaitGroup - var ev *store.Event - var err error - wg.Add(1) - go func() { - ev, err = waitForEvent(cctx, w, dw) - wg.Done() - }() - cancel() - wg.Wait() - if ev != nil { - t.Fatalf("non-nil Event returned with cancelled context: %v", ev) - } - if err == nil { - t.Fatalf("nil err returned with cancelled context!") - } -} - func TestV2MachinesEndpoint(t *testing.T) { tests := []struct { method string @@ -950,17 +869,6 @@ func TestBadServeKeys(t *testing.T) { http.StatusInternalServerError, }, - { - // timeout waiting for event (watcher never returns) - mustNewRequest(t, "foo"), - &resServer{ - etcdserver.Response{ - Watcher: &dummyWatcher{}, - }, - }, - - http.StatusGatewayTimeout, - }, { // non-event/watcher response from etcdserver.Server mustNewRequest(t, "foo"), @@ -1056,6 +964,7 @@ func TestServeKeysWatch(t *testing.T) { Node: &store.NodeExtern{}, }, ) + wbody = fmt.Sprintf("%x\r\n%s\r\n", len(wbody), wbody) if rw.Code != wcode { t.Errorf("got code=%d, want %d", rw.Code, wcode) diff --git a/etcdserver/etcdserverpb/etcdserver.pb.go b/etcdserver/etcdserverpb/etcdserver.pb.go index a4176a31f..e81b21508 100644 --- a/etcdserver/etcdserverpb/etcdserver.pb.go +++ b/etcdserver/etcdserverpb/etcdserver.pb.go @@ -43,6 +43,7 @@ type Request struct { Sorted bool `protobuf:"varint,13,req,name=sorted" json:"sorted"` Quorum bool `protobuf:"varint,14,req,name=quorum" json:"quorum"` Time int64 `protobuf:"varint,15,req,name=time" json:"time"` + Stream bool `protobuf:"varint,16,req,name=stream" json:"stream"` XXX_unrecognized []byte `json:"-"` } @@ -337,6 +338,23 @@ func (m *Request) Unmarshal(data []byte) error { break } } + case 16: + if wireType != 0 { + return code_google_com_p_gogoprotobuf_proto.ErrWrongType + } + var v int + for shift := uint(0); ; shift += 7 { + if index >= l { + return io.ErrUnexpectedEOF + } + b := data[index] + index++ + v |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Stream = bool(v != 0) default: var sizeOfWire int for { @@ -384,6 +402,7 @@ func (m *Request) Size() (n int) { n += 2 n += 2 n += 1 + sovEtcdserver(uint64(m.Time)) + n += 3 if m.XXX_unrecognized != nil { n += len(m.XXX_unrecognized) } @@ -499,6 +518,16 @@ func (m *Request) MarshalTo(data []byte) (n int, err error) { data[i] = 0x78 i++ i = encodeVarintEtcdserver(data, i, uint64(m.Time)) + data[i] = 0x80 + i++ + data[i] = 0x1 + i++ + if m.Stream { + data[i] = 1 + } else { + data[i] = 0 + } + i++ if m.XXX_unrecognized != nil { i += copy(data[i:], m.XXX_unrecognized) } diff --git a/etcdserver/etcdserverpb/etcdserver.proto b/etcdserver/etcdserverpb/etcdserver.proto index bb34ac250..a0b98bb0f 100644 --- a/etcdserver/etcdserverpb/etcdserver.proto +++ b/etcdserver/etcdserverpb/etcdserver.proto @@ -23,4 +23,5 @@ message Request { required bool sorted = 13 [(gogoproto.nullable) = false]; required bool quorum = 14 [(gogoproto.nullable) = false]; required int64 time = 15 [(gogoproto.nullable) = false]; + required bool stream = 16 [(gogoproto.nullable) = false]; } diff --git a/etcdserver/server.go b/etcdserver/server.go index b0f16a1c9..4aeb2c3f0 100644 --- a/etcdserver/server.go +++ b/etcdserver/server.go @@ -213,7 +213,7 @@ func (s *EtcdServer) Do(ctx context.Context, r pb.Request) (Response, error) { case "GET": switch { case r.Wait: - wc, err := s.Store.Watch(r.Path, r.Recursive, false, r.Since) + wc, err := s.Store.Watch(r.Path, r.Recursive, r.Stream, r.Since) if err != nil { return Response{}, err }