mirror of
https://github.com/etcd-io/etcd.git
synced 2024-09-27 06:25:44 +00:00
pkg/adt: fix interval tree black-height property based on rbtree
Author: xkey <xk33430@ly.com> ref. https://github.com/etcd-io/etcd/pull/10978 Signed-off-by: Gyuho Lee <leegyuho@amazon.com>
This commit is contained in:
parent
9ff86fe516
commit
bb7df24af4
@ -87,39 +87,39 @@ type intervalNode struct {
|
||||
c rbcolor
|
||||
}
|
||||
|
||||
func (x *intervalNode) color() rbcolor {
|
||||
if x == nil {
|
||||
func (x *intervalNode) color(sentinel *intervalNode) rbcolor {
|
||||
if x == sentinel {
|
||||
return black
|
||||
}
|
||||
return x.c
|
||||
}
|
||||
|
||||
func (x *intervalNode) height() int {
|
||||
if x == nil {
|
||||
func (x *intervalNode) height(sentinel *intervalNode) int {
|
||||
if x == sentinel {
|
||||
return 0
|
||||
}
|
||||
ld := x.left.height()
|
||||
rd := x.right.height()
|
||||
ld := x.left.height(sentinel)
|
||||
rd := x.right.height(sentinel)
|
||||
if ld < rd {
|
||||
return rd + 1
|
||||
}
|
||||
return ld + 1
|
||||
}
|
||||
|
||||
func (x *intervalNode) min() *intervalNode {
|
||||
for x.left != nil {
|
||||
func (x *intervalNode) min(sentinel *intervalNode) *intervalNode {
|
||||
for x.left != sentinel {
|
||||
x = x.left
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
// successor is the next in-order node in the tree
|
||||
func (x *intervalNode) successor() *intervalNode {
|
||||
if x.right != nil {
|
||||
return x.right.min()
|
||||
func (x *intervalNode) successor(sentinel *intervalNode) *intervalNode {
|
||||
if x.right != sentinel {
|
||||
return x.right.min(sentinel)
|
||||
}
|
||||
y := x.parent
|
||||
for y != nil && x == y.right {
|
||||
for y != sentinel && x == y.right {
|
||||
x = y
|
||||
y = y.parent
|
||||
}
|
||||
@ -127,14 +127,14 @@ func (x *intervalNode) successor() *intervalNode {
|
||||
}
|
||||
|
||||
// updateMax updates the maximum values for a node and its ancestors
|
||||
func (x *intervalNode) updateMax() {
|
||||
for x != nil {
|
||||
func (x *intervalNode) updateMax(sentinel *intervalNode) {
|
||||
for x != sentinel {
|
||||
oldmax := x.max
|
||||
max := x.iv.Ivl.End
|
||||
if x.left != nil && x.left.max.Compare(max) > 0 {
|
||||
if x.left != sentinel && x.left.max.Compare(max) > 0 {
|
||||
max = x.left.max
|
||||
}
|
||||
if x.right != nil && x.right.max.Compare(max) > 0 {
|
||||
if x.right != sentinel && x.right.max.Compare(max) > 0 {
|
||||
max = x.right.max
|
||||
}
|
||||
if oldmax.Compare(max) == 0 {
|
||||
@ -148,25 +148,25 @@ func (x *intervalNode) updateMax() {
|
||||
type nodeVisitor func(n *intervalNode) bool
|
||||
|
||||
// visit will call a node visitor on each node that overlaps the given interval
|
||||
func (x *intervalNode) visit(iv *Interval, nv nodeVisitor) bool {
|
||||
if x == nil {
|
||||
func (x *intervalNode) visit(iv *Interval, sentinel *intervalNode, nv nodeVisitor) bool {
|
||||
if x == sentinel {
|
||||
return true
|
||||
}
|
||||
v := iv.Compare(&x.iv.Ivl)
|
||||
switch {
|
||||
case v < 0:
|
||||
if !x.left.visit(iv, nv) {
|
||||
if !x.left.visit(iv, sentinel, nv) {
|
||||
return false
|
||||
}
|
||||
case v > 0:
|
||||
maxiv := Interval{x.iv.Ivl.Begin, x.max}
|
||||
if maxiv.Compare(iv) == 0 {
|
||||
if !x.left.visit(iv, nv) || !x.right.visit(iv, nv) {
|
||||
if !x.left.visit(iv, sentinel, nv) || !x.right.visit(iv, sentinel, nv) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
default:
|
||||
if !x.left.visit(iv, nv) || !nv(x) || !x.right.visit(iv, nv) {
|
||||
if !x.left.visit(iv, sentinel, nv) || !nv(x) || !x.right.visit(iv, sentinel, nv) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
@ -211,9 +211,18 @@ type IntervalTree interface {
|
||||
|
||||
// NewIntervalTree returns a new interval tree.
|
||||
func NewIntervalTree() IntervalTree {
|
||||
sentinel := &intervalNode{
|
||||
iv: IntervalValue{},
|
||||
max: nil,
|
||||
left: nil,
|
||||
right: nil,
|
||||
parent: nil,
|
||||
c: black,
|
||||
}
|
||||
return &intervalTree{
|
||||
root: nil,
|
||||
count: 0,
|
||||
root: sentinel,
|
||||
count: 0,
|
||||
sentinel: sentinel,
|
||||
}
|
||||
}
|
||||
|
||||
@ -221,9 +230,11 @@ type intervalTree struct {
|
||||
root *intervalNode
|
||||
count int
|
||||
|
||||
// TODO: use 'sentinel' as a dummy object to simplify boundary conditions
|
||||
// red-black NIL node
|
||||
// use 'sentinel' as a dummy object to simplify boundary conditions
|
||||
// use the sentinel to treat a nil child of a node x as an ordinary node whose parent is x
|
||||
// use one shared sentinel to represent all nil leaves and the root's parent
|
||||
sentinel *intervalNode
|
||||
}
|
||||
|
||||
// TODO: make this consistent with textbook implementation
|
||||
@ -263,24 +274,25 @@ type intervalTree struct {
|
||||
// true if a node is in fact removed.
|
||||
func (ivt *intervalTree) Delete(ivl Interval) bool {
|
||||
z := ivt.find(ivl)
|
||||
if z == nil {
|
||||
if z == ivt.sentinel {
|
||||
return false
|
||||
}
|
||||
|
||||
y := z
|
||||
if z.left != nil && z.right != nil {
|
||||
y = z.successor()
|
||||
if z.left != ivt.sentinel && z.right != ivt.sentinel {
|
||||
y = z.successor(ivt.sentinel)
|
||||
}
|
||||
|
||||
x := y.left
|
||||
if x == nil {
|
||||
x := ivt.sentinel
|
||||
if y.left != ivt.sentinel {
|
||||
x = y.left
|
||||
} else if y.right != ivt.sentinel {
|
||||
x = y.right
|
||||
}
|
||||
if x != nil {
|
||||
x.parent = y.parent
|
||||
}
|
||||
|
||||
if y.parent == nil {
|
||||
x.parent = y.parent
|
||||
|
||||
if y.parent == ivt.sentinel {
|
||||
ivt.root = x
|
||||
} else {
|
||||
if y == y.parent.left {
|
||||
@ -288,14 +300,14 @@ func (ivt *intervalTree) Delete(ivl Interval) bool {
|
||||
} else {
|
||||
y.parent.right = x
|
||||
}
|
||||
y.parent.updateMax()
|
||||
y.parent.updateMax(ivt.sentinel)
|
||||
}
|
||||
if y != z {
|
||||
z.iv = y.iv
|
||||
z.updateMax()
|
||||
z.updateMax(ivt.sentinel)
|
||||
}
|
||||
|
||||
if y.color() == black && x != nil {
|
||||
if y.color(ivt.sentinel) == black {
|
||||
ivt.deleteFixup(x)
|
||||
}
|
||||
|
||||
@ -348,10 +360,10 @@ func (ivt *intervalTree) Delete(ivl Interval) bool {
|
||||
// 40. x.color = BLACK
|
||||
//
|
||||
func (ivt *intervalTree) deleteFixup(x *intervalNode) {
|
||||
for x != ivt.root && x.color() == black && x.parent != nil {
|
||||
for x != ivt.root && x.color(ivt.sentinel) == black {
|
||||
if x == x.parent.left { // line 3-20
|
||||
w := x.parent.right
|
||||
if w.color() == red {
|
||||
if w.color(ivt.sentinel) == red {
|
||||
w.c = black
|
||||
x.parent.c = red
|
||||
ivt.rotateLeft(x.parent)
|
||||
@ -360,28 +372,26 @@ func (ivt *intervalTree) deleteFixup(x *intervalNode) {
|
||||
if w == nil {
|
||||
break
|
||||
}
|
||||
if w.left.color() == black && w.right.color() == black {
|
||||
if w.left.color(ivt.sentinel) == black && w.right.color(ivt.sentinel) == black {
|
||||
w.c = red
|
||||
x = x.parent
|
||||
} else {
|
||||
if w.right.color() == black {
|
||||
if w.right.color(ivt.sentinel) == black {
|
||||
w.left.c = black
|
||||
w.c = red
|
||||
ivt.rotateRight(w)
|
||||
w = x.parent.right
|
||||
}
|
||||
w.c = x.parent.color()
|
||||
w.c = x.parent.color(ivt.sentinel)
|
||||
x.parent.c = black
|
||||
w.right.c = black
|
||||
ivt.rotateLeft(x.parent)
|
||||
x = ivt.root
|
||||
}
|
||||
|
||||
} else { // line 22-38
|
||||
|
||||
// same as above but with left and right exchanged
|
||||
w := x.parent.left
|
||||
if w.color() == red {
|
||||
if w.color(ivt.sentinel) == red {
|
||||
w.c = black
|
||||
x.parent.c = red
|
||||
ivt.rotateRight(x.parent)
|
||||
@ -390,17 +400,17 @@ func (ivt *intervalTree) deleteFixup(x *intervalNode) {
|
||||
if w == nil {
|
||||
break
|
||||
}
|
||||
if w.left.color() == black && w.right.color() == black {
|
||||
if w.left.color(ivt.sentinel) == black && w.right.color(ivt.sentinel) == black {
|
||||
w.c = red
|
||||
x = x.parent
|
||||
} else {
|
||||
if w.left.color() == black {
|
||||
if w.left.color(ivt.sentinel) == black {
|
||||
w.right.c = black
|
||||
w.c = red
|
||||
ivt.rotateLeft(w)
|
||||
w = x.parent.left
|
||||
}
|
||||
w.c = x.parent.color()
|
||||
w.c = x.parent.color(ivt.sentinel)
|
||||
x.parent.c = black
|
||||
w.left.c = black
|
||||
ivt.rotateRight(x.parent)
|
||||
@ -419,9 +429,9 @@ func (ivt *intervalTree) createIntervalNode(ivl Interval, val interface{}) *inte
|
||||
iv: IntervalValue{ivl, val},
|
||||
max: ivl.End,
|
||||
c: red,
|
||||
left: nil,
|
||||
right: nil,
|
||||
parent: nil,
|
||||
left: ivt.sentinel,
|
||||
right: ivt.sentinel,
|
||||
parent: ivt.sentinel,
|
||||
}
|
||||
}
|
||||
|
||||
@ -458,10 +468,10 @@ func (ivt *intervalTree) createIntervalNode(ivl Interval, val interface{}) *inte
|
||||
|
||||
// Insert adds a node with the given interval into the tree.
|
||||
func (ivt *intervalTree) Insert(ivl Interval, val interface{}) {
|
||||
var y *intervalNode
|
||||
y := ivt.sentinel
|
||||
z := ivt.createIntervalNode(ivl, val)
|
||||
x := ivt.root
|
||||
for x != nil {
|
||||
for x != ivt.sentinel {
|
||||
y = x
|
||||
if z.iv.Ivl.Begin.Compare(x.iv.Ivl.Begin) < 0 {
|
||||
x = x.left
|
||||
@ -471,7 +481,7 @@ func (ivt *intervalTree) Insert(ivl Interval, val interface{}) {
|
||||
}
|
||||
|
||||
z.parent = y
|
||||
if y == nil {
|
||||
if y == ivt.sentinel {
|
||||
ivt.root = z
|
||||
} else {
|
||||
if z.iv.Ivl.Begin.Compare(y.iv.Ivl.Begin) < 0 {
|
||||
@ -479,7 +489,7 @@ func (ivt *intervalTree) Insert(ivl Interval, val interface{}) {
|
||||
} else {
|
||||
y.right = z
|
||||
}
|
||||
y.updateMax()
|
||||
y.updateMax(ivt.sentinel)
|
||||
}
|
||||
z.c = red
|
||||
|
||||
@ -522,10 +532,11 @@ func (ivt *intervalTree) Insert(ivl Interval, val interface{}) {
|
||||
// 30. T.root.color = BLACK
|
||||
//
|
||||
func (ivt *intervalTree) insertFixup(z *intervalNode) {
|
||||
for z.parent != nil && z.parent.parent != nil && z.parent.color() == red {
|
||||
for z.parent.color(ivt.sentinel) == red {
|
||||
if z.parent == z.parent.parent.left { // line 3-15
|
||||
|
||||
y := z.parent.parent.right
|
||||
if y.color() == red {
|
||||
if y.color(ivt.sentinel) == red {
|
||||
y.c = black
|
||||
z.parent.c = black
|
||||
z.parent.parent.c = red
|
||||
@ -542,7 +553,7 @@ func (ivt *intervalTree) insertFixup(z *intervalNode) {
|
||||
} else { // line 16-28
|
||||
// same as then with left/right exchanged
|
||||
y := z.parent.parent.left
|
||||
if y.color() == red {
|
||||
if y.color(ivt.sentinel) == red {
|
||||
y.c = black
|
||||
z.parent.c = black
|
||||
z.parent.parent.c = red
|
||||
@ -588,23 +599,27 @@ func (ivt *intervalTree) insertFixup(z *intervalNode) {
|
||||
// 18. x.p = y
|
||||
//
|
||||
func (ivt *intervalTree) rotateLeft(x *intervalNode) {
|
||||
// rotateLeft x must have right child
|
||||
if x.right == ivt.sentinel {
|
||||
return
|
||||
}
|
||||
|
||||
// line 2-3
|
||||
y := x.right
|
||||
x.right = y.left
|
||||
|
||||
// line 5-6
|
||||
if y.left != nil {
|
||||
if y.left != ivt.sentinel {
|
||||
y.left.parent = x
|
||||
}
|
||||
|
||||
x.updateMax()
|
||||
x.updateMax(ivt.sentinel)
|
||||
|
||||
// line 10-15, 18
|
||||
ivt.replaceParent(x, y)
|
||||
|
||||
// line 17
|
||||
y.left = x
|
||||
y.updateMax()
|
||||
y.updateMax(ivt.sentinel)
|
||||
}
|
||||
|
||||
// rotateRight moves x so it is right of its left child
|
||||
@ -630,7 +645,8 @@ func (ivt *intervalTree) rotateLeft(x *intervalNode) {
|
||||
// 18. x.p = y
|
||||
//
|
||||
func (ivt *intervalTree) rotateRight(x *intervalNode) {
|
||||
if x == nil {
|
||||
// rotateRight x must have left child
|
||||
if x.left == ivt.sentinel {
|
||||
return
|
||||
}
|
||||
|
||||
@ -639,24 +655,23 @@ func (ivt *intervalTree) rotateRight(x *intervalNode) {
|
||||
x.left = y.right
|
||||
|
||||
// line 5-6
|
||||
if y.right != nil {
|
||||
if y.right != ivt.sentinel {
|
||||
y.right.parent = x
|
||||
}
|
||||
|
||||
x.updateMax()
|
||||
x.updateMax(ivt.sentinel)
|
||||
|
||||
// line 10-15, 18
|
||||
ivt.replaceParent(x, y)
|
||||
|
||||
// line 17
|
||||
y.right = x
|
||||
y.updateMax()
|
||||
y.updateMax(ivt.sentinel)
|
||||
}
|
||||
|
||||
// replaceParent replaces x's parent with y
|
||||
func (ivt *intervalTree) replaceParent(x *intervalNode, y *intervalNode) {
|
||||
y.parent = x.parent
|
||||
if x.parent == nil {
|
||||
if x.parent == ivt.sentinel {
|
||||
ivt.root = y
|
||||
} else {
|
||||
if x == x.parent.left {
|
||||
@ -664,7 +679,7 @@ func (ivt *intervalTree) replaceParent(x *intervalNode, y *intervalNode) {
|
||||
} else {
|
||||
x.parent.right = y
|
||||
}
|
||||
x.parent.updateMax()
|
||||
x.parent.updateMax(ivt.sentinel)
|
||||
}
|
||||
x.parent = y
|
||||
}
|
||||
@ -673,7 +688,7 @@ func (ivt *intervalTree) replaceParent(x *intervalNode, y *intervalNode) {
|
||||
func (ivt *intervalTree) Len() int { return ivt.count }
|
||||
|
||||
// Height is the number of levels in the tree; one node has height 1.
|
||||
func (ivt *intervalTree) Height() int { return ivt.root.height() }
|
||||
func (ivt *intervalTree) Height() int { return ivt.root.height(ivt.sentinel) }
|
||||
|
||||
// MaxHeight is the expected maximum tree height given the number of nodes
|
||||
func (ivt *intervalTree) MaxHeight() int {
|
||||
@ -686,11 +701,12 @@ type IntervalVisitor func(n *IntervalValue) bool
|
||||
// Visit calls a visitor function on every tree node intersecting the given interval.
|
||||
// It will visit each interval [x, y) in ascending order sorted on x.
|
||||
func (ivt *intervalTree) Visit(ivl Interval, ivv IntervalVisitor) {
|
||||
ivt.root.visit(&ivl, func(n *intervalNode) bool { return ivv(&n.iv) })
|
||||
ivt.root.visit(&ivl, ivt.sentinel, func(n *intervalNode) bool { return ivv(&n.iv) })
|
||||
}
|
||||
|
||||
// find the exact node for a given interval
|
||||
func (ivt *intervalTree) find(ivl Interval) (ret *intervalNode) {
|
||||
func (ivt *intervalTree) find(ivl Interval) *intervalNode {
|
||||
ret := ivt.sentinel
|
||||
f := func(n *intervalNode) bool {
|
||||
if n.iv.Ivl != ivl {
|
||||
return true
|
||||
@ -698,14 +714,14 @@ func (ivt *intervalTree) find(ivl Interval) (ret *intervalNode) {
|
||||
ret = n
|
||||
return false
|
||||
}
|
||||
ivt.root.visit(&ivl, f)
|
||||
ivt.root.visit(&ivl, ivt.sentinel, f)
|
||||
return ret
|
||||
}
|
||||
|
||||
// Find gets the IntervalValue for the node matching the given interval
|
||||
func (ivt *intervalTree) Find(ivl Interval) (ret *IntervalValue) {
|
||||
n := ivt.find(ivl)
|
||||
if n == nil {
|
||||
if n == ivt.sentinel {
|
||||
return nil
|
||||
}
|
||||
return &n.iv
|
||||
@ -714,14 +730,14 @@ func (ivt *intervalTree) Find(ivl Interval) (ret *IntervalValue) {
|
||||
// Intersects returns true if there is some tree node intersecting the given interval.
|
||||
func (ivt *intervalTree) Intersects(iv Interval) bool {
|
||||
x := ivt.root
|
||||
for x != nil && iv.Compare(&x.iv.Ivl) != 0 {
|
||||
if x.left != nil && x.left.max.Compare(iv.Begin) > 0 {
|
||||
for x != ivt.sentinel && iv.Compare(&x.iv.Ivl) != 0 {
|
||||
if x.left != ivt.sentinel && x.left.max.Compare(iv.Begin) > 0 {
|
||||
x = x.left
|
||||
} else {
|
||||
x = x.right
|
||||
}
|
||||
}
|
||||
return x != nil
|
||||
return x != ivt.sentinel
|
||||
}
|
||||
|
||||
// Contains returns true if the interval tree's keys cover the entire given interval.
|
||||
@ -789,7 +805,7 @@ func (vi visitedInterval) String() string {
|
||||
// visitLevel traverses tree in level order.
|
||||
// used for testing
|
||||
func (ivt *intervalTree) visitLevel() []visitedInterval {
|
||||
if ivt.root == nil {
|
||||
if ivt.root == ivt.sentinel {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -804,22 +820,21 @@ func (ivt *intervalTree) visitLevel() []visitedInterval {
|
||||
f := queue[0]
|
||||
queue = queue[1:]
|
||||
|
||||
ivt := visitedInterval{
|
||||
vi := visitedInterval{
|
||||
root: f.node.iv.Ivl,
|
||||
color: f.node.color(),
|
||||
color: f.node.color(ivt.sentinel),
|
||||
depth: f.depth,
|
||||
}
|
||||
|
||||
if f.node.left != nil {
|
||||
ivt.left = f.node.left.iv.Ivl
|
||||
if f.node.left != ivt.sentinel {
|
||||
vi.left = f.node.left.iv.Ivl
|
||||
queue = append(queue, pair{f.node.left, f.depth + 1})
|
||||
}
|
||||
if f.node.right != nil {
|
||||
ivt.right = f.node.right.iv.Ivl
|
||||
if f.node.right != ivt.sentinel {
|
||||
vi.right = f.node.right.iv.Ivl
|
||||
queue = append(queue, pair{f.node.right, f.depth + 1})
|
||||
}
|
||||
|
||||
rs = append(rs, ivt)
|
||||
rs = append(rs, vi)
|
||||
}
|
||||
|
||||
return rs
|
||||
|
@ -298,7 +298,7 @@ func TestIntervalTreeDelete(t *testing.T) {
|
||||
// / \ /
|
||||
// [238,239] [292,293] [953,954]
|
||||
//
|
||||
t.Logf("level order after deleting '11' expected %v, got %v", expectedAfterDelete11, visitsAfterDelete11)
|
||||
t.Fatalf("level order after deleting '11' expected %v, got %v", expectedAfterDelete11, visitsAfterDelete11)
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user