pkg/adt: fix interval tree black-height property based on rbtree

Author: xkey <xk33430@ly.com>
ref. https://github.com/etcd-io/etcd/pull/10978

Signed-off-by: Gyuho Lee <leegyuho@amazon.com>
This commit is contained in:
xkey 2019-08-05 11:40:18 +08:00 committed by Gyuho Lee
parent 9ff86fe516
commit bb7df24af4
2 changed files with 101 additions and 86 deletions

View File

@ -87,39 +87,39 @@ type intervalNode struct {
c rbcolor
}
func (x *intervalNode) color() rbcolor {
if x == nil {
func (x *intervalNode) color(sentinel *intervalNode) rbcolor {
if x == sentinel {
return black
}
return x.c
}
func (x *intervalNode) height() int {
if x == nil {
func (x *intervalNode) height(sentinel *intervalNode) int {
if x == sentinel {
return 0
}
ld := x.left.height()
rd := x.right.height()
ld := x.left.height(sentinel)
rd := x.right.height(sentinel)
if ld < rd {
return rd + 1
}
return ld + 1
}
func (x *intervalNode) min() *intervalNode {
for x.left != nil {
func (x *intervalNode) min(sentinel *intervalNode) *intervalNode {
for x.left != sentinel {
x = x.left
}
return x
}
// successor is the next in-order node in the tree
func (x *intervalNode) successor() *intervalNode {
if x.right != nil {
return x.right.min()
func (x *intervalNode) successor(sentinel *intervalNode) *intervalNode {
if x.right != sentinel {
return x.right.min(sentinel)
}
y := x.parent
for y != nil && x == y.right {
for y != sentinel && x == y.right {
x = y
y = y.parent
}
@ -127,14 +127,14 @@ func (x *intervalNode) successor() *intervalNode {
}
// updateMax updates the maximum values for a node and its ancestors
func (x *intervalNode) updateMax() {
for x != nil {
func (x *intervalNode) updateMax(sentinel *intervalNode) {
for x != sentinel {
oldmax := x.max
max := x.iv.Ivl.End
if x.left != nil && x.left.max.Compare(max) > 0 {
if x.left != sentinel && x.left.max.Compare(max) > 0 {
max = x.left.max
}
if x.right != nil && x.right.max.Compare(max) > 0 {
if x.right != sentinel && x.right.max.Compare(max) > 0 {
max = x.right.max
}
if oldmax.Compare(max) == 0 {
@ -148,25 +148,25 @@ func (x *intervalNode) updateMax() {
type nodeVisitor func(n *intervalNode) bool
// visit will call a node visitor on each node that overlaps the given interval
func (x *intervalNode) visit(iv *Interval, nv nodeVisitor) bool {
if x == nil {
func (x *intervalNode) visit(iv *Interval, sentinel *intervalNode, nv nodeVisitor) bool {
if x == sentinel {
return true
}
v := iv.Compare(&x.iv.Ivl)
switch {
case v < 0:
if !x.left.visit(iv, nv) {
if !x.left.visit(iv, sentinel, nv) {
return false
}
case v > 0:
maxiv := Interval{x.iv.Ivl.Begin, x.max}
if maxiv.Compare(iv) == 0 {
if !x.left.visit(iv, nv) || !x.right.visit(iv, nv) {
if !x.left.visit(iv, sentinel, nv) || !x.right.visit(iv, sentinel, nv) {
return false
}
}
default:
if !x.left.visit(iv, nv) || !nv(x) || !x.right.visit(iv, nv) {
if !x.left.visit(iv, sentinel, nv) || !nv(x) || !x.right.visit(iv, sentinel, nv) {
return false
}
}
@ -211,9 +211,18 @@ type IntervalTree interface {
// NewIntervalTree returns a new interval tree.
func NewIntervalTree() IntervalTree {
sentinel := &intervalNode{
iv: IntervalValue{},
max: nil,
left: nil,
right: nil,
parent: nil,
c: black,
}
return &intervalTree{
root: nil,
count: 0,
root: sentinel,
count: 0,
sentinel: sentinel,
}
}
@ -221,9 +230,11 @@ type intervalTree struct {
root *intervalNode
count int
// TODO: use 'sentinel' as a dummy object to simplify boundary conditions
// red-black NIL node
// use 'sentinel' as a dummy object to simplify boundary conditions
// use the sentinel to treat a nil child of a node x as an ordinary node whose parent is x
// use one shared sentinel to represent all nil leaves and the root's parent
sentinel *intervalNode
}
// TODO: make this consistent with textbook implementation
@ -263,24 +274,25 @@ type intervalTree struct {
// true if a node is in fact removed.
func (ivt *intervalTree) Delete(ivl Interval) bool {
z := ivt.find(ivl)
if z == nil {
if z == ivt.sentinel {
return false
}
y := z
if z.left != nil && z.right != nil {
y = z.successor()
if z.left != ivt.sentinel && z.right != ivt.sentinel {
y = z.successor(ivt.sentinel)
}
x := y.left
if x == nil {
x := ivt.sentinel
if y.left != ivt.sentinel {
x = y.left
} else if y.right != ivt.sentinel {
x = y.right
}
if x != nil {
x.parent = y.parent
}
if y.parent == nil {
x.parent = y.parent
if y.parent == ivt.sentinel {
ivt.root = x
} else {
if y == y.parent.left {
@ -288,14 +300,14 @@ func (ivt *intervalTree) Delete(ivl Interval) bool {
} else {
y.parent.right = x
}
y.parent.updateMax()
y.parent.updateMax(ivt.sentinel)
}
if y != z {
z.iv = y.iv
z.updateMax()
z.updateMax(ivt.sentinel)
}
if y.color() == black && x != nil {
if y.color(ivt.sentinel) == black {
ivt.deleteFixup(x)
}
@ -348,10 +360,10 @@ func (ivt *intervalTree) Delete(ivl Interval) bool {
// 40. x.color = BLACK
//
func (ivt *intervalTree) deleteFixup(x *intervalNode) {
for x != ivt.root && x.color() == black && x.parent != nil {
for x != ivt.root && x.color(ivt.sentinel) == black {
if x == x.parent.left { // line 3-20
w := x.parent.right
if w.color() == red {
if w.color(ivt.sentinel) == red {
w.c = black
x.parent.c = red
ivt.rotateLeft(x.parent)
@ -360,28 +372,26 @@ func (ivt *intervalTree) deleteFixup(x *intervalNode) {
if w == nil {
break
}
if w.left.color() == black && w.right.color() == black {
if w.left.color(ivt.sentinel) == black && w.right.color(ivt.sentinel) == black {
w.c = red
x = x.parent
} else {
if w.right.color() == black {
if w.right.color(ivt.sentinel) == black {
w.left.c = black
w.c = red
ivt.rotateRight(w)
w = x.parent.right
}
w.c = x.parent.color()
w.c = x.parent.color(ivt.sentinel)
x.parent.c = black
w.right.c = black
ivt.rotateLeft(x.parent)
x = ivt.root
}
} else { // line 22-38
// same as above but with left and right exchanged
w := x.parent.left
if w.color() == red {
if w.color(ivt.sentinel) == red {
w.c = black
x.parent.c = red
ivt.rotateRight(x.parent)
@ -390,17 +400,17 @@ func (ivt *intervalTree) deleteFixup(x *intervalNode) {
if w == nil {
break
}
if w.left.color() == black && w.right.color() == black {
if w.left.color(ivt.sentinel) == black && w.right.color(ivt.sentinel) == black {
w.c = red
x = x.parent
} else {
if w.left.color() == black {
if w.left.color(ivt.sentinel) == black {
w.right.c = black
w.c = red
ivt.rotateLeft(w)
w = x.parent.left
}
w.c = x.parent.color()
w.c = x.parent.color(ivt.sentinel)
x.parent.c = black
w.left.c = black
ivt.rotateRight(x.parent)
@ -419,9 +429,9 @@ func (ivt *intervalTree) createIntervalNode(ivl Interval, val interface{}) *inte
iv: IntervalValue{ivl, val},
max: ivl.End,
c: red,
left: nil,
right: nil,
parent: nil,
left: ivt.sentinel,
right: ivt.sentinel,
parent: ivt.sentinel,
}
}
@ -458,10 +468,10 @@ func (ivt *intervalTree) createIntervalNode(ivl Interval, val interface{}) *inte
// Insert adds a node with the given interval into the tree.
func (ivt *intervalTree) Insert(ivl Interval, val interface{}) {
var y *intervalNode
y := ivt.sentinel
z := ivt.createIntervalNode(ivl, val)
x := ivt.root
for x != nil {
for x != ivt.sentinel {
y = x
if z.iv.Ivl.Begin.Compare(x.iv.Ivl.Begin) < 0 {
x = x.left
@ -471,7 +481,7 @@ func (ivt *intervalTree) Insert(ivl Interval, val interface{}) {
}
z.parent = y
if y == nil {
if y == ivt.sentinel {
ivt.root = z
} else {
if z.iv.Ivl.Begin.Compare(y.iv.Ivl.Begin) < 0 {
@ -479,7 +489,7 @@ func (ivt *intervalTree) Insert(ivl Interval, val interface{}) {
} else {
y.right = z
}
y.updateMax()
y.updateMax(ivt.sentinel)
}
z.c = red
@ -522,10 +532,11 @@ func (ivt *intervalTree) Insert(ivl Interval, val interface{}) {
// 30. T.root.color = BLACK
//
func (ivt *intervalTree) insertFixup(z *intervalNode) {
for z.parent != nil && z.parent.parent != nil && z.parent.color() == red {
for z.parent.color(ivt.sentinel) == red {
if z.parent == z.parent.parent.left { // line 3-15
y := z.parent.parent.right
if y.color() == red {
if y.color(ivt.sentinel) == red {
y.c = black
z.parent.c = black
z.parent.parent.c = red
@ -542,7 +553,7 @@ func (ivt *intervalTree) insertFixup(z *intervalNode) {
} else { // line 16-28
// same as then with left/right exchanged
y := z.parent.parent.left
if y.color() == red {
if y.color(ivt.sentinel) == red {
y.c = black
z.parent.c = black
z.parent.parent.c = red
@ -588,23 +599,27 @@ func (ivt *intervalTree) insertFixup(z *intervalNode) {
// 18. x.p = y
//
func (ivt *intervalTree) rotateLeft(x *intervalNode) {
// rotateLeft x must have right child
if x.right == ivt.sentinel {
return
}
// line 2-3
y := x.right
x.right = y.left
// line 5-6
if y.left != nil {
if y.left != ivt.sentinel {
y.left.parent = x
}
x.updateMax()
x.updateMax(ivt.sentinel)
// line 10-15, 18
ivt.replaceParent(x, y)
// line 17
y.left = x
y.updateMax()
y.updateMax(ivt.sentinel)
}
// rotateRight moves x so it is right of its left child
@ -630,7 +645,8 @@ func (ivt *intervalTree) rotateLeft(x *intervalNode) {
// 18. x.p = y
//
func (ivt *intervalTree) rotateRight(x *intervalNode) {
if x == nil {
// rotateRight x must have left child
if x.left == ivt.sentinel {
return
}
@ -639,24 +655,23 @@ func (ivt *intervalTree) rotateRight(x *intervalNode) {
x.left = y.right
// line 5-6
if y.right != nil {
if y.right != ivt.sentinel {
y.right.parent = x
}
x.updateMax()
x.updateMax(ivt.sentinel)
// line 10-15, 18
ivt.replaceParent(x, y)
// line 17
y.right = x
y.updateMax()
y.updateMax(ivt.sentinel)
}
// replaceParent replaces x's parent with y
func (ivt *intervalTree) replaceParent(x *intervalNode, y *intervalNode) {
y.parent = x.parent
if x.parent == nil {
if x.parent == ivt.sentinel {
ivt.root = y
} else {
if x == x.parent.left {
@ -664,7 +679,7 @@ func (ivt *intervalTree) replaceParent(x *intervalNode, y *intervalNode) {
} else {
x.parent.right = y
}
x.parent.updateMax()
x.parent.updateMax(ivt.sentinel)
}
x.parent = y
}
@ -673,7 +688,7 @@ func (ivt *intervalTree) replaceParent(x *intervalNode, y *intervalNode) {
func (ivt *intervalTree) Len() int { return ivt.count }
// Height is the number of levels in the tree; one node has height 1.
func (ivt *intervalTree) Height() int { return ivt.root.height() }
func (ivt *intervalTree) Height() int { return ivt.root.height(ivt.sentinel) }
// MaxHeight is the expected maximum tree height given the number of nodes
func (ivt *intervalTree) MaxHeight() int {
@ -686,11 +701,12 @@ type IntervalVisitor func(n *IntervalValue) bool
// Visit calls a visitor function on every tree node intersecting the given interval.
// It will visit each interval [x, y) in ascending order sorted on x.
func (ivt *intervalTree) Visit(ivl Interval, ivv IntervalVisitor) {
ivt.root.visit(&ivl, func(n *intervalNode) bool { return ivv(&n.iv) })
ivt.root.visit(&ivl, ivt.sentinel, func(n *intervalNode) bool { return ivv(&n.iv) })
}
// find the exact node for a given interval
func (ivt *intervalTree) find(ivl Interval) (ret *intervalNode) {
func (ivt *intervalTree) find(ivl Interval) *intervalNode {
ret := ivt.sentinel
f := func(n *intervalNode) bool {
if n.iv.Ivl != ivl {
return true
@ -698,14 +714,14 @@ func (ivt *intervalTree) find(ivl Interval) (ret *intervalNode) {
ret = n
return false
}
ivt.root.visit(&ivl, f)
ivt.root.visit(&ivl, ivt.sentinel, f)
return ret
}
// Find gets the IntervalValue for the node matching the given interval
func (ivt *intervalTree) Find(ivl Interval) (ret *IntervalValue) {
n := ivt.find(ivl)
if n == nil {
if n == ivt.sentinel {
return nil
}
return &n.iv
@ -714,14 +730,14 @@ func (ivt *intervalTree) Find(ivl Interval) (ret *IntervalValue) {
// Intersects returns true if there is some tree node intersecting the given interval.
func (ivt *intervalTree) Intersects(iv Interval) bool {
x := ivt.root
for x != nil && iv.Compare(&x.iv.Ivl) != 0 {
if x.left != nil && x.left.max.Compare(iv.Begin) > 0 {
for x != ivt.sentinel && iv.Compare(&x.iv.Ivl) != 0 {
if x.left != ivt.sentinel && x.left.max.Compare(iv.Begin) > 0 {
x = x.left
} else {
x = x.right
}
}
return x != nil
return x != ivt.sentinel
}
// Contains returns true if the interval tree's keys cover the entire given interval.
@ -789,7 +805,7 @@ func (vi visitedInterval) String() string {
// visitLevel traverses tree in level order.
// used for testing
func (ivt *intervalTree) visitLevel() []visitedInterval {
if ivt.root == nil {
if ivt.root == ivt.sentinel {
return nil
}
@ -804,22 +820,21 @@ func (ivt *intervalTree) visitLevel() []visitedInterval {
f := queue[0]
queue = queue[1:]
ivt := visitedInterval{
vi := visitedInterval{
root: f.node.iv.Ivl,
color: f.node.color(),
color: f.node.color(ivt.sentinel),
depth: f.depth,
}
if f.node.left != nil {
ivt.left = f.node.left.iv.Ivl
if f.node.left != ivt.sentinel {
vi.left = f.node.left.iv.Ivl
queue = append(queue, pair{f.node.left, f.depth + 1})
}
if f.node.right != nil {
ivt.right = f.node.right.iv.Ivl
if f.node.right != ivt.sentinel {
vi.right = f.node.right.iv.Ivl
queue = append(queue, pair{f.node.right, f.depth + 1})
}
rs = append(rs, ivt)
rs = append(rs, vi)
}
return rs

View File

@ -298,7 +298,7 @@ func TestIntervalTreeDelete(t *testing.T) {
// / \ /
// [238,239] [292,293] [953,954]
//
t.Logf("level order after deleting '11' expected %v, got %v", expectedAfterDelete11, visitsAfterDelete11)
t.Fatalf("level order after deleting '11' expected %v, got %v", expectedAfterDelete11, visitsAfterDelete11)
}
}