diff --git a/mvcc/watchable_store.go b/mvcc/watchable_store.go index b9664bd81..2d8545942 100644 --- a/mvcc/watchable_store.go +++ b/mvcc/watchable_store.go @@ -144,7 +144,6 @@ func (s *watchableStore) watch(key, end []byte, startRev int64, id WatchID, ch c func (s *watchableStore) cancelWatcher(wa *watcher) { for { s.mu.Lock() - if s.unsynced.delete(wa) { slowWatcherGauge.Dec() break @@ -152,6 +151,9 @@ func (s *watchableStore) cancelWatcher(wa *watcher) { break } else if wa.compacted { break + } else if wa.ch == nil { + // already canceled (e.g., cancel/close race) + break } if !wa.victim { @@ -177,6 +179,7 @@ func (s *watchableStore) cancelWatcher(wa *watcher) { } watcherGauge.Dec() + wa.ch = nil s.mu.Unlock() } @@ -425,7 +428,6 @@ func (s *watchableStore) notify(rev int64, evs []mvccpb.Event) { if eb.revs != 1 { plog.Panicf("unexpected multiple revisions in notification") } - if w.send(WatchResponse{WatchID: w.id, Events: eb.evs, Revision: rev}) { pendingEventsGauge.Add(float64(len(eb.evs))) } else { diff --git a/mvcc/watchable_store_test.go b/mvcc/watchable_store_test.go index 93c7cc954..52e1b90c0 100644 --- a/mvcc/watchable_store_test.go +++ b/mvcc/watchable_store_test.go @@ -539,3 +539,49 @@ func TestWatchVictims(t *testing.T) { default: } } + +// TestStressWatchCancelClose tests closing a watch stream while +// canceling its watches. +func TestStressWatchCancelClose(t *testing.T) { + b, tmpPath := backend.NewDefaultTmpBackend() + s := newWatchableStore(b, &lease.FakeLessor{}, nil) + + defer func() { + s.store.Close() + os.Remove(tmpPath) + }() + + testKey, testValue := []byte("foo"), []byte("bar") + var wg sync.WaitGroup + readyc := make(chan struct{}) + wg.Add(100) + for i := 0; i < 100; i++ { + go func() { + defer wg.Done() + w := s.NewWatchStream() + ids := make([]WatchID, 10) + for i := range ids { + ids[i] = w.Watch(testKey, nil, 0) + } + <-readyc + wg.Add(1 + len(ids)/2) + for i := range ids[:len(ids)/2] { + go func(n int) { + defer wg.Done() + w.Cancel(ids[n]) + }(i) + } + go func() { + defer wg.Done() + w.Close() + }() + }() + } + + close(readyc) + for i := 0; i < 100; i++ { + s.Put(testKey, testValue, lease.NoLease) + } + + wg.Wait() +} diff --git a/mvcc/watcher.go b/mvcc/watcher.go index 9468d4269..bc0c6322f 100644 --- a/mvcc/watcher.go +++ b/mvcc/watcher.go @@ -129,16 +129,25 @@ func (ws *watchStream) Chan() <-chan WatchResponse { func (ws *watchStream) Cancel(id WatchID) error { ws.mu.Lock() cancel, ok := ws.cancels[id] + w := ws.watchers[id] ok = ok && !ws.closed - if ok { - delete(ws.cancels, id) - delete(ws.watchers, id) - } ws.mu.Unlock() + if !ok { return ErrWatcherNotExist } cancel() + + ws.mu.Lock() + // The watch isn't removed until cancel so that if Close() is called, + // it will wait for the cancel. Otherwise, Close() could close the + // watch channel while the store is still posting events. + if ww := ws.watchers[id]; ww == w { + delete(ws.cancels, id) + delete(ws.watchers, id) + } + ws.mu.Unlock() + return nil }