diff --git a/file_system/file_system.go b/file_system/file_system.go index eb19e2ea8..43f850275 100644 --- a/file_system/file_system.go +++ b/file_system/file_system.go @@ -198,12 +198,6 @@ func (fs *FileSystem) Delete(keyPath string, recurisive bool, index uint64, term return nil, err } - err = n.Remove(recurisive) - - if err != nil { - return nil, err - } - e := newEvent(Delete, keyPath, index, term) if n.IsDir() { @@ -212,6 +206,18 @@ func (fs *FileSystem) Delete(keyPath string, recurisive bool, index uint64, term e.PrevValue = n.Value } + callback := func(path string) { + fs.WatcherHub.notifyWithPath(e, path, true) + } + + err = n.Remove(recurisive, callback) + + if err != nil { + return nil, err + } + + fs.WatcherHub.notify(e) + return e, nil } diff --git a/file_system/file_system_test.go b/file_system/file_system_test.go index 04f5ad9b6..16666f614 100644 --- a/file_system/file_system_test.go +++ b/file_system/file_system_test.go @@ -230,6 +230,29 @@ func TestTestAndSet(t *testing.T) { } } +func TestWatchRemove(t *testing.T) { + fs := New() + fs.Create("/foo/foo/foo", "bar", Permanent, 1, 1) + + // watch at a deeper path + c, _ := fs.WatcherHub.watch("/foo/foo/foo", false, 0) + fs.Delete("/foo", true, 2, 1) + e := <-c + if e.Key != "/foo" { + t.Fatal("watch for delete fails") + } + + fs.Create("/foo/foo/foo", "bar", Permanent, 3, 1) + // watch at a prefix + c, _ = fs.WatcherHub.watch("/foo", true, 0) + fs.Delete("/foo/foo/foo", false, 4, 1) + e = <-c + if e.Key != "/foo/foo/foo" { + t.Fatal("watch for delete fails") + } + +} + func createAndGet(fs *FileSystem, path string, t *testing.T) { _, err := fs.Create(path, "bar", Permanent, 1, 1) diff --git a/file_system/node.go b/file_system/node.go index 4d933f55a..761266768 100644 --- a/file_system/node.go +++ b/file_system/node.go @@ -65,7 +65,7 @@ func newDir(keyPath string, createIndex uint64, createTerm uint64, parent *Node, // Remove function remove the node. // If the node is a directory and recursive is true, the function will recursively remove // add nodes under the receiver node. -func (n *Node) Remove(recursive bool) error { +func (n *Node) Remove(recursive bool, callback func(path string)) error { n.mu.Lock() defer n.mu.Unlock() @@ -80,6 +80,11 @@ func (n *Node) Remove(recursive bool) error { // This is the only pointer to Node object // Handled by garbage collector delete(n.Parent.Children, name) + + if callback != nil { + callback(n.Path) + } + n.stopExpire <- true n.status = removed } @@ -92,13 +97,18 @@ func (n *Node) Remove(recursive bool) error { } for _, child := range n.Children { // delete all children - child.Remove(true) + child.Remove(true, callback) } // delete self _, name := path.Split(n.Path) if n.Parent.Children[name] == n { delete(n.Parent.Children, name) + + if callback != nil { + callback(n.Path) + } + n.stopExpire <- true n.status = removed } @@ -235,14 +245,14 @@ func (n *Node) IsDir() bool { func (n *Node) Expire() { duration := n.ExpireTime.Sub(time.Now()) if duration <= 0 { - n.Remove(true) + n.Remove(true, nil) return } select { // if timeout, delete the node case <-time.After(duration): - n.Remove(true) + n.Remove(true, nil) return // if stopped, return diff --git a/file_system/watcher.go b/file_system/watcher.go index 7cad8b15b..c17d0eb87 100644 --- a/file_system/watcher.go +++ b/file_system/watcher.go @@ -28,18 +28,18 @@ func newWatchHub(capacity int) *watcherHub { // If recursive is true, the first change after index under prefix will be sent to the event channel. // If recursive is false, the first change after index at prefix will be sent to the event channel. // If index is zero, watch will start from the current index + 1. -func (wh *watcherHub) watch(prefix string, recursive bool, index uint64) (error, <-chan *Event) { +func (wh *watcherHub) watch(prefix string, recursive bool, index uint64) (<-chan *Event, error) { eventChan := make(chan *Event, 1) e, err := wh.EventHistory.scan(prefix, index) if err != nil { - return err, nil + return nil, err } if e != nil { eventChan <- e - return nil, eventChan + return eventChan, nil } w := &watcher{ @@ -58,57 +58,55 @@ func (wh *watcherHub) watch(prefix string, recursive bool, index uint64) (error, wh.watchers[prefix] = l } - return nil, eventChan + return eventChan, nil +} + +func (wh *watcherHub) notifyWithPath(e *Event, path string, force bool) { + l, ok := wh.watchers[path] + + if ok { + + curr := l.Front() + notifiedAll := true + + for { + + if curr == nil { // we have reached the end of the list + + if notifiedAll { + // if we have notified all watcher in the list + // we can delete the list + delete(wh.watchers, path) + } + break + } + + next := curr.Next() // save the next + + w, _ := curr.Value.(*watcher) + + if w.recursive || force || e.Key == path { + w.eventChan <- e + l.Remove(curr) + } else { + notifiedAll = false + } + + curr = next // go to the next one + + } + } } func (wh *watcherHub) notify(e *Event) { segments := strings.Split(e.Key, "/") + currPath := "/" // walk through all the paths for _, segment := range segments { currPath = path.Join(currPath, segment) - - l, ok := wh.watchers[currPath] - - if ok { - - curr := l.Front() - notifiedAll := true - - for { - - if curr == nil { // we have reached the end of the list - - if notifiedAll { - // if we have notified all watcher in the list - // we can delete the list - delete(wh.watchers, currPath) - } - break - } - - next := curr.Next() // save the next - - w, _ := curr.Value.(*watcher) - - if w.recursive { - w.eventChan <- e - l.Remove(curr) - } else { - if e.Key == currPath { // only notify the same path - w.eventChan <- e - l.Remove(curr) - } else { // we do not notify all watcher in the list - notifiedAll = false - } - } - - curr = next // go to the next one - - } - } - + wh.notifyWithPath(e, currPath, false) } } diff --git a/file_system/watcher_test.go b/file_system/watcher_test.go index c63a489d7..b817e64ec 100644 --- a/file_system/watcher_test.go +++ b/file_system/watcher_test.go @@ -6,7 +6,7 @@ import ( func TestWatch(t *testing.T) { wh := newWatchHub(100) - err, c := wh.watch("/foo", true, 0) + c, err := wh.watch("/foo", true, 0) if err != nil { t.Fatal("%v", err) @@ -29,7 +29,7 @@ func TestWatch(t *testing.T) { t.Fatal("recv != send") } - _, c = wh.watch("/foo", false, 0) + c, _ = wh.watch("/foo", false, 0) e = newEvent(Set, "/foo/bar", 1, 0)