From 1aeeb38459f496288b2b3c1bfe44941bb70befcf Mon Sep 17 00:00:00 2001 From: Anthony Romano Date: Mon, 7 Nov 2016 13:03:47 -0800 Subject: [PATCH] clientv3: let watchers cancel when reconnecting --- clientv3/watch.go | 98 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 70 insertions(+), 28 deletions(-) diff --git a/clientv3/watch.go b/clientv3/watch.go index 2b0e657ca..08f7d56da 100644 --- a/clientv3/watch.go +++ b/clientv3/watch.go @@ -126,8 +126,6 @@ type watchGrpcStream struct { reqc chan *watchRequest // respc receives data from the watch client respc chan *pb.WatchResponse - // stopc is sent to the main goroutine to stop all processing - stopc chan struct{} // donec closes to broadcast shutdown donec chan struct{} // errc transmits errors from grpc Recv to the watch stream reconn logic @@ -213,7 +211,6 @@ func (w *watcher) newWatcherGrpcStream(inctx context.Context) *watchGrpcStream { respc: make(chan *pb.WatchResponse), reqc: make(chan *watchRequest), - stopc: make(chan struct{}), donec: make(chan struct{}), errc: make(chan error, 1), closingc: make(chan *watcherStream), @@ -319,7 +316,7 @@ func (w *watcher) Close() (err error) { } func (w *watchGrpcStream) Close() (err error) { - close(w.stopc) + w.cancel() <-w.donec select { case err = <-w.errc: @@ -366,7 +363,7 @@ func (w *watchGrpcStream) closeSubstream(ws *watcherStream) { // close subscriber's channel if closeErr := w.closeErr; closeErr != nil && ws.initReq.ctx.Err() == nil { go w.sendCloseSubstream(ws, &WatchResponse{closeErr: w.closeErr}) - } else { + } else if ws.outc != nil { close(ws.outc) } if ws.id != -1 { @@ -493,7 +490,7 @@ func (w *watchGrpcStream) run() { wc.Send(ws.initReq.toPB()) } cancelSet = make(map[int64]struct{}) - case <-w.stopc: + case <-w.ctx.Done(): return case ws := <-w.closingc: w.closeSubstream(ws) @@ -632,7 +629,8 @@ func (w *watchGrpcStream) serveSubstream(ws *watcherStream, resumec chan struct{ // TODO pause channel if buffer gets too large ws.buf = append(ws.buf, wr) - + case <-w.ctx.Done(): + return case <-ws.initReq.ctx.Done(): return case <-resumec: @@ -644,34 +642,78 @@ func (w *watchGrpcStream) serveSubstream(ws *watcherStream, resumec chan struct{ } func (w *watchGrpcStream) newWatchClient() (pb.Watch_WatchClient, error) { - // connect to grpc stream + // mark all substreams as resuming + close(w.resumec) + w.resumec = make(chan struct{}) + w.joinSubstreams() + for _, ws := range w.substreams { + ws.id = -1 + w.resuming = append(w.resuming, ws) + } + // strip out nils, if any + var resuming []*watcherStream + for _, ws := range w.resuming { + if ws != nil { + resuming = append(resuming, ws) + } + } + w.resuming = resuming + w.substreams = make(map[int64]*watcherStream) + + // connect to grpc stream while accepting watcher cancelation + stopc := make(chan struct{}) + donec := w.waitCancelSubstreams(stopc) wc, err := w.openWatchClient() + close(stopc) + <-donec + + // serve all non-closing streams, even if there's a client error + // so that the teardown path can shutdown the streams as expected. + for _, ws := range w.resuming { + if ws.closing { + continue + } + ws.donec = make(chan struct{}) + go w.serveSubstream(ws, w.resumec) + } + if err != nil { return nil, v3rpc.Error(err) } - // mark all substreams as resuming - if len(w.substreams)+len(w.resuming) > 0 { - close(w.resumec) - w.resumec = make(chan struct{}) - w.joinSubstreams() - for _, ws := range w.substreams { - ws.id = -1 - w.resuming = append(w.resuming, ws) - } - for _, ws := range w.resuming { - if ws == nil || ws.closing { - continue - } - ws.donec = make(chan struct{}) - go w.serveSubstream(ws, w.resumec) - } - } - w.substreams = make(map[int64]*watcherStream) + // receive data from new grpc stream go w.serveWatchClient(wc) return wc, nil } +func (w *watchGrpcStream) waitCancelSubstreams(stopc <-chan struct{}) <-chan struct{} { + var wg sync.WaitGroup + wg.Add(len(w.resuming)) + donec := make(chan struct{}) + for i := range w.resuming { + go func(ws *watcherStream) { + defer wg.Done() + if ws.closing { + return + } + select { + case <-ws.initReq.ctx.Done(): + // closed ws will be removed from resuming + ws.closing = true + close(ws.outc) + ws.outc = nil + go func() { w.closingc <- ws }() + case <-stopc: + } + }(w.resuming[i]) + } + go func() { + defer close(donec) + wg.Wait() + }() + return donec +} + // joinSubstream waits for all substream goroutines to complete func (w *watchGrpcStream) joinSubstreams() { for _, ws := range w.substreams { @@ -688,9 +730,9 @@ func (w *watchGrpcStream) joinSubstreams() { func (w *watchGrpcStream) openWatchClient() (ws pb.Watch_WatchClient, err error) { for { select { - case <-w.stopc: + case <-w.ctx.Done(): if err == nil { - return nil, context.Canceled + return nil, w.ctx.Err() } return nil, err default: