grpcproxy: shut down watcher proxy when client context is done

This commit is contained in:
Anthony Romano 2016-09-01 15:20:50 -07:00
parent 26999db927
commit d3ecebd14e
2 changed files with 25 additions and 12 deletions

View File

@ -31,16 +31,25 @@ type watchProxy struct {
mu sync.Mutex mu sync.Mutex
nextStreamID int64 nextStreamID int64
ctx context.Context
} }
func NewWatchProxy(c *clientv3.Client) pb.WatchServer { func NewWatchProxy(c *clientv3.Client) pb.WatchServer {
return &watchProxy{ wp := &watchProxy{
cw: c.Watcher, cw: c.Watcher,
wgs: watchergroups{ wgs: watchergroups{
cw: c.Watcher, cw: c.Watcher,
groups: make(map[watchRange]*watcherGroup), groups: make(map[watchRange]*watcherGroup),
proxyCtx: c.Ctx(),
}, },
ctx: c.Ctx(),
} }
go func() {
<-wp.ctx.Done()
wp.wgs.stop()
}()
return wp
} }
func (wp *watchProxy) Watch(stream pb.Watch_WatchServer) (err error) { func (wp *watchProxy) Watch(stream pb.Watch_WatchServer) (err error) {
@ -58,13 +67,13 @@ func (wp *watchProxy) Watch(stream pb.Watch_WatchServer) (err error) {
ctrlCh: make(chan *pb.WatchResponse, 10), ctrlCh: make(chan *pb.WatchResponse, 10),
watchCh: make(chan *pb.WatchResponse, 10), watchCh: make(chan *pb.WatchResponse, 10),
proxyCtx: wp.ctx,
} }
go sws.recvLoop() go sws.recvLoop()
sws.sendLoop() sws.sendLoop()
return wp.ctx.Err()
return nil
} }
type serverWatchStream struct { type serverWatchStream struct {
@ -81,6 +90,8 @@ type serverWatchStream struct {
watchCh chan *pb.WatchResponse watchCh chan *pb.WatchResponse
nextWatcherID int64 nextWatcherID int64
proxyCtx context.Context
} }
func (sws *serverWatchStream) close() { func (sws *serverWatchStream) close() {
@ -89,8 +100,8 @@ func (sws *serverWatchStream) close() {
var wg sync.WaitGroup var wg sync.WaitGroup
sws.mu.Lock() sws.mu.Lock()
wg.Add(len(sws.singles))
for _, ws := range sws.singles { for _, ws := range sws.singles {
wg.Add(1)
ws.stop() ws.stop()
// copy the range variable to avoid race // copy the range variable to avoid race
copyws := ws copyws := ws
@ -100,10 +111,7 @@ func (sws *serverWatchStream) close() {
}() }()
} }
sws.mu.Unlock() sws.mu.Unlock()
wg.Wait() wg.Wait()
sws.groups.stop()
} }
func (sws *serverWatchStream) recvLoop() error { func (sws *serverWatchStream) recvLoop() error {
@ -166,6 +174,8 @@ func (sws *serverWatchStream) sendLoop() {
if err := sws.gRPCStream.Send(c); err != nil { if err := sws.gRPCStream.Send(c); err != nil {
return return
} }
case <-sws.proxyCtx.Done():
return
} }
} }
} }
@ -182,7 +192,7 @@ func (sws *serverWatchStream) addDedicatedWatcher(w watcher, rev int64) {
sws.mu.Lock() sws.mu.Lock()
defer sws.mu.Unlock() defer sws.mu.Unlock()
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(sws.proxyCtx)
wch := sws.cw.Watch(ctx, wch := sws.cw.Watch(ctx,
w.wr.key, clientv3.WithRange(w.wr.end), w.wr.key, clientv3.WithRange(w.wr.end),

View File

@ -27,6 +27,8 @@ type watchergroups struct {
mu sync.Mutex mu sync.Mutex
groups map[watchRange]*watcherGroup groups map[watchRange]*watcherGroup
idToGroup map[receiverID]*watcherGroup idToGroup map[receiverID]*watcherGroup
proxyCtx context.Context
} }
func (wgs *watchergroups) addWatcher(rid receiverID, w watcher) { func (wgs *watchergroups) addWatcher(rid receiverID, w watcher) {
@ -40,7 +42,7 @@ func (wgs *watchergroups) addWatcher(rid receiverID, w watcher) {
return return
} }
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(wgs.proxyCtx)
wch := wgs.cw.Watch(ctx, w.wr.key, wch := wgs.cw.Watch(ctx, w.wr.key,
clientv3.WithRange(w.wr.end), clientv3.WithRange(w.wr.end),
@ -98,4 +100,5 @@ func (wgs *watchergroups) stop() {
for _, wg := range wgs.groups { for _, wg := range wgs.groups {
wg.stop() wg.stop()
} }
wgs.groups = nil
} }