pkg/adt: fix interval tree black-height property based on rbtree

Author: xkey <xk33430@ly.com>
ref. https://github.com/etcd-io/etcd/pull/10978

Signed-off-by: Gyuho Lee <leegyuho@amazon.com>
This commit is contained in:
xkey 2019-08-05 11:40:18 +08:00 committed by Gyuho Lee
parent 9a2af7378a
commit 003362ef8e
2 changed files with 101 additions and 86 deletions

View File

@ -87,39 +87,39 @@ type intervalNode struct {
c rbcolor c rbcolor
} }
func (x *intervalNode) color() rbcolor { func (x *intervalNode) color(sentinel *intervalNode) rbcolor {
if x == nil { if x == sentinel {
return black return black
} }
return x.c return x.c
} }
func (x *intervalNode) height() int { func (x *intervalNode) height(sentinel *intervalNode) int {
if x == nil { if x == sentinel {
return 0 return 0
} }
ld := x.left.height() ld := x.left.height(sentinel)
rd := x.right.height() rd := x.right.height(sentinel)
if ld < rd { if ld < rd {
return rd + 1 return rd + 1
} }
return ld + 1 return ld + 1
} }
func (x *intervalNode) min() *intervalNode { func (x *intervalNode) min(sentinel *intervalNode) *intervalNode {
for x.left != nil { for x.left != sentinel {
x = x.left x = x.left
} }
return x return x
} }
// successor is the next in-order node in the tree // successor is the next in-order node in the tree
func (x *intervalNode) successor() *intervalNode { func (x *intervalNode) successor(sentinel *intervalNode) *intervalNode {
if x.right != nil { if x.right != sentinel {
return x.right.min() return x.right.min(sentinel)
} }
y := x.parent y := x.parent
for y != nil && x == y.right { for y != sentinel && x == y.right {
x = y x = y
y = y.parent y = y.parent
} }
@ -127,14 +127,14 @@ func (x *intervalNode) successor() *intervalNode {
} }
// updateMax updates the maximum values for a node and its ancestors // updateMax updates the maximum values for a node and its ancestors
func (x *intervalNode) updateMax() { func (x *intervalNode) updateMax(sentinel *intervalNode) {
for x != nil { for x != sentinel {
oldmax := x.max oldmax := x.max
max := x.iv.Ivl.End max := x.iv.Ivl.End
if x.left != nil && x.left.max.Compare(max) > 0 { if x.left != sentinel && x.left.max.Compare(max) > 0 {
max = x.left.max max = x.left.max
} }
if x.right != nil && x.right.max.Compare(max) > 0 { if x.right != sentinel && x.right.max.Compare(max) > 0 {
max = x.right.max max = x.right.max
} }
if oldmax.Compare(max) == 0 { if oldmax.Compare(max) == 0 {
@ -148,25 +148,25 @@ func (x *intervalNode) updateMax() {
type nodeVisitor func(n *intervalNode) bool type nodeVisitor func(n *intervalNode) bool
// visit will call a node visitor on each node that overlaps the given interval // visit will call a node visitor on each node that overlaps the given interval
func (x *intervalNode) visit(iv *Interval, nv nodeVisitor) bool { func (x *intervalNode) visit(iv *Interval, sentinel *intervalNode, nv nodeVisitor) bool {
if x == nil { if x == sentinel {
return true return true
} }
v := iv.Compare(&x.iv.Ivl) v := iv.Compare(&x.iv.Ivl)
switch { switch {
case v < 0: case v < 0:
if !x.left.visit(iv, nv) { if !x.left.visit(iv, sentinel, nv) {
return false return false
} }
case v > 0: case v > 0:
maxiv := Interval{x.iv.Ivl.Begin, x.max} maxiv := Interval{x.iv.Ivl.Begin, x.max}
if maxiv.Compare(iv) == 0 { if maxiv.Compare(iv) == 0 {
if !x.left.visit(iv, nv) || !x.right.visit(iv, nv) { if !x.left.visit(iv, sentinel, nv) || !x.right.visit(iv, sentinel, nv) {
return false return false
} }
} }
default: default:
if !x.left.visit(iv, nv) || !nv(x) || !x.right.visit(iv, nv) { if !x.left.visit(iv, sentinel, nv) || !nv(x) || !x.right.visit(iv, sentinel, nv) {
return false return false
} }
} }
@ -211,9 +211,18 @@ type IntervalTree interface {
// NewIntervalTree returns a new interval tree. // NewIntervalTree returns a new interval tree.
func NewIntervalTree() IntervalTree { func NewIntervalTree() IntervalTree {
sentinel := &intervalNode{
iv: IntervalValue{},
max: nil,
left: nil,
right: nil,
parent: nil,
c: black,
}
return &intervalTree{ return &intervalTree{
root: nil, root: sentinel,
count: 0, count: 0,
sentinel: sentinel,
} }
} }
@ -221,9 +230,11 @@ type intervalTree struct {
root *intervalNode root *intervalNode
count int count int
// TODO: use 'sentinel' as a dummy object to simplify boundary conditions // red-black NIL node
// use 'sentinel' as a dummy object to simplify boundary conditions
// use the sentinel to treat a nil child of a node x as an ordinary node whose parent is x // use the sentinel to treat a nil child of a node x as an ordinary node whose parent is x
// use one shared sentinel to represent all nil leaves and the root's parent // use one shared sentinel to represent all nil leaves and the root's parent
sentinel *intervalNode
} }
// TODO: make this consistent with textbook implementation // TODO: make this consistent with textbook implementation
@ -263,24 +274,25 @@ type intervalTree struct {
// true if a node is in fact removed. // true if a node is in fact removed.
func (ivt *intervalTree) Delete(ivl Interval) bool { func (ivt *intervalTree) Delete(ivl Interval) bool {
z := ivt.find(ivl) z := ivt.find(ivl)
if z == nil { if z == ivt.sentinel {
return false return false
} }
y := z y := z
if z.left != nil && z.right != nil { if z.left != ivt.sentinel && z.right != ivt.sentinel {
y = z.successor() y = z.successor(ivt.sentinel)
} }
x := y.left x := ivt.sentinel
if x == nil { if y.left != ivt.sentinel {
x = y.left
} else if y.right != ivt.sentinel {
x = y.right x = y.right
} }
if x != nil {
x.parent = y.parent
}
if y.parent == nil { x.parent = y.parent
if y.parent == ivt.sentinel {
ivt.root = x ivt.root = x
} else { } else {
if y == y.parent.left { if y == y.parent.left {
@ -288,14 +300,14 @@ func (ivt *intervalTree) Delete(ivl Interval) bool {
} else { } else {
y.parent.right = x y.parent.right = x
} }
y.parent.updateMax() y.parent.updateMax(ivt.sentinel)
} }
if y != z { if y != z {
z.iv = y.iv z.iv = y.iv
z.updateMax() z.updateMax(ivt.sentinel)
} }
if y.color() == black && x != nil { if y.color(ivt.sentinel) == black {
ivt.deleteFixup(x) ivt.deleteFixup(x)
} }
@ -348,10 +360,10 @@ func (ivt *intervalTree) Delete(ivl Interval) bool {
// 40. x.color = BLACK // 40. x.color = BLACK
// //
func (ivt *intervalTree) deleteFixup(x *intervalNode) { func (ivt *intervalTree) deleteFixup(x *intervalNode) {
for x != ivt.root && x.color() == black && x.parent != nil { for x != ivt.root && x.color(ivt.sentinel) == black {
if x == x.parent.left { // line 3-20 if x == x.parent.left { // line 3-20
w := x.parent.right w := x.parent.right
if w.color() == red { if w.color(ivt.sentinel) == red {
w.c = black w.c = black
x.parent.c = red x.parent.c = red
ivt.rotateLeft(x.parent) ivt.rotateLeft(x.parent)
@ -360,28 +372,26 @@ func (ivt *intervalTree) deleteFixup(x *intervalNode) {
if w == nil { if w == nil {
break break
} }
if w.left.color() == black && w.right.color() == black { if w.left.color(ivt.sentinel) == black && w.right.color(ivt.sentinel) == black {
w.c = red w.c = red
x = x.parent x = x.parent
} else { } else {
if w.right.color() == black { if w.right.color(ivt.sentinel) == black {
w.left.c = black w.left.c = black
w.c = red w.c = red
ivt.rotateRight(w) ivt.rotateRight(w)
w = x.parent.right w = x.parent.right
} }
w.c = x.parent.color() w.c = x.parent.color(ivt.sentinel)
x.parent.c = black x.parent.c = black
w.right.c = black w.right.c = black
ivt.rotateLeft(x.parent) ivt.rotateLeft(x.parent)
x = ivt.root x = ivt.root
} }
} else { // line 22-38 } else { // line 22-38
// same as above but with left and right exchanged // same as above but with left and right exchanged
w := x.parent.left w := x.parent.left
if w.color() == red { if w.color(ivt.sentinel) == red {
w.c = black w.c = black
x.parent.c = red x.parent.c = red
ivt.rotateRight(x.parent) ivt.rotateRight(x.parent)
@ -390,17 +400,17 @@ func (ivt *intervalTree) deleteFixup(x *intervalNode) {
if w == nil { if w == nil {
break break
} }
if w.left.color() == black && w.right.color() == black { if w.left.color(ivt.sentinel) == black && w.right.color(ivt.sentinel) == black {
w.c = red w.c = red
x = x.parent x = x.parent
} else { } else {
if w.left.color() == black { if w.left.color(ivt.sentinel) == black {
w.right.c = black w.right.c = black
w.c = red w.c = red
ivt.rotateLeft(w) ivt.rotateLeft(w)
w = x.parent.left w = x.parent.left
} }
w.c = x.parent.color() w.c = x.parent.color(ivt.sentinel)
x.parent.c = black x.parent.c = black
w.left.c = black w.left.c = black
ivt.rotateRight(x.parent) ivt.rotateRight(x.parent)
@ -419,9 +429,9 @@ func (ivt *intervalTree) createIntervalNode(ivl Interval, val interface{}) *inte
iv: IntervalValue{ivl, val}, iv: IntervalValue{ivl, val},
max: ivl.End, max: ivl.End,
c: red, c: red,
left: nil, left: ivt.sentinel,
right: nil, right: ivt.sentinel,
parent: nil, parent: ivt.sentinel,
} }
} }
@ -458,10 +468,10 @@ func (ivt *intervalTree) createIntervalNode(ivl Interval, val interface{}) *inte
// Insert adds a node with the given interval into the tree. // Insert adds a node with the given interval into the tree.
func (ivt *intervalTree) Insert(ivl Interval, val interface{}) { func (ivt *intervalTree) Insert(ivl Interval, val interface{}) {
var y *intervalNode y := ivt.sentinel
z := ivt.createIntervalNode(ivl, val) z := ivt.createIntervalNode(ivl, val)
x := ivt.root x := ivt.root
for x != nil { for x != ivt.sentinel {
y = x y = x
if z.iv.Ivl.Begin.Compare(x.iv.Ivl.Begin) < 0 { if z.iv.Ivl.Begin.Compare(x.iv.Ivl.Begin) < 0 {
x = x.left x = x.left
@ -471,7 +481,7 @@ func (ivt *intervalTree) Insert(ivl Interval, val interface{}) {
} }
z.parent = y z.parent = y
if y == nil { if y == ivt.sentinel {
ivt.root = z ivt.root = z
} else { } else {
if z.iv.Ivl.Begin.Compare(y.iv.Ivl.Begin) < 0 { if z.iv.Ivl.Begin.Compare(y.iv.Ivl.Begin) < 0 {
@ -479,7 +489,7 @@ func (ivt *intervalTree) Insert(ivl Interval, val interface{}) {
} else { } else {
y.right = z y.right = z
} }
y.updateMax() y.updateMax(ivt.sentinel)
} }
z.c = red z.c = red
@ -522,10 +532,11 @@ func (ivt *intervalTree) Insert(ivl Interval, val interface{}) {
// 30. T.root.color = BLACK // 30. T.root.color = BLACK
// //
func (ivt *intervalTree) insertFixup(z *intervalNode) { func (ivt *intervalTree) insertFixup(z *intervalNode) {
for z.parent != nil && z.parent.parent != nil && z.parent.color() == red { for z.parent.color(ivt.sentinel) == red {
if z.parent == z.parent.parent.left { // line 3-15 if z.parent == z.parent.parent.left { // line 3-15
y := z.parent.parent.right y := z.parent.parent.right
if y.color() == red { if y.color(ivt.sentinel) == red {
y.c = black y.c = black
z.parent.c = black z.parent.c = black
z.parent.parent.c = red z.parent.parent.c = red
@ -542,7 +553,7 @@ func (ivt *intervalTree) insertFixup(z *intervalNode) {
} else { // line 16-28 } else { // line 16-28
// same as then with left/right exchanged // same as then with left/right exchanged
y := z.parent.parent.left y := z.parent.parent.left
if y.color() == red { if y.color(ivt.sentinel) == red {
y.c = black y.c = black
z.parent.c = black z.parent.c = black
z.parent.parent.c = red z.parent.parent.c = red
@ -588,23 +599,27 @@ func (ivt *intervalTree) insertFixup(z *intervalNode) {
// 18. x.p = y // 18. x.p = y
// //
func (ivt *intervalTree) rotateLeft(x *intervalNode) { func (ivt *intervalTree) rotateLeft(x *intervalNode) {
// rotateLeft x must have right child
if x.right == ivt.sentinel {
return
}
// line 2-3 // line 2-3
y := x.right y := x.right
x.right = y.left x.right = y.left
// line 5-6 // line 5-6
if y.left != nil { if y.left != ivt.sentinel {
y.left.parent = x y.left.parent = x
} }
x.updateMax(ivt.sentinel)
x.updateMax()
// line 10-15, 18 // line 10-15, 18
ivt.replaceParent(x, y) ivt.replaceParent(x, y)
// line 17 // line 17
y.left = x y.left = x
y.updateMax() y.updateMax(ivt.sentinel)
} }
// rotateRight moves x so it is right of its left child // rotateRight moves x so it is right of its left child
@ -630,7 +645,8 @@ func (ivt *intervalTree) rotateLeft(x *intervalNode) {
// 18. x.p = y // 18. x.p = y
// //
func (ivt *intervalTree) rotateRight(x *intervalNode) { func (ivt *intervalTree) rotateRight(x *intervalNode) {
if x == nil { // rotateRight x must have left child
if x.left == ivt.sentinel {
return return
} }
@ -639,24 +655,23 @@ func (ivt *intervalTree) rotateRight(x *intervalNode) {
x.left = y.right x.left = y.right
// line 5-6 // line 5-6
if y.right != nil { if y.right != ivt.sentinel {
y.right.parent = x y.right.parent = x
} }
x.updateMax(ivt.sentinel)
x.updateMax()
// line 10-15, 18 // line 10-15, 18
ivt.replaceParent(x, y) ivt.replaceParent(x, y)
// line 17 // line 17
y.right = x y.right = x
y.updateMax() y.updateMax(ivt.sentinel)
} }
// replaceParent replaces x's parent with y // replaceParent replaces x's parent with y
func (ivt *intervalTree) replaceParent(x *intervalNode, y *intervalNode) { func (ivt *intervalTree) replaceParent(x *intervalNode, y *intervalNode) {
y.parent = x.parent y.parent = x.parent
if x.parent == nil { if x.parent == ivt.sentinel {
ivt.root = y ivt.root = y
} else { } else {
if x == x.parent.left { if x == x.parent.left {
@ -664,7 +679,7 @@ func (ivt *intervalTree) replaceParent(x *intervalNode, y *intervalNode) {
} else { } else {
x.parent.right = y x.parent.right = y
} }
x.parent.updateMax() x.parent.updateMax(ivt.sentinel)
} }
x.parent = y x.parent = y
} }
@ -673,7 +688,7 @@ func (ivt *intervalTree) replaceParent(x *intervalNode, y *intervalNode) {
func (ivt *intervalTree) Len() int { return ivt.count } func (ivt *intervalTree) Len() int { return ivt.count }
// Height is the number of levels in the tree; one node has height 1. // Height is the number of levels in the tree; one node has height 1.
func (ivt *intervalTree) Height() int { return ivt.root.height() } func (ivt *intervalTree) Height() int { return ivt.root.height(ivt.sentinel) }
// MaxHeight is the expected maximum tree height given the number of nodes // MaxHeight is the expected maximum tree height given the number of nodes
func (ivt *intervalTree) MaxHeight() int { func (ivt *intervalTree) MaxHeight() int {
@ -686,11 +701,12 @@ type IntervalVisitor func(n *IntervalValue) bool
// Visit calls a visitor function on every tree node intersecting the given interval. // Visit calls a visitor function on every tree node intersecting the given interval.
// It will visit each interval [x, y) in ascending order sorted on x. // It will visit each interval [x, y) in ascending order sorted on x.
func (ivt *intervalTree) Visit(ivl Interval, ivv IntervalVisitor) { func (ivt *intervalTree) Visit(ivl Interval, ivv IntervalVisitor) {
ivt.root.visit(&ivl, func(n *intervalNode) bool { return ivv(&n.iv) }) ivt.root.visit(&ivl, ivt.sentinel, func(n *intervalNode) bool { return ivv(&n.iv) })
} }
// find the exact node for a given interval // find the exact node for a given interval
func (ivt *intervalTree) find(ivl Interval) (ret *intervalNode) { func (ivt *intervalTree) find(ivl Interval) *intervalNode {
ret := ivt.sentinel
f := func(n *intervalNode) bool { f := func(n *intervalNode) bool {
if n.iv.Ivl != ivl { if n.iv.Ivl != ivl {
return true return true
@ -698,14 +714,14 @@ func (ivt *intervalTree) find(ivl Interval) (ret *intervalNode) {
ret = n ret = n
return false return false
} }
ivt.root.visit(&ivl, f) ivt.root.visit(&ivl, ivt.sentinel, f)
return ret return ret
} }
// Find gets the IntervalValue for the node matching the given interval // Find gets the IntervalValue for the node matching the given interval
func (ivt *intervalTree) Find(ivl Interval) (ret *IntervalValue) { func (ivt *intervalTree) Find(ivl Interval) (ret *IntervalValue) {
n := ivt.find(ivl) n := ivt.find(ivl)
if n == nil { if n == ivt.sentinel {
return nil return nil
} }
return &n.iv return &n.iv
@ -714,14 +730,14 @@ func (ivt *intervalTree) Find(ivl Interval) (ret *IntervalValue) {
// Intersects returns true if there is some tree node intersecting the given interval. // Intersects returns true if there is some tree node intersecting the given interval.
func (ivt *intervalTree) Intersects(iv Interval) bool { func (ivt *intervalTree) Intersects(iv Interval) bool {
x := ivt.root x := ivt.root
for x != nil && iv.Compare(&x.iv.Ivl) != 0 { for x != ivt.sentinel && iv.Compare(&x.iv.Ivl) != 0 {
if x.left != nil && x.left.max.Compare(iv.Begin) > 0 { if x.left != ivt.sentinel && x.left.max.Compare(iv.Begin) > 0 {
x = x.left x = x.left
} else { } else {
x = x.right x = x.right
} }
} }
return x != nil return x != ivt.sentinel
} }
// Contains returns true if the interval tree's keys cover the entire given interval. // Contains returns true if the interval tree's keys cover the entire given interval.
@ -789,7 +805,7 @@ func (vi visitedInterval) String() string {
// visitLevel traverses tree in level order. // visitLevel traverses tree in level order.
// used for testing // used for testing
func (ivt *intervalTree) visitLevel() []visitedInterval { func (ivt *intervalTree) visitLevel() []visitedInterval {
if ivt.root == nil { if ivt.root == ivt.sentinel {
return nil return nil
} }
@ -804,22 +820,21 @@ func (ivt *intervalTree) visitLevel() []visitedInterval {
f := queue[0] f := queue[0]
queue = queue[1:] queue = queue[1:]
ivt := visitedInterval{ vi := visitedInterval{
root: f.node.iv.Ivl, root: f.node.iv.Ivl,
color: f.node.color(), color: f.node.color(ivt.sentinel),
depth: f.depth, depth: f.depth,
} }
if f.node.left != ivt.sentinel {
if f.node.left != nil { vi.left = f.node.left.iv.Ivl
ivt.left = f.node.left.iv.Ivl
queue = append(queue, pair{f.node.left, f.depth + 1}) queue = append(queue, pair{f.node.left, f.depth + 1})
} }
if f.node.right != nil { if f.node.right != ivt.sentinel {
ivt.right = f.node.right.iv.Ivl vi.right = f.node.right.iv.Ivl
queue = append(queue, pair{f.node.right, f.depth + 1}) queue = append(queue, pair{f.node.right, f.depth + 1})
} }
rs = append(rs, ivt) rs = append(rs, vi)
} }
return rs return rs

View File

@ -298,7 +298,7 @@ func TestIntervalTreeDelete(t *testing.T) {
// / \ / // / \ /
// [238,239] [292,293] [953,954] // [238,239] [292,293] [953,954]
// //
t.Logf("level order after deleting '11' expected %v, got %v", expectedAfterDelete11, visitsAfterDelete11) t.Fatalf("level order after deleting '11' expected %v, got %v", expectedAfterDelete11, visitsAfterDelete11)
} }
} }