From 25e3ce1febbfe793b468c95b4e282dbfacac0ece Mon Sep 17 00:00:00 2001 From: Anthony Romano Date: Wed, 22 Mar 2017 22:27:05 -0700 Subject: [PATCH] adt: Visit() interval trees in sorted order and terminate early For all intervals [x, y), Visit will visit intervals in ascending order sorted by x. Also fixes a bug where Visit would not terminate the search when requested by the visitor function. --- pkg/adt/interval_tree.go | 21 +++++--- pkg/adt/interval_tree_test.go | 95 +++++++++++++++++++++++++++++++++++ 2 files changed, 108 insertions(+), 8 deletions(-) diff --git a/pkg/adt/interval_tree.go b/pkg/adt/interval_tree.go index afddcc7d6..a93625045 100644 --- a/pkg/adt/interval_tree.go +++ b/pkg/adt/interval_tree.go @@ -134,25 +134,29 @@ func (x *intervalNode) updateMax() { type nodeVisitor func(n *intervalNode) bool // visit will call a node visitor on each node that overlaps the given interval -func (x *intervalNode) visit(iv *Interval, nv nodeVisitor) { +func (x *intervalNode) visit(iv *Interval, nv nodeVisitor) bool { if x == nil { - return + return true } v := iv.Compare(&x.iv.Ivl) switch { case v < 0: - x.left.visit(iv, nv) + if !x.left.visit(iv, nv) { + return false + } case v > 0: maxiv := Interval{x.iv.Ivl.Begin, x.max} if maxiv.Compare(iv) == 0 { - x.left.visit(iv, nv) - x.right.visit(iv, nv) + if !x.left.visit(iv, nv) || !x.right.visit(iv, nv) { + return false + } } default: - nv(x) - x.left.visit(iv, nv) - x.right.visit(iv, nv) + if !x.left.visit(iv, nv) || !nv(x) || !x.right.visit(iv, nv) { + return false + } } + return true } type IntervalValue struct { @@ -406,6 +410,7 @@ func (ivt *IntervalTree) MaxHeight() int { type IntervalVisitor func(n *IntervalValue) bool // Visit calls a visitor function on every tree node intersecting the given interval. +// It will visit each interval [x, y) in ascending order sorted on x. func (ivt *IntervalTree) Visit(ivl Interval, ivv IntervalVisitor) { ivt.root.visit(&ivl, func(n *intervalNode) bool { return ivv(&n.iv) }) } diff --git a/pkg/adt/interval_tree_test.go b/pkg/adt/interval_tree_test.go index f8d038b72..f9770b3a5 100644 --- a/pkg/adt/interval_tree_test.go +++ b/pkg/adt/interval_tree_test.go @@ -136,3 +136,98 @@ func TestIntervalTreeRandom(t *testing.T) { t.Errorf("got ivt.Len() = %v, expected 0", ivt.Len()) } } + +// TestIntervalTreeSortedVisit tests that intervals are visited in sorted order. +func TestIntervalTreeSortedVisit(t *testing.T) { + tests := []struct { + ivls []Interval + visitRange Interval + }{ + { + ivls: []Interval{NewInt64Interval(1, 10), NewInt64Interval(2, 5), NewInt64Interval(3, 6)}, + visitRange: NewInt64Interval(0, 100), + }, + { + ivls: []Interval{NewInt64Interval(1, 10), NewInt64Interval(10, 12), NewInt64Interval(3, 6)}, + visitRange: NewInt64Interval(0, 100), + }, + { + ivls: []Interval{NewInt64Interval(2, 3), NewInt64Interval(3, 4), NewInt64Interval(6, 7), NewInt64Interval(5, 6)}, + visitRange: NewInt64Interval(0, 100), + }, + { + ivls: []Interval{ + NewInt64Interval(2, 3), + NewInt64Interval(2, 4), + NewInt64Interval(3, 7), + NewInt64Interval(2, 5), + NewInt64Interval(3, 8), + NewInt64Interval(3, 5), + }, + visitRange: NewInt64Interval(0, 100), + }, + } + for i, tt := range tests { + ivt := &IntervalTree{} + for _, ivl := range tt.ivls { + ivt.Insert(ivl, struct{}{}) + } + last := tt.ivls[0].Begin + count := 0 + chk := func(iv *IntervalValue) bool { + if last.Compare(iv.Ivl.Begin) > 0 { + t.Errorf("#%d: expected less than %d, got interval %+v", i, last, iv.Ivl) + } + last = iv.Ivl.Begin + count++ + return true + } + ivt.Visit(tt.visitRange, chk) + if count != len(tt.ivls) { + t.Errorf("#%d: did not cover all intervals. expected %d, got %d", i, len(tt.ivls), count) + } + } +} + +// TestIntervalTreeVisitExit tests that visiting can be stopped. +func TestIntervalTreeVisitExit(t *testing.T) { + ivls := []Interval{NewInt64Interval(1, 10), NewInt64Interval(2, 5), NewInt64Interval(3, 6), NewInt64Interval(4, 8)} + ivlRange := NewInt64Interval(0, 100) + tests := []struct { + f IntervalVisitor + + wcount int + }{ + { + f: func(n *IntervalValue) bool { return false }, + wcount: 1, + }, + { + f: func(n *IntervalValue) bool { return n.Ivl.Begin.Compare(ivls[0].Begin) <= 0 }, + wcount: 2, + }, + { + f: func(n *IntervalValue) bool { return n.Ivl.Begin.Compare(ivls[2].Begin) < 0 }, + wcount: 3, + }, + { + f: func(n *IntervalValue) bool { return true }, + wcount: 4, + }, + } + + for i, tt := range tests { + ivt := &IntervalTree{} + for _, ivl := range ivls { + ivt.Insert(ivl, struct{}{}) + } + count := 0 + ivt.Visit(ivlRange, func(n *IntervalValue) bool { + count++ + return tt.f(n) + }) + if count != tt.wcount { + t.Errorf("#%d: expected count %d, got %d", i, tt.wcount, count) + } + } +}