From f67bdc2eedacd1cd59e62cbee8f52d0529b81fad Mon Sep 17 00:00:00 2001 From: Anthony Romano Date: Mon, 3 Apr 2017 13:34:13 -0700 Subject: [PATCH] *: support checking that an interval tree's keys cover an entire interval --- auth/range_perm_cache.go | 46 ++++---------------- mvcc/watcher_group.go | 2 +- pkg/adt/interval_tree.go | 28 +++++++++++- pkg/adt/interval_tree_test.go | 80 +++++++++++++++++++++++++++++++---- 4 files changed, 106 insertions(+), 50 deletions(-) diff --git a/auth/range_perm_cache.go b/auth/range_perm_cache.go index f69807edd..8f4d9f5db 100644 --- a/auth/range_perm_cache.go +++ b/auth/range_perm_cache.go @@ -66,59 +66,29 @@ func getMergedPerms(tx backend.BatchTx, userName string) *unifiedRangePermission } func checkKeyInterval(cachedPerms *unifiedRangePermissions, key, rangeEnd string, permtyp authpb.Permission_Type) bool { - var tocheck *adt.IntervalTree - + ivl := adt.NewStringInterval(key, rangeEnd) switch permtyp { case authpb.READ: - tocheck = cachedPerms.readPerms + return cachedPerms.readPerms.Contains(ivl) case authpb.WRITE: - tocheck = cachedPerms.writePerms + return cachedPerms.writePerms.Contains(ivl) default: plog.Panicf("unknown auth type: %v", permtyp) } - - ivl := adt.NewStringInterval(key, rangeEnd) - - isContiguous := true - var maxEnd, minBegin adt.Comparable - - tocheck.Visit(ivl, func(n *adt.IntervalValue) bool { - if minBegin == nil { - minBegin = n.Ivl.Begin - maxEnd = n.Ivl.End - return true - } - - if maxEnd.Compare(n.Ivl.Begin) < 0 { - isContiguous = false - return false - } - - if n.Ivl.End.Compare(maxEnd) > 0 { - maxEnd = n.Ivl.End - } - - return true - }) - - return isContiguous && maxEnd.Compare(ivl.End) >= 0 && minBegin.Compare(ivl.Begin) <= 0 + return false } func checkKeyPoint(cachedPerms *unifiedRangePermissions, key string, permtyp authpb.Permission_Type) bool { - var tocheck *adt.IntervalTree - + pt := adt.NewStringPoint(key) switch permtyp { case authpb.READ: - tocheck = cachedPerms.readPerms + return cachedPerms.readPerms.Intersects(pt) case authpb.WRITE: - tocheck = cachedPerms.writePerms + return cachedPerms.writePerms.Intersects(pt) default: plog.Panicf("unknown auth type: %v", permtyp) } - - pt := adt.NewStringPoint(key) - - return tocheck.Contains(pt) + return false } func (as *authStore) isRangeOpPermitted(tx backend.BatchTx, userName string, key, rangeEnd string, permtyp authpb.Permission_Type) bool { diff --git a/mvcc/watcher_group.go b/mvcc/watcher_group.go index 2710c1cc9..6ef1d0ce8 100644 --- a/mvcc/watcher_group.go +++ b/mvcc/watcher_group.go @@ -183,7 +183,7 @@ func (wg *watcherGroup) add(wa *watcher) { // contains is whether the given key has a watcher in the group. func (wg *watcherGroup) contains(key string) bool { _, ok := wg.keyWatchers[key] - return ok || wg.ranges.Contains(adt.NewStringAffinePoint(key)) + return ok || wg.ranges.Intersects(adt.NewStringAffinePoint(key)) } // size gives the number of unique watchers in the group. diff --git a/pkg/adt/interval_tree.go b/pkg/adt/interval_tree.go index a93625045..9c5afb3f0 100644 --- a/pkg/adt/interval_tree.go +++ b/pkg/adt/interval_tree.go @@ -437,8 +437,8 @@ func (ivt *IntervalTree) Find(ivl Interval) (ret *IntervalValue) { return &n.iv } -// Contains returns true if there is some tree node intersecting the given interval. -func (ivt *IntervalTree) Contains(iv Interval) bool { +// Intersects returns true if there is some tree node intersecting the given interval. +func (ivt *IntervalTree) Intersects(iv Interval) bool { x := ivt.root for x != nil && iv.Compare(&x.iv.Ivl) != 0 { if x.left != nil && x.left.max.Compare(iv.Begin) > 0 { @@ -450,6 +450,30 @@ func (ivt *IntervalTree) Contains(iv Interval) bool { return x != nil } +// Contains returns true if the interval tree's keys cover the entire given interval. +func (ivt *IntervalTree) Contains(ivl Interval) bool { + var maxEnd, minBegin Comparable + + isContiguous := true + ivt.Visit(ivl, func(n *IntervalValue) bool { + if minBegin == nil { + minBegin = n.Ivl.Begin + maxEnd = n.Ivl.End + return true + } + if maxEnd.Compare(n.Ivl.Begin) < 0 { + isContiguous = false + return false + } + if n.Ivl.End.Compare(maxEnd) > 0 { + maxEnd = n.Ivl.End + } + return true + }) + + return isContiguous && minBegin != nil && maxEnd.Compare(ivl.End) >= 0 && minBegin.Compare(ivl.Begin) <= 0 +} + // Stab returns a slice with all elements in the tree intersecting the interval. func (ivt *IntervalTree) Stab(iv Interval) (ivs []*IntervalValue) { if ivt.count == 0 { diff --git a/pkg/adt/interval_tree_test.go b/pkg/adt/interval_tree_test.go index f9770b3a5..493c11fa0 100644 --- a/pkg/adt/interval_tree_test.go +++ b/pkg/adt/interval_tree_test.go @@ -20,23 +20,23 @@ import ( "time" ) -func TestIntervalTreeContains(t *testing.T) { +func TestIntervalTreeIntersects(t *testing.T) { ivt := &IntervalTree{} ivt.Insert(NewStringInterval("1", "3"), 123) - if ivt.Contains(NewStringPoint("0")) { + if ivt.Intersects(NewStringPoint("0")) { t.Errorf("contains 0") } - if !ivt.Contains(NewStringPoint("1")) { + if !ivt.Intersects(NewStringPoint("1")) { t.Errorf("missing 1") } - if !ivt.Contains(NewStringPoint("11")) { + if !ivt.Intersects(NewStringPoint("11")) { t.Errorf("missing 11") } - if !ivt.Contains(NewStringPoint("2")) { + if !ivt.Intersects(NewStringPoint("2")) { t.Errorf("missing 2") } - if ivt.Contains(NewStringPoint("3")) { + if ivt.Intersects(NewStringPoint("3")) { t.Errorf("contains 3") } } @@ -44,10 +44,10 @@ func TestIntervalTreeContains(t *testing.T) { func TestIntervalTreeStringAffine(t *testing.T) { ivt := &IntervalTree{} ivt.Insert(NewStringAffineInterval("8", ""), 123) - if !ivt.Contains(NewStringAffinePoint("9")) { + if !ivt.Intersects(NewStringAffinePoint("9")) { t.Errorf("missing 9") } - if ivt.Contains(NewStringAffinePoint("7")) { + if ivt.Intersects(NewStringAffinePoint("7")) { t.Errorf("contains 7") } } @@ -122,7 +122,7 @@ func TestIntervalTreeRandom(t *testing.T) { if slen := len(ivt.Stab(NewInt64Point(v))); slen == 0 { t.Fatalf("expected %v stab non-zero for [%+v)", v, xy) } - if !ivt.Contains(NewInt64Point(v)) { + if !ivt.Intersects(NewInt64Point(v)) { t.Fatalf("did not get %d as expected for [%+v)", v, xy) } } @@ -231,3 +231,65 @@ func TestIntervalTreeVisitExit(t *testing.T) { } } } + +// TestIntervalTreeContains tests that contains returns true iff the ivt maps the entire interval. +func TestIntervalTreeContains(t *testing.T) { + tests := []struct { + ivls []Interval + chkIvl Interval + + wContains bool + }{ + { + ivls: []Interval{NewInt64Interval(1, 10)}, + chkIvl: NewInt64Interval(0, 100), + + wContains: false, + }, + { + ivls: []Interval{NewInt64Interval(1, 10)}, + chkIvl: NewInt64Interval(1, 10), + + wContains: true, + }, + { + ivls: []Interval{NewInt64Interval(1, 10)}, + chkIvl: NewInt64Interval(2, 8), + + wContains: true, + }, + { + ivls: []Interval{NewInt64Interval(1, 5), NewInt64Interval(6, 10)}, + chkIvl: NewInt64Interval(1, 10), + + wContains: false, + }, + { + ivls: []Interval{NewInt64Interval(1, 5), NewInt64Interval(3, 10)}, + chkIvl: NewInt64Interval(1, 10), + + wContains: true, + }, + { + ivls: []Interval{NewInt64Interval(1, 4), NewInt64Interval(4, 7), NewInt64Interval(3, 10)}, + chkIvl: NewInt64Interval(1, 10), + + wContains: true, + }, + { + ivls: []Interval{}, + chkIvl: NewInt64Interval(1, 10), + + wContains: false, + }, + } + for i, tt := range tests { + ivt := &IntervalTree{} + for _, ivl := range tt.ivls { + ivt.Insert(ivl, struct{}{}) + } + if v := ivt.Contains(tt.chkIvl); v != tt.wContains { + t.Errorf("#%d: ivt.Contains got %v, expected %v", i, v, tt.wContains) + } + } +}