From c0b06a7a32fd48e2f922e31bd4dc7e9a8d2c557c Mon Sep 17 00:00:00 2001
From: Anthony Romano <anthony.romano@coreos.com>
Date: Sun, 28 Feb 2016 22:11:26 -0800
Subject: [PATCH] pkg/adt: interval tree

---
 pkg/adt/interval_tree.go      | 526 ++++++++++++++++++++++++++++++++++
 pkg/adt/interval_tree_test.go | 138 +++++++++
 2 files changed, 664 insertions(+)
 create mode 100644 pkg/adt/interval_tree.go
 create mode 100644 pkg/adt/interval_tree_test.go

diff --git a/pkg/adt/interval_tree.go b/pkg/adt/interval_tree.go
new file mode 100644
index 000000000..465c6200c
--- /dev/null
+++ b/pkg/adt/interval_tree.go
@@ -0,0 +1,526 @@
+// Copyright 2016 CoreOS, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package adt
+
+import (
+	"math"
+)
+
+// Comparable is an interface for trichotomic comparisons.
+type Comparable interface {
+	// Compare gives the result of a 3-way comparison
+	// a.Compare(b) = 1 => a > b
+	// a.Compare(b) = 0 => a == b
+	// a.Compare(b) = -1 => a < b
+	Compare(c Comparable) int
+}
+
+type rbcolor bool
+
+const black = true
+const red = false
+
+// Interval implements a Comparable interval [begin, end)
+// TODO: support different sorts of intervals: (a,b), [a,b], (a, b]
+type Interval struct {
+	Begin Comparable
+	End   Comparable
+}
+
+// Compare on an interval gives == if the interval overlaps.
+func (ivl *Interval) Compare(c Comparable) int {
+	ivl2 := c.(*Interval)
+	ivbCmpBegin := ivl.Begin.Compare(ivl2.Begin)
+	ivbCmpEnd := ivl.Begin.Compare(ivl2.End)
+	iveCmpBegin := ivl.End.Compare(ivl2.Begin)
+
+	// ivl is left of ivl2
+	if ivbCmpBegin < 0 && iveCmpBegin <= 0 {
+		return -1
+	}
+
+	// iv is right of iv2
+	if ivbCmpEnd >= 0 {
+		return 1
+	}
+
+	return 0
+}
+
+type intervalNode struct {
+	// iv is the interval-value pair entry.
+	iv IntervalValue
+	// max endpoint of all descendent nodes.
+	max Comparable
+	// left and right are sorted by low endpoint of key interval
+	left, right *intervalNode
+	// parent is the direct ancestor of the node
+	parent *intervalNode
+	c      rbcolor
+}
+
+func (x *intervalNode) color() rbcolor {
+	if x == nil {
+		return black
+	}
+	return x.c
+}
+
+func (n *intervalNode) height() int {
+	if n == nil {
+		return 0
+	}
+	ld := n.left.height()
+	rd := n.right.height()
+	if ld < rd {
+		return rd + 1
+	}
+	return ld + 1
+}
+
+func (x *intervalNode) min() *intervalNode {
+	for x.left != nil {
+		x = x.left
+	}
+	return x
+}
+
+// successor is the next in-order node in the tree
+func (x *intervalNode) successor() *intervalNode {
+	if x.right != nil {
+		return x.right.min()
+	}
+	y := x.parent
+	for y != nil && x == y.right {
+		x = y
+		y = y.parent
+	}
+	return y
+}
+
+// updateMax updates the maximum values for a node and its ancestors
+func (x *intervalNode) updateMax() {
+	for x != nil {
+		oldmax := x.max
+		max := x.iv.Ivl.End
+		if x.left != nil && x.left.max.Compare(max) > 0 {
+			max = x.left.max
+		}
+		if x.right != nil && x.right.max.Compare(max) > 0 {
+			max = x.right.max
+		}
+		if oldmax.Compare(max) == 0 {
+			break
+		}
+		x.max = max
+		x = x.parent
+	}
+}
+
+type nodeVisitor func(n *intervalNode) bool
+
+// visit will call a node visitor on each node that overlaps the given interval
+func (x *intervalNode) visit(iv *Interval, nv nodeVisitor) {
+	if x == nil {
+		return
+	}
+	v := iv.Compare(&x.iv.Ivl)
+	switch {
+	case v < 0:
+		x.left.visit(iv, nv)
+	case v > 0:
+		maxiv := Interval{x.iv.Ivl.Begin, x.max}
+		if maxiv.Compare(iv) == 0 {
+			x.left.visit(iv, nv)
+			x.right.visit(iv, nv)
+		}
+	default:
+		nv(x)
+		x.left.visit(iv, nv)
+		x.right.visit(iv, nv)
+	}
+}
+
+type IntervalValue struct {
+	Ivl Interval
+	Val interface{}
+}
+
+// IntervalTree represents a (mostly) textbook implementation of the
+// "Introduction to Algorithms" (Cormen et al, 2nd ed.) chapter 13 red-black tree
+// and chapter 14.3 interval tree with search supporting "stabbing queries".
+type IntervalTree struct {
+	root  *intervalNode
+	count int
+}
+
+// Delete removes the node with the given interval from the tree, returning
+// true if a node is in fact removed.
+func (ivt *IntervalTree) Delete(ivl Interval) bool {
+	z := ivt.find(ivl)
+	if z == nil {
+		return false
+	}
+
+	y := z
+	if z.left != nil && z.right != nil {
+		y = z.successor()
+	}
+
+	x := y.left
+	if x == nil {
+		x = y.right
+	}
+	if x != nil {
+		x.parent = y.parent
+	}
+
+	if y.parent == nil {
+		ivt.root = x
+	} else {
+		if y == y.parent.left {
+			y.parent.left = x
+		} else {
+			y.parent.right = x
+		}
+		y.parent.updateMax()
+	}
+	if y != z {
+		z.iv = y.iv
+		z.updateMax()
+	}
+
+	if y.color() == black && x != nil {
+		ivt.deleteFixup(x)
+	}
+
+	ivt.count--
+	return true
+}
+
+func (ivt *IntervalTree) deleteFixup(x *intervalNode) {
+	for x != ivt.root && x.color() == black && x.parent != nil {
+		if x == x.parent.left {
+			w := x.parent.right
+			if w.color() == red {
+				w.c = black
+				x.parent.c = red
+				ivt.rotateLeft(x.parent)
+				w = x.parent.right
+			}
+			if w == nil {
+				break
+			}
+			if w.left.color() == black && w.right.color() == black {
+				w.c = red
+				x = x.parent
+			} else {
+				if w.right.color() == black {
+					w.left.c = black
+					w.c = red
+					ivt.rotateRight(w)
+					w = x.parent.right
+				}
+				w.c = x.parent.color()
+				x.parent.c = black
+				w.right.c = black
+				ivt.rotateLeft(x.parent)
+				x = ivt.root
+			}
+		} else {
+			// same as above but with left and right exchanged
+			w := x.parent.left
+			if w.color() == red {
+				w.c = black
+				x.parent.c = red
+				ivt.rotateRight(x.parent)
+				w = x.parent.left
+			}
+			if w == nil {
+				break
+			}
+			if w.left.color() == black && w.right.color() == black {
+				w.c = red
+				x = x.parent
+			} else {
+				if w.left.color() == black {
+					w.right.c = black
+					w.c = red
+					ivt.rotateLeft(w)
+					w = x.parent.left
+				}
+				w.c = x.parent.color()
+				x.parent.c = black
+				w.left.c = black
+				ivt.rotateRight(x.parent)
+				x = ivt.root
+			}
+		}
+	}
+	if x != nil {
+		x.c = black
+	}
+}
+
+// Insert adds a node with the given interval into the tree.
+func (ivt *IntervalTree) Insert(ivl Interval, val interface{}) {
+	var y *intervalNode
+	z := &intervalNode{iv: IntervalValue{ivl, val}, max: ivl.End, c: red}
+	x := ivt.root
+	for x != nil {
+		y = x
+		if z.iv.Ivl.Begin.Compare(x.iv.Ivl.Begin) < 0 {
+			x = x.left
+		} else {
+			x = x.right
+		}
+	}
+
+	z.parent = y
+	if y == nil {
+		ivt.root = z
+	} else {
+		if z.iv.Ivl.Begin.Compare(y.iv.Ivl.Begin) < 0 {
+			y.left = z
+		} else {
+			y.right = z
+		}
+		y.updateMax()
+	}
+	z.c = red
+	ivt.insertFixup(z)
+	ivt.count++
+}
+
+func (ivt *IntervalTree) insertFixup(z *intervalNode) {
+	for z.parent != nil && z.parent.parent != nil && z.parent.color() == red {
+		if z.parent == z.parent.parent.left {
+			y := z.parent.parent.right
+			if y.color() == red {
+				y.c = black
+				z.parent.c = black
+				z.parent.parent.c = red
+				z = z.parent.parent
+			} else {
+				if z == z.parent.right {
+					z = z.parent
+					ivt.rotateLeft(z)
+				}
+				z.parent.c = black
+				z.parent.parent.c = red
+				ivt.rotateRight(z.parent.parent)
+			}
+		} else {
+			// same as then with left/right exchanged
+			y := z.parent.parent.left
+			if y.color() == red {
+				y.c = black
+				z.parent.c = black
+				z.parent.parent.c = red
+				z = z.parent.parent
+			} else {
+				if z == z.parent.left {
+					z = z.parent
+					ivt.rotateRight(z)
+				}
+				z.parent.c = black
+				z.parent.parent.c = red
+				ivt.rotateLeft(z.parent.parent)
+			}
+		}
+	}
+	ivt.root.c = black
+}
+
+// rotateLeft moves x so it is left of its right child
+func (ivt *IntervalTree) rotateLeft(x *intervalNode) {
+	y := x.right
+	x.right = y.left
+	if y.left != nil {
+		y.left.parent = x
+	}
+	x.updateMax()
+	ivt.replaceParent(x, y)
+	y.left = x
+	y.updateMax()
+}
+
+// rotateLeft moves x so it is right of its left child
+func (ivt *IntervalTree) rotateRight(x *intervalNode) {
+	if x == nil {
+		return
+	}
+	y := x.left
+	x.left = y.right
+	if y.right != nil {
+		y.right.parent = x
+	}
+	x.updateMax()
+	ivt.replaceParent(x, y)
+	y.right = x
+	y.updateMax()
+}
+
+// replaceParent replaces x's parent with y
+func (ivt *IntervalTree) replaceParent(x *intervalNode, y *intervalNode) {
+	y.parent = x.parent
+	if x.parent == nil {
+		ivt.root = y
+	} else {
+		if x == x.parent.left {
+			x.parent.left = y
+		} else {
+			x.parent.right = y
+		}
+		x.parent.updateMax()
+	}
+	x.parent = y
+}
+
+// Len gives the number of elements in the tree
+func (ivt *IntervalTree) Len() int { return ivt.count }
+
+// Height is the number of levels in the tree; one node has height 1.
+func (ivt *IntervalTree) Height() int { return ivt.root.height() }
+
+// MaxHeight is the expected maximum tree height given the number of nodes
+func (ivt *IntervalTree) MaxHeight() int {
+	return int((2 * math.Log2(float64(ivt.Len()+1))) + 0.5)
+}
+
+// InternalVisitor is used on tree searchs; return false to stop searching.
+type IntervalVisitor func(n *IntervalValue) bool
+
+// Visit calls a visitor function on every tree node intersecting the given interval.
+func (ivt *IntervalTree) Visit(ivl Interval, ivv IntervalVisitor) {
+	ivt.root.visit(&ivl, func(n *intervalNode) bool { return ivv(&n.iv) })
+}
+
+// find the exact node for a given interval
+func (ivt *IntervalTree) find(ivl Interval) (ret *intervalNode) {
+	f := func(n *intervalNode) bool {
+		if n.iv.Ivl != ivl {
+			return true
+		}
+		ret = n
+		return false
+	}
+	ivt.root.visit(&ivl, f)
+	return ret
+}
+
+// Find gets the IntervalValue for the node matching the given interval
+func (ivt *IntervalTree) Find(ivl Interval) (ret *IntervalValue) {
+	n := ivt.find(ivl)
+	if n == nil {
+		return nil
+	}
+	return &n.iv
+}
+
+// Contains returns true if there is some tree node intersecting the given interval.
+func (ivt *IntervalTree) Contains(iv Interval) bool {
+	x := ivt.root
+	for x != nil && iv.Compare(&x.iv.Ivl) != 0 {
+		if x.left != nil && x.left.max.Compare(iv.Begin) > 0 {
+			x = x.left
+		} else {
+			x = x.right
+		}
+	}
+	return x != nil
+}
+
+// Stab returns a slice with all elements in the tree intersecting the interval.
+func (ivt *IntervalTree) Stab(iv Interval) (ivs []*IntervalValue) {
+	f := func(n *IntervalValue) bool { ivs = append(ivs, n); return true }
+	ivt.Visit(iv, f)
+	return ivs
+}
+
+type StringComparable string
+
+func (s StringComparable) Compare(c Comparable) int {
+	sc := c.(StringComparable)
+	if s < sc {
+		return -1
+	}
+	if s > sc {
+		return 1
+	}
+	return 0
+}
+
+func NewStringInterval(begin, end string) Interval {
+	return Interval{StringComparable(begin), StringComparable(end)}
+}
+
+func NewStringPoint(s string) Interval {
+	return Interval{StringComparable(s), StringComparable(s + "\x00")}
+}
+
+// StringAffineComparable treats "" as > all other strings
+type StringAffineComparable string
+
+func (s StringAffineComparable) Compare(c Comparable) int {
+	sc := c.(StringAffineComparable)
+
+	if len(s) == 0 {
+		if len(sc) == 0 {
+			return 0
+		}
+		return 1
+	}
+	if len(sc) == 0 {
+		return -1
+	}
+
+	if s < sc {
+		return -1
+	}
+	if s > sc {
+		return 1
+	}
+	return 0
+}
+
+func NewStringAffineInterval(begin, end string) Interval {
+	return Interval{StringAffineComparable(begin), StringAffineComparable(end)}
+}
+func NewStringAffinePoint(s string) Interval {
+	return NewStringAffineInterval(s, s+"\x00")
+}
+
+func NewInt64Interval(a int64, b int64) Interval {
+	return Interval{Int64Comparable(a), Int64Comparable(b)}
+}
+
+func NewInt64Point(a int64) Interval {
+	return Interval{Int64Comparable(a), Int64Comparable(a + 1)}
+}
+
+type Int64Comparable int64
+
+func (v Int64Comparable) Compare(c Comparable) int {
+	vc := c.(Int64Comparable)
+	cmp := v - vc
+	if cmp < 0 {
+		return -1
+	}
+	if cmp > 0 {
+		return 1
+	}
+	return 0
+}
diff --git a/pkg/adt/interval_tree_test.go b/pkg/adt/interval_tree_test.go
new file mode 100644
index 000000000..6cc540265
--- /dev/null
+++ b/pkg/adt/interval_tree_test.go
@@ -0,0 +1,138 @@
+// Copyright 2016 CoreOS, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package adt
+
+import (
+	"math/rand"
+	"testing"
+	"time"
+)
+
+func TestIntervalTreeContains(t *testing.T) {
+	ivt := &IntervalTree{}
+	ivt.Insert(NewStringInterval("1", "3"), 123)
+
+	if ivt.Contains(NewStringPoint("0")) {
+		t.Errorf("contains 0")
+	}
+	if !ivt.Contains(NewStringPoint("1")) {
+		t.Errorf("missing 1")
+	}
+	if !ivt.Contains(NewStringPoint("11")) {
+		t.Errorf("missing 11")
+	}
+	if !ivt.Contains(NewStringPoint("2")) {
+		t.Errorf("missing 2")
+	}
+	if ivt.Contains(NewStringPoint("3")) {
+		t.Errorf("contains 3")
+	}
+}
+
+func TestIntervalTreeStringAffine(t *testing.T) {
+	ivt := &IntervalTree{}
+	ivt.Insert(NewStringAffineInterval("8", ""), 123)
+	if !ivt.Contains(NewStringAffinePoint("9")) {
+		t.Errorf("missing 9")
+	}
+	if ivt.Contains(NewStringAffinePoint("7")) {
+		t.Errorf("contains 7")
+	}
+}
+
+func TestIntervalTreeStab(t *testing.T) {
+	ivt := &IntervalTree{}
+	ivt.Insert(NewStringInterval("0", "1"), 123)
+	ivt.Insert(NewStringInterval("0", "2"), 456)
+	ivt.Insert(NewStringInterval("5", "6"), 789)
+	ivt.Insert(NewStringInterval("6", "8"), 999)
+	ivt.Insert(NewStringInterval("0", "3"), 0)
+
+	if ivt.root.max.Compare(StringComparable("8")) != 0 {
+		t.Fatalf("wrong root max got %v, expected 8", ivt.root.max)
+	}
+	if x := len(ivt.Stab(NewStringPoint("0"))); x != 3 {
+		t.Errorf("got %d, expected 3", x)
+	}
+	if x := len(ivt.Stab(NewStringPoint("1"))); x != 2 {
+		t.Errorf("got %d, expected 2", x)
+	}
+	if x := len(ivt.Stab(NewStringPoint("2"))); x != 1 {
+		t.Errorf("got %d, expected 1", x)
+	}
+	if x := len(ivt.Stab(NewStringPoint("3"))); x != 0 {
+		t.Errorf("got %d, expected 0", x)
+	}
+	if x := len(ivt.Stab(NewStringPoint("5"))); x != 1 {
+		t.Errorf("got %d, expected 1", x)
+	}
+	if x := len(ivt.Stab(NewStringPoint("55"))); x != 1 {
+		t.Errorf("got %d, expected 1", x)
+	}
+	if x := len(ivt.Stab(NewStringPoint("6"))); x != 1 {
+		t.Errorf("got %d, expected 1", x)
+	}
+}
+
+type xy struct {
+	x int64
+	y int64
+}
+
+func TestIntervalTreeRandom(t *testing.T) {
+	// generate unique intervals
+	ivs := make(map[xy]struct{})
+	ivt := &IntervalTree{}
+	maxv := 128
+	rand.Seed(time.Now().UnixNano())
+
+	for i := rand.Intn(maxv) + 1; i != 0; i-- {
+		x, y := int64(rand.Intn(maxv)), int64(rand.Intn(maxv))
+		if x > y {
+			t := x
+			x = y
+			y = t
+		} else if x == y {
+			y++
+		}
+		iv := xy{x, y}
+		if _, ok := ivs[iv]; ok {
+			// don't double insert
+			continue
+		}
+		ivt.Insert(NewInt64Interval(x, y), 123)
+		ivs[iv] = struct{}{}
+	}
+
+	for ab := range ivs {
+		for xy := range ivs {
+			v := xy.x + int64(rand.Intn(int(xy.y-xy.x)))
+			if slen := len(ivt.Stab(NewInt64Point(v))); slen == 0 {
+				t.Fatalf("expected %v stab non-zero for [%+v)", v, xy)
+			}
+			if !ivt.Contains(NewInt64Point(v)) {
+				t.Fatalf("did not get %d as expected for [%+v)", v, xy)
+			}
+		}
+		if !ivt.Delete(NewInt64Interval(ab.x, ab.y)) {
+			t.Errorf("did not delete %v as expected", ab)
+		}
+		delete(ivs, ab)
+	}
+
+	if ivt.Len() != 0 {
+		t.Errorf("got ivt.Len() = %v, expected 0", ivt.Len())
+	}
+}