mirror of
https://github.com/etcd-io/etcd.git
synced 2024-09-27 06:25:44 +00:00
Merge pull request #4614 from heyitsanthony/future-watch-rpc
etcdserver, storage, clientv3: watcher ranges
This commit is contained in:
commit
3a9d532140
@ -157,6 +157,20 @@ func testWatchMultiWatcher(t *testing.T, wctx *watchctx) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestWatchRange tests watcher creates ranges
|
||||
func TestWatchRange(t *testing.T) {
|
||||
runWatchTest(t, testWatchReconnInit)
|
||||
}
|
||||
|
||||
func testWatchRange(t *testing.T, wctx *watchctx) {
|
||||
if wctx.ch = wctx.w.Watch(context.TODO(), "a", clientv3.WithRange("c")); wctx.ch == nil {
|
||||
t.Fatalf("expected non-nil channel")
|
||||
}
|
||||
putAndWatch(t, wctx, "a", "a")
|
||||
putAndWatch(t, wctx, "b", "b")
|
||||
putAndWatch(t, wctx, "bar", "bar")
|
||||
}
|
||||
|
||||
// TestWatchReconnRequest tests the send failure path when requesting a watcher.
|
||||
func TestWatchReconnRequest(t *testing.T) {
|
||||
runWatchTest(t, testWatchReconnRequest)
|
||||
|
@ -15,8 +15,6 @@
|
||||
package clientv3
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
pb "github.com/coreos/etcd/etcdserver/etcdserverpb"
|
||||
"github.com/coreos/etcd/lease"
|
||||
)
|
||||
@ -69,27 +67,6 @@ func (op Op) toRequestUnion() *pb.RequestUnion {
|
||||
}
|
||||
}
|
||||
|
||||
func (op Op) toWatchRequest() *watchRequest {
|
||||
switch op.t {
|
||||
case tRange:
|
||||
key := string(op.key)
|
||||
prefix := ""
|
||||
if op.end != nil {
|
||||
prefix = key
|
||||
key = ""
|
||||
}
|
||||
wr := &watchRequest{
|
||||
key: key,
|
||||
prefix: prefix,
|
||||
rev: op.rev,
|
||||
}
|
||||
return wr
|
||||
|
||||
default:
|
||||
panic("Only for tRange")
|
||||
}
|
||||
}
|
||||
|
||||
func (op Op) isWrite() bool {
|
||||
return op.t != tRange
|
||||
}
|
||||
@ -140,8 +117,6 @@ func opWatch(key string, opts ...OpOption) Op {
|
||||
ret := Op{t: tRange, key: []byte(key)}
|
||||
ret.applyOpts(opts)
|
||||
switch {
|
||||
case ret.end != nil && !reflect.DeepEqual(ret.end, getPrefix(ret.key)):
|
||||
panic("only supports single keys or prefixes")
|
||||
case ret.leaseID != 0:
|
||||
panic("unexpected lease in watch")
|
||||
case ret.limit != 0:
|
||||
|
@ -78,10 +78,10 @@ type watcher struct {
|
||||
|
||||
// watchRequest is issued by the subscriber to start a new watcher
|
||||
type watchRequest struct {
|
||||
ctx context.Context
|
||||
key string
|
||||
prefix string
|
||||
rev int64
|
||||
ctx context.Context
|
||||
key string
|
||||
end string
|
||||
rev int64
|
||||
// retc receives a chan WatchResponse once the watcher is established
|
||||
retc chan chan WatchResponse
|
||||
}
|
||||
@ -129,11 +129,14 @@ func NewWatcher(c *Client) Watcher {
|
||||
func (w *watcher) Watch(ctx context.Context, key string, opts ...OpOption) WatchChan {
|
||||
ow := opWatch(key, opts...)
|
||||
|
||||
wr := ow.toWatchRequest()
|
||||
wr.ctx = ctx
|
||||
|
||||
retc := make(chan chan WatchResponse, 1)
|
||||
wr.retc = retc
|
||||
wr := &watchRequest{
|
||||
ctx: ctx,
|
||||
key: string(ow.key),
|
||||
end: string(ow.end),
|
||||
rev: ow.rev,
|
||||
retc: retc,
|
||||
}
|
||||
|
||||
ok := false
|
||||
|
||||
@ -502,11 +505,10 @@ func (w *watcher) resumeWatchers(wc pb.Watch_WatchClient) error {
|
||||
|
||||
// toPB converts an internal watch request structure to its protobuf messagefunc (wr *watchRequest)
|
||||
func (wr *watchRequest) toPB() *pb.WatchRequest {
|
||||
req := &pb.WatchCreateRequest{StartRevision: wr.rev}
|
||||
if wr.key != "" {
|
||||
req.Key = []byte(wr.key)
|
||||
} else {
|
||||
req.Prefix = []byte(wr.prefix)
|
||||
req := &pb.WatchCreateRequest{
|
||||
StartRevision: wr.rev,
|
||||
Key: []byte(wr.key),
|
||||
RangeEnd: []byte(wr.end),
|
||||
}
|
||||
cr := &pb.WatchRequest_CreateRequest{CreateRequest: req}
|
||||
return &pb.WatchRequest{RequestUnion: cr}
|
||||
|
@ -94,35 +94,33 @@ func (sws *serverWatchStream) recvLoop() error {
|
||||
|
||||
switch uv := req.RequestUnion.(type) {
|
||||
case *pb.WatchRequest_CreateRequest:
|
||||
if uv.CreateRequest != nil {
|
||||
creq := uv.CreateRequest
|
||||
var prefix bool
|
||||
toWatch := creq.Key
|
||||
if len(creq.Key) == 0 {
|
||||
toWatch = creq.Prefix
|
||||
prefix = true
|
||||
}
|
||||
if uv.CreateRequest == nil {
|
||||
break
|
||||
}
|
||||
|
||||
rev := creq.StartRevision
|
||||
wsrev := sws.watchStream.Rev()
|
||||
if rev == 0 {
|
||||
// rev 0 watches past the current revision
|
||||
rev = wsrev + 1
|
||||
} else if rev > wsrev { // do not allow watching future revision.
|
||||
sws.ctrlStream <- &pb.WatchResponse{
|
||||
Header: sws.newResponseHeader(wsrev),
|
||||
WatchId: -1,
|
||||
Created: true,
|
||||
Canceled: true,
|
||||
}
|
||||
continue
|
||||
}
|
||||
id := sws.watchStream.Watch(toWatch, prefix, rev)
|
||||
sws.ctrlStream <- &pb.WatchResponse{
|
||||
Header: sws.newResponseHeader(wsrev),
|
||||
WatchId: int64(id),
|
||||
Created: true,
|
||||
}
|
||||
creq := uv.CreateRequest
|
||||
if len(creq.RangeEnd) == 1 && creq.RangeEnd[0] == 0 {
|
||||
// support >= key queries
|
||||
creq.RangeEnd = []byte{}
|
||||
}
|
||||
|
||||
rev := creq.StartRevision
|
||||
wsrev := sws.watchStream.Rev()
|
||||
futureRev := rev > wsrev
|
||||
if rev == 0 {
|
||||
// rev 0 watches past the current revision
|
||||
rev = wsrev + 1
|
||||
}
|
||||
// do not allow future watch revision
|
||||
id := storage.WatchID(-1)
|
||||
if !futureRev {
|
||||
id = sws.watchStream.Watch(creq.Key, creq.RangeEnd, rev)
|
||||
}
|
||||
sws.ctrlStream <- &pb.WatchResponse{
|
||||
Header: sws.newResponseHeader(wsrev),
|
||||
WatchId: int64(id),
|
||||
Created: true,
|
||||
Canceled: futureRev,
|
||||
}
|
||||
case *pb.WatchRequest_CancelRequest:
|
||||
if uv.CancelRequest != nil {
|
||||
|
@ -870,8 +870,9 @@ func _WatchRequest_OneofUnmarshaler(msg proto.Message, tag, wire int, b *proto.B
|
||||
type WatchCreateRequest struct {
|
||||
// the key to be watched
|
||||
Key []byte `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"`
|
||||
// the prefix to be watched.
|
||||
Prefix []byte `protobuf:"bytes,2,opt,name=prefix,proto3" json:"prefix,omitempty"`
|
||||
// if the range_end is given, keys in [key, range_end) are watched
|
||||
// NOTE: only range_end == prefixEnd(key) is accepted now
|
||||
RangeEnd []byte `protobuf:"bytes,2,opt,name=range_end,proto3" json:"range_end,omitempty"`
|
||||
// start_revision is an optional revision (including) to watch from. No start_revision is "now".
|
||||
StartRevision int64 `protobuf:"varint,3,opt,name=start_revision,proto3" json:"start_revision,omitempty"`
|
||||
}
|
||||
@ -2588,12 +2589,12 @@ func (m *WatchCreateRequest) MarshalTo(data []byte) (int, error) {
|
||||
i += copy(data[i:], m.Key)
|
||||
}
|
||||
}
|
||||
if m.Prefix != nil {
|
||||
if len(m.Prefix) > 0 {
|
||||
if m.RangeEnd != nil {
|
||||
if len(m.RangeEnd) > 0 {
|
||||
data[i] = 0x12
|
||||
i++
|
||||
i = encodeVarintRpc(data, i, uint64(len(m.Prefix)))
|
||||
i += copy(data[i:], m.Prefix)
|
||||
i = encodeVarintRpc(data, i, uint64(len(m.RangeEnd)))
|
||||
i += copy(data[i:], m.RangeEnd)
|
||||
}
|
||||
}
|
||||
if m.StartRevision != 0 {
|
||||
@ -3592,8 +3593,8 @@ func (m *WatchCreateRequest) Size() (n int) {
|
||||
n += 1 + l + sovRpc(uint64(l))
|
||||
}
|
||||
}
|
||||
if m.Prefix != nil {
|
||||
l = len(m.Prefix)
|
||||
if m.RangeEnd != nil {
|
||||
l = len(m.RangeEnd)
|
||||
if l > 0 {
|
||||
n += 1 + l + sovRpc(uint64(l))
|
||||
}
|
||||
@ -6004,7 +6005,7 @@ func (m *WatchCreateRequest) Unmarshal(data []byte) error {
|
||||
iNdEx = postIndex
|
||||
case 2:
|
||||
if wireType != 2 {
|
||||
return fmt.Errorf("proto: wrong wireType = %d for field Prefix", wireType)
|
||||
return fmt.Errorf("proto: wrong wireType = %d for field RangeEnd", wireType)
|
||||
}
|
||||
var byteLen int
|
||||
for shift := uint(0); ; shift += 7 {
|
||||
@ -6028,9 +6029,9 @@ func (m *WatchCreateRequest) Unmarshal(data []byte) error {
|
||||
if postIndex > l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
m.Prefix = append(m.Prefix[:0], data[iNdEx:postIndex]...)
|
||||
if m.Prefix == nil {
|
||||
m.Prefix = []byte{}
|
||||
m.RangeEnd = append(m.RangeEnd[:0], data[iNdEx:postIndex]...)
|
||||
if m.RangeEnd == nil {
|
||||
m.RangeEnd = []byte{}
|
||||
}
|
||||
iNdEx = postIndex
|
||||
case 3:
|
||||
|
@ -262,11 +262,11 @@ message WatchRequest {
|
||||
message WatchCreateRequest {
|
||||
// the key to be watched
|
||||
bytes key = 1;
|
||||
// the prefix to be watched.
|
||||
bytes prefix = 2;
|
||||
// if the range_end is given, keys in [key, range_end) are watched
|
||||
// NOTE: only range_end == prefixEnd(key) is accepted now
|
||||
bytes range_end = 2;
|
||||
// start_revision is an optional revision (including) to watch from. No start_revision is "now".
|
||||
int64 start_revision = 3;
|
||||
// TODO: support Range watch?
|
||||
}
|
||||
|
||||
message WatchCancelRequest {
|
||||
|
@ -71,7 +71,8 @@ func TestV3WatchFromCurrentRevision(t *testing.T) {
|
||||
[]string{"fooLong"},
|
||||
&pb.WatchRequest{RequestUnion: &pb.WatchRequest_CreateRequest{
|
||||
CreateRequest: &pb.WatchCreateRequest{
|
||||
Prefix: []byte("foo")}}},
|
||||
Key: []byte("foo"),
|
||||
RangeEnd: []byte("fop")}}},
|
||||
|
||||
[]*pb.WatchResponse{
|
||||
{
|
||||
@ -91,7 +92,8 @@ func TestV3WatchFromCurrentRevision(t *testing.T) {
|
||||
[]string{"foo"},
|
||||
&pb.WatchRequest{RequestUnion: &pb.WatchRequest_CreateRequest{
|
||||
CreateRequest: &pb.WatchCreateRequest{
|
||||
Prefix: []byte("helloworld")}}},
|
||||
Key: []byte("helloworld"),
|
||||
RangeEnd: []byte("helloworle")}}},
|
||||
|
||||
[]*pb.WatchResponse{},
|
||||
},
|
||||
@ -140,7 +142,8 @@ func TestV3WatchFromCurrentRevision(t *testing.T) {
|
||||
[]string{"foo", "foo", "foo"},
|
||||
&pb.WatchRequest{RequestUnion: &pb.WatchRequest_CreateRequest{
|
||||
CreateRequest: &pb.WatchCreateRequest{
|
||||
Prefix: []byte("foo")}}},
|
||||
Key: []byte("foo"),
|
||||
RangeEnd: []byte("fop")}}},
|
||||
|
||||
[]*pb.WatchResponse{
|
||||
{
|
||||
@ -203,6 +206,11 @@ func TestV3WatchFromCurrentRevision(t *testing.T) {
|
||||
t.Errorf("#%d: did not create watchid, got +%v", i, cresp)
|
||||
continue
|
||||
}
|
||||
if cresp.Canceled {
|
||||
t.Errorf("#%d: canceled watcher on create", i, cresp)
|
||||
continue
|
||||
}
|
||||
|
||||
createdWatchId := cresp.WatchId
|
||||
if cresp.Header == nil || cresp.Header.Revision != 1 {
|
||||
t.Errorf("#%d: header revision got +%v, wanted revison 1", i, cresp)
|
||||
@ -353,7 +361,7 @@ func TestV3WatchCurrentPutOverlap(t *testing.T) {
|
||||
progress := make(map[int64]int64)
|
||||
|
||||
wreq := &pb.WatchRequest{RequestUnion: &pb.WatchRequest_CreateRequest{
|
||||
CreateRequest: &pb.WatchCreateRequest{Prefix: []byte("foo")}}}
|
||||
CreateRequest: &pb.WatchCreateRequest{Key: []byte("foo"), RangeEnd: []byte("fop")}}}
|
||||
if err := wStream.Send(wreq); err != nil {
|
||||
t.Fatalf("first watch request failed (%v)", err)
|
||||
}
|
||||
@ -437,7 +445,7 @@ func testV3WatchMultipleWatchers(t *testing.T, startRev int64) {
|
||||
} else {
|
||||
wreq = &pb.WatchRequest{RequestUnion: &pb.WatchRequest_CreateRequest{
|
||||
CreateRequest: &pb.WatchCreateRequest{
|
||||
Prefix: []byte("fo"), StartRevision: startRev}}}
|
||||
Key: []byte("fo"), RangeEnd: []byte("fp"), StartRevision: startRev}}}
|
||||
}
|
||||
if err := wStream.Send(wreq); err != nil {
|
||||
t.Fatalf("wStream.Send error: %v", err)
|
||||
@ -530,7 +538,7 @@ func testV3WatchMultipleEventsTxn(t *testing.T, startRev int64) {
|
||||
|
||||
wreq := &pb.WatchRequest{RequestUnion: &pb.WatchRequest_CreateRequest{
|
||||
CreateRequest: &pb.WatchCreateRequest{
|
||||
Prefix: []byte("foo"), StartRevision: startRev}}}
|
||||
Key: []byte("foo"), RangeEnd: []byte("fop"), StartRevision: startRev}}}
|
||||
if err := wStream.Send(wreq); err != nil {
|
||||
t.Fatalf("wStream.Send error: %v", err)
|
||||
}
|
||||
@ -623,7 +631,7 @@ func TestV3WatchMultipleEventsPutUnsynced(t *testing.T) {
|
||||
|
||||
wreq := &pb.WatchRequest{RequestUnion: &pb.WatchRequest_CreateRequest{
|
||||
CreateRequest: &pb.WatchCreateRequest{
|
||||
Prefix: []byte("foo"), StartRevision: 1}}}
|
||||
Key: []byte("foo"), RangeEnd: []byte("fop"), StartRevision: 1}}}
|
||||
if err := wStream.Send(wreq); err != nil {
|
||||
t.Fatalf("wStream.Send error: %v", err)
|
||||
}
|
||||
|
526
pkg/adt/interval_tree.go
Normal file
526
pkg/adt/interval_tree.go
Normal file
@ -0,0 +1,526 @@
|
||||
// Copyright 2016 CoreOS, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package adt
|
||||
|
||||
import (
|
||||
"math"
|
||||
)
|
||||
|
||||
// Comparable is an interface for trichotomic comparisons.
|
||||
type Comparable interface {
|
||||
// Compare gives the result of a 3-way comparison
|
||||
// a.Compare(b) = 1 => a > b
|
||||
// a.Compare(b) = 0 => a == b
|
||||
// a.Compare(b) = -1 => a < b
|
||||
Compare(c Comparable) int
|
||||
}
|
||||
|
||||
type rbcolor bool
|
||||
|
||||
const black = true
|
||||
const red = false
|
||||
|
||||
// Interval implements a Comparable interval [begin, end)
|
||||
// TODO: support different sorts of intervals: (a,b), [a,b], (a, b]
|
||||
type Interval struct {
|
||||
Begin Comparable
|
||||
End Comparable
|
||||
}
|
||||
|
||||
// Compare on an interval gives == if the interval overlaps.
|
||||
func (ivl *Interval) Compare(c Comparable) int {
|
||||
ivl2 := c.(*Interval)
|
||||
ivbCmpBegin := ivl.Begin.Compare(ivl2.Begin)
|
||||
ivbCmpEnd := ivl.Begin.Compare(ivl2.End)
|
||||
iveCmpBegin := ivl.End.Compare(ivl2.Begin)
|
||||
|
||||
// ivl is left of ivl2
|
||||
if ivbCmpBegin < 0 && iveCmpBegin <= 0 {
|
||||
return -1
|
||||
}
|
||||
|
||||
// iv is right of iv2
|
||||
if ivbCmpEnd >= 0 {
|
||||
return 1
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
type intervalNode struct {
|
||||
// iv is the interval-value pair entry.
|
||||
iv IntervalValue
|
||||
// max endpoint of all descendent nodes.
|
||||
max Comparable
|
||||
// left and right are sorted by low endpoint of key interval
|
||||
left, right *intervalNode
|
||||
// parent is the direct ancestor of the node
|
||||
parent *intervalNode
|
||||
c rbcolor
|
||||
}
|
||||
|
||||
func (x *intervalNode) color() rbcolor {
|
||||
if x == nil {
|
||||
return black
|
||||
}
|
||||
return x.c
|
||||
}
|
||||
|
||||
func (n *intervalNode) height() int {
|
||||
if n == nil {
|
||||
return 0
|
||||
}
|
||||
ld := n.left.height()
|
||||
rd := n.right.height()
|
||||
if ld < rd {
|
||||
return rd + 1
|
||||
}
|
||||
return ld + 1
|
||||
}
|
||||
|
||||
func (x *intervalNode) min() *intervalNode {
|
||||
for x.left != nil {
|
||||
x = x.left
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
// successor is the next in-order node in the tree
|
||||
func (x *intervalNode) successor() *intervalNode {
|
||||
if x.right != nil {
|
||||
return x.right.min()
|
||||
}
|
||||
y := x.parent
|
||||
for y != nil && x == y.right {
|
||||
x = y
|
||||
y = y.parent
|
||||
}
|
||||
return y
|
||||
}
|
||||
|
||||
// updateMax updates the maximum values for a node and its ancestors
|
||||
func (x *intervalNode) updateMax() {
|
||||
for x != nil {
|
||||
oldmax := x.max
|
||||
max := x.iv.Ivl.End
|
||||
if x.left != nil && x.left.max.Compare(max) > 0 {
|
||||
max = x.left.max
|
||||
}
|
||||
if x.right != nil && x.right.max.Compare(max) > 0 {
|
||||
max = x.right.max
|
||||
}
|
||||
if oldmax.Compare(max) == 0 {
|
||||
break
|
||||
}
|
||||
x.max = max
|
||||
x = x.parent
|
||||
}
|
||||
}
|
||||
|
||||
type nodeVisitor func(n *intervalNode) bool
|
||||
|
||||
// visit will call a node visitor on each node that overlaps the given interval
|
||||
func (x *intervalNode) visit(iv *Interval, nv nodeVisitor) {
|
||||
if x == nil {
|
||||
return
|
||||
}
|
||||
v := iv.Compare(&x.iv.Ivl)
|
||||
switch {
|
||||
case v < 0:
|
||||
x.left.visit(iv, nv)
|
||||
case v > 0:
|
||||
maxiv := Interval{x.iv.Ivl.Begin, x.max}
|
||||
if maxiv.Compare(iv) == 0 {
|
||||
x.left.visit(iv, nv)
|
||||
x.right.visit(iv, nv)
|
||||
}
|
||||
default:
|
||||
nv(x)
|
||||
x.left.visit(iv, nv)
|
||||
x.right.visit(iv, nv)
|
||||
}
|
||||
}
|
||||
|
||||
type IntervalValue struct {
|
||||
Ivl Interval
|
||||
Val interface{}
|
||||
}
|
||||
|
||||
// IntervalTree represents a (mostly) textbook implementation of the
|
||||
// "Introduction to Algorithms" (Cormen et al, 2nd ed.) chapter 13 red-black tree
|
||||
// and chapter 14.3 interval tree with search supporting "stabbing queries".
|
||||
type IntervalTree struct {
|
||||
root *intervalNode
|
||||
count int
|
||||
}
|
||||
|
||||
// Delete removes the node with the given interval from the tree, returning
|
||||
// true if a node is in fact removed.
|
||||
func (ivt *IntervalTree) Delete(ivl Interval) bool {
|
||||
z := ivt.find(ivl)
|
||||
if z == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
y := z
|
||||
if z.left != nil && z.right != nil {
|
||||
y = z.successor()
|
||||
}
|
||||
|
||||
x := y.left
|
||||
if x == nil {
|
||||
x = y.right
|
||||
}
|
||||
if x != nil {
|
||||
x.parent = y.parent
|
||||
}
|
||||
|
||||
if y.parent == nil {
|
||||
ivt.root = x
|
||||
} else {
|
||||
if y == y.parent.left {
|
||||
y.parent.left = x
|
||||
} else {
|
||||
y.parent.right = x
|
||||
}
|
||||
y.parent.updateMax()
|
||||
}
|
||||
if y != z {
|
||||
z.iv = y.iv
|
||||
z.updateMax()
|
||||
}
|
||||
|
||||
if y.color() == black && x != nil {
|
||||
ivt.deleteFixup(x)
|
||||
}
|
||||
|
||||
ivt.count--
|
||||
return true
|
||||
}
|
||||
|
||||
func (ivt *IntervalTree) deleteFixup(x *intervalNode) {
|
||||
for x != ivt.root && x.color() == black && x.parent != nil {
|
||||
if x == x.parent.left {
|
||||
w := x.parent.right
|
||||
if w.color() == red {
|
||||
w.c = black
|
||||
x.parent.c = red
|
||||
ivt.rotateLeft(x.parent)
|
||||
w = x.parent.right
|
||||
}
|
||||
if w == nil {
|
||||
break
|
||||
}
|
||||
if w.left.color() == black && w.right.color() == black {
|
||||
w.c = red
|
||||
x = x.parent
|
||||
} else {
|
||||
if w.right.color() == black {
|
||||
w.left.c = black
|
||||
w.c = red
|
||||
ivt.rotateRight(w)
|
||||
w = x.parent.right
|
||||
}
|
||||
w.c = x.parent.color()
|
||||
x.parent.c = black
|
||||
w.right.c = black
|
||||
ivt.rotateLeft(x.parent)
|
||||
x = ivt.root
|
||||
}
|
||||
} else {
|
||||
// same as above but with left and right exchanged
|
||||
w := x.parent.left
|
||||
if w.color() == red {
|
||||
w.c = black
|
||||
x.parent.c = red
|
||||
ivt.rotateRight(x.parent)
|
||||
w = x.parent.left
|
||||
}
|
||||
if w == nil {
|
||||
break
|
||||
}
|
||||
if w.left.color() == black && w.right.color() == black {
|
||||
w.c = red
|
||||
x = x.parent
|
||||
} else {
|
||||
if w.left.color() == black {
|
||||
w.right.c = black
|
||||
w.c = red
|
||||
ivt.rotateLeft(w)
|
||||
w = x.parent.left
|
||||
}
|
||||
w.c = x.parent.color()
|
||||
x.parent.c = black
|
||||
w.left.c = black
|
||||
ivt.rotateRight(x.parent)
|
||||
x = ivt.root
|
||||
}
|
||||
}
|
||||
}
|
||||
if x != nil {
|
||||
x.c = black
|
||||
}
|
||||
}
|
||||
|
||||
// Insert adds a node with the given interval into the tree.
|
||||
func (ivt *IntervalTree) Insert(ivl Interval, val interface{}) {
|
||||
var y *intervalNode
|
||||
z := &intervalNode{iv: IntervalValue{ivl, val}, max: ivl.End, c: red}
|
||||
x := ivt.root
|
||||
for x != nil {
|
||||
y = x
|
||||
if z.iv.Ivl.Begin.Compare(x.iv.Ivl.Begin) < 0 {
|
||||
x = x.left
|
||||
} else {
|
||||
x = x.right
|
||||
}
|
||||
}
|
||||
|
||||
z.parent = y
|
||||
if y == nil {
|
||||
ivt.root = z
|
||||
} else {
|
||||
if z.iv.Ivl.Begin.Compare(y.iv.Ivl.Begin) < 0 {
|
||||
y.left = z
|
||||
} else {
|
||||
y.right = z
|
||||
}
|
||||
y.updateMax()
|
||||
}
|
||||
z.c = red
|
||||
ivt.insertFixup(z)
|
||||
ivt.count++
|
||||
}
|
||||
|
||||
func (ivt *IntervalTree) insertFixup(z *intervalNode) {
|
||||
for z.parent != nil && z.parent.parent != nil && z.parent.color() == red {
|
||||
if z.parent == z.parent.parent.left {
|
||||
y := z.parent.parent.right
|
||||
if y.color() == red {
|
||||
y.c = black
|
||||
z.parent.c = black
|
||||
z.parent.parent.c = red
|
||||
z = z.parent.parent
|
||||
} else {
|
||||
if z == z.parent.right {
|
||||
z = z.parent
|
||||
ivt.rotateLeft(z)
|
||||
}
|
||||
z.parent.c = black
|
||||
z.parent.parent.c = red
|
||||
ivt.rotateRight(z.parent.parent)
|
||||
}
|
||||
} else {
|
||||
// same as then with left/right exchanged
|
||||
y := z.parent.parent.left
|
||||
if y.color() == red {
|
||||
y.c = black
|
||||
z.parent.c = black
|
||||
z.parent.parent.c = red
|
||||
z = z.parent.parent
|
||||
} else {
|
||||
if z == z.parent.left {
|
||||
z = z.parent
|
||||
ivt.rotateRight(z)
|
||||
}
|
||||
z.parent.c = black
|
||||
z.parent.parent.c = red
|
||||
ivt.rotateLeft(z.parent.parent)
|
||||
}
|
||||
}
|
||||
}
|
||||
ivt.root.c = black
|
||||
}
|
||||
|
||||
// rotateLeft moves x so it is left of its right child
|
||||
func (ivt *IntervalTree) rotateLeft(x *intervalNode) {
|
||||
y := x.right
|
||||
x.right = y.left
|
||||
if y.left != nil {
|
||||
y.left.parent = x
|
||||
}
|
||||
x.updateMax()
|
||||
ivt.replaceParent(x, y)
|
||||
y.left = x
|
||||
y.updateMax()
|
||||
}
|
||||
|
||||
// rotateLeft moves x so it is right of its left child
|
||||
func (ivt *IntervalTree) rotateRight(x *intervalNode) {
|
||||
if x == nil {
|
||||
return
|
||||
}
|
||||
y := x.left
|
||||
x.left = y.right
|
||||
if y.right != nil {
|
||||
y.right.parent = x
|
||||
}
|
||||
x.updateMax()
|
||||
ivt.replaceParent(x, y)
|
||||
y.right = x
|
||||
y.updateMax()
|
||||
}
|
||||
|
||||
// replaceParent replaces x's parent with y
|
||||
func (ivt *IntervalTree) replaceParent(x *intervalNode, y *intervalNode) {
|
||||
y.parent = x.parent
|
||||
if x.parent == nil {
|
||||
ivt.root = y
|
||||
} else {
|
||||
if x == x.parent.left {
|
||||
x.parent.left = y
|
||||
} else {
|
||||
x.parent.right = y
|
||||
}
|
||||
x.parent.updateMax()
|
||||
}
|
||||
x.parent = y
|
||||
}
|
||||
|
||||
// Len gives the number of elements in the tree
|
||||
func (ivt *IntervalTree) Len() int { return ivt.count }
|
||||
|
||||
// Height is the number of levels in the tree; one node has height 1.
|
||||
func (ivt *IntervalTree) Height() int { return ivt.root.height() }
|
||||
|
||||
// MaxHeight is the expected maximum tree height given the number of nodes
|
||||
func (ivt *IntervalTree) MaxHeight() int {
|
||||
return int((2 * math.Log2(float64(ivt.Len()+1))) + 0.5)
|
||||
}
|
||||
|
||||
// InternalVisitor is used on tree searchs; return false to stop searching.
|
||||
type IntervalVisitor func(n *IntervalValue) bool
|
||||
|
||||
// Visit calls a visitor function on every tree node intersecting the given interval.
|
||||
func (ivt *IntervalTree) Visit(ivl Interval, ivv IntervalVisitor) {
|
||||
ivt.root.visit(&ivl, func(n *intervalNode) bool { return ivv(&n.iv) })
|
||||
}
|
||||
|
||||
// find the exact node for a given interval
|
||||
func (ivt *IntervalTree) find(ivl Interval) (ret *intervalNode) {
|
||||
f := func(n *intervalNode) bool {
|
||||
if n.iv.Ivl != ivl {
|
||||
return true
|
||||
}
|
||||
ret = n
|
||||
return false
|
||||
}
|
||||
ivt.root.visit(&ivl, f)
|
||||
return ret
|
||||
}
|
||||
|
||||
// Find gets the IntervalValue for the node matching the given interval
|
||||
func (ivt *IntervalTree) Find(ivl Interval) (ret *IntervalValue) {
|
||||
n := ivt.find(ivl)
|
||||
if n == nil {
|
||||
return nil
|
||||
}
|
||||
return &n.iv
|
||||
}
|
||||
|
||||
// Contains returns true if there is some tree node intersecting the given interval.
|
||||
func (ivt *IntervalTree) Contains(iv Interval) bool {
|
||||
x := ivt.root
|
||||
for x != nil && iv.Compare(&x.iv.Ivl) != 0 {
|
||||
if x.left != nil && x.left.max.Compare(iv.Begin) > 0 {
|
||||
x = x.left
|
||||
} else {
|
||||
x = x.right
|
||||
}
|
||||
}
|
||||
return x != nil
|
||||
}
|
||||
|
||||
// Stab returns a slice with all elements in the tree intersecting the interval.
|
||||
func (ivt *IntervalTree) Stab(iv Interval) (ivs []*IntervalValue) {
|
||||
f := func(n *IntervalValue) bool { ivs = append(ivs, n); return true }
|
||||
ivt.Visit(iv, f)
|
||||
return ivs
|
||||
}
|
||||
|
||||
type StringComparable string
|
||||
|
||||
func (s StringComparable) Compare(c Comparable) int {
|
||||
sc := c.(StringComparable)
|
||||
if s < sc {
|
||||
return -1
|
||||
}
|
||||
if s > sc {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func NewStringInterval(begin, end string) Interval {
|
||||
return Interval{StringComparable(begin), StringComparable(end)}
|
||||
}
|
||||
|
||||
func NewStringPoint(s string) Interval {
|
||||
return Interval{StringComparable(s), StringComparable(s + "\x00")}
|
||||
}
|
||||
|
||||
// StringAffineComparable treats "" as > all other strings
|
||||
type StringAffineComparable string
|
||||
|
||||
func (s StringAffineComparable) Compare(c Comparable) int {
|
||||
sc := c.(StringAffineComparable)
|
||||
|
||||
if len(s) == 0 {
|
||||
if len(sc) == 0 {
|
||||
return 0
|
||||
}
|
||||
return 1
|
||||
}
|
||||
if len(sc) == 0 {
|
||||
return -1
|
||||
}
|
||||
|
||||
if s < sc {
|
||||
return -1
|
||||
}
|
||||
if s > sc {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func NewStringAffineInterval(begin, end string) Interval {
|
||||
return Interval{StringAffineComparable(begin), StringAffineComparable(end)}
|
||||
}
|
||||
func NewStringAffinePoint(s string) Interval {
|
||||
return NewStringAffineInterval(s, s+"\x00")
|
||||
}
|
||||
|
||||
func NewInt64Interval(a int64, b int64) Interval {
|
||||
return Interval{Int64Comparable(a), Int64Comparable(b)}
|
||||
}
|
||||
|
||||
func NewInt64Point(a int64) Interval {
|
||||
return Interval{Int64Comparable(a), Int64Comparable(a + 1)}
|
||||
}
|
||||
|
||||
type Int64Comparable int64
|
||||
|
||||
func (v Int64Comparable) Compare(c Comparable) int {
|
||||
vc := c.(Int64Comparable)
|
||||
cmp := v - vc
|
||||
if cmp < 0 {
|
||||
return -1
|
||||
}
|
||||
if cmp > 0 {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
138
pkg/adt/interval_tree_test.go
Normal file
138
pkg/adt/interval_tree_test.go
Normal file
@ -0,0 +1,138 @@
|
||||
// Copyright 2016 CoreOS, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package adt
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestIntervalTreeContains(t *testing.T) {
|
||||
ivt := &IntervalTree{}
|
||||
ivt.Insert(NewStringInterval("1", "3"), 123)
|
||||
|
||||
if ivt.Contains(NewStringPoint("0")) {
|
||||
t.Errorf("contains 0")
|
||||
}
|
||||
if !ivt.Contains(NewStringPoint("1")) {
|
||||
t.Errorf("missing 1")
|
||||
}
|
||||
if !ivt.Contains(NewStringPoint("11")) {
|
||||
t.Errorf("missing 11")
|
||||
}
|
||||
if !ivt.Contains(NewStringPoint("2")) {
|
||||
t.Errorf("missing 2")
|
||||
}
|
||||
if ivt.Contains(NewStringPoint("3")) {
|
||||
t.Errorf("contains 3")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntervalTreeStringAffine(t *testing.T) {
|
||||
ivt := &IntervalTree{}
|
||||
ivt.Insert(NewStringAffineInterval("8", ""), 123)
|
||||
if !ivt.Contains(NewStringAffinePoint("9")) {
|
||||
t.Errorf("missing 9")
|
||||
}
|
||||
if ivt.Contains(NewStringAffinePoint("7")) {
|
||||
t.Errorf("contains 7")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntervalTreeStab(t *testing.T) {
|
||||
ivt := &IntervalTree{}
|
||||
ivt.Insert(NewStringInterval("0", "1"), 123)
|
||||
ivt.Insert(NewStringInterval("0", "2"), 456)
|
||||
ivt.Insert(NewStringInterval("5", "6"), 789)
|
||||
ivt.Insert(NewStringInterval("6", "8"), 999)
|
||||
ivt.Insert(NewStringInterval("0", "3"), 0)
|
||||
|
||||
if ivt.root.max.Compare(StringComparable("8")) != 0 {
|
||||
t.Fatalf("wrong root max got %v, expected 8", ivt.root.max)
|
||||
}
|
||||
if x := len(ivt.Stab(NewStringPoint("0"))); x != 3 {
|
||||
t.Errorf("got %d, expected 3", x)
|
||||
}
|
||||
if x := len(ivt.Stab(NewStringPoint("1"))); x != 2 {
|
||||
t.Errorf("got %d, expected 2", x)
|
||||
}
|
||||
if x := len(ivt.Stab(NewStringPoint("2"))); x != 1 {
|
||||
t.Errorf("got %d, expected 1", x)
|
||||
}
|
||||
if x := len(ivt.Stab(NewStringPoint("3"))); x != 0 {
|
||||
t.Errorf("got %d, expected 0", x)
|
||||
}
|
||||
if x := len(ivt.Stab(NewStringPoint("5"))); x != 1 {
|
||||
t.Errorf("got %d, expected 1", x)
|
||||
}
|
||||
if x := len(ivt.Stab(NewStringPoint("55"))); x != 1 {
|
||||
t.Errorf("got %d, expected 1", x)
|
||||
}
|
||||
if x := len(ivt.Stab(NewStringPoint("6"))); x != 1 {
|
||||
t.Errorf("got %d, expected 1", x)
|
||||
}
|
||||
}
|
||||
|
||||
type xy struct {
|
||||
x int64
|
||||
y int64
|
||||
}
|
||||
|
||||
func TestIntervalTreeRandom(t *testing.T) {
|
||||
// generate unique intervals
|
||||
ivs := make(map[xy]struct{})
|
||||
ivt := &IntervalTree{}
|
||||
maxv := 128
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
|
||||
for i := rand.Intn(maxv) + 1; i != 0; i-- {
|
||||
x, y := int64(rand.Intn(maxv)), int64(rand.Intn(maxv))
|
||||
if x > y {
|
||||
t := x
|
||||
x = y
|
||||
y = t
|
||||
} else if x == y {
|
||||
y++
|
||||
}
|
||||
iv := xy{x, y}
|
||||
if _, ok := ivs[iv]; ok {
|
||||
// don't double insert
|
||||
continue
|
||||
}
|
||||
ivt.Insert(NewInt64Interval(x, y), 123)
|
||||
ivs[iv] = struct{}{}
|
||||
}
|
||||
|
||||
for ab := range ivs {
|
||||
for xy := range ivs {
|
||||
v := xy.x + int64(rand.Intn(int(xy.y-xy.x)))
|
||||
if slen := len(ivt.Stab(NewInt64Point(v))); slen == 0 {
|
||||
t.Fatalf("expected %v stab non-zero for [%+v)", v, xy)
|
||||
}
|
||||
if !ivt.Contains(NewInt64Point(v)) {
|
||||
t.Fatalf("did not get %d as expected for [%+v)", v, xy)
|
||||
}
|
||||
}
|
||||
if !ivt.Delete(NewInt64Interval(ab.x, ab.y)) {
|
||||
t.Errorf("did not delete %v as expected", ab)
|
||||
}
|
||||
delete(ivs, ab)
|
||||
}
|
||||
|
||||
if ivt.Len() != 0 {
|
||||
t.Errorf("got ivt.Len() = %v, expected 0", ivt.Len())
|
||||
}
|
||||
}
|
@ -722,13 +722,10 @@ func TestWatchableKVWatch(t *testing.T) {
|
||||
w := s.NewWatchStream()
|
||||
defer w.Close()
|
||||
|
||||
wid := w.Watch([]byte("foo"), true, 0)
|
||||
wid := w.Watch([]byte("foo"), []byte("fop"), 0)
|
||||
|
||||
s.Put([]byte("foo"), []byte("bar"), 1)
|
||||
select {
|
||||
case resp := <-w.Chan():
|
||||
wev := storagepb.Event{
|
||||
Type: storagepb.PUT,
|
||||
wev := []storagepb.Event{
|
||||
{Type: storagepb.PUT,
|
||||
Kv: &storagepb.KeyValue{
|
||||
Key: []byte("foo"),
|
||||
Value: []byte("bar"),
|
||||
@ -737,23 +734,8 @@ func TestWatchableKVWatch(t *testing.T) {
|
||||
Version: 1,
|
||||
Lease: 1,
|
||||
},
|
||||
}
|
||||
if resp.WatchID != wid {
|
||||
t.Errorf("resp.WatchID got = %d, want = %d", resp.WatchID, wid)
|
||||
}
|
||||
ev := resp.Events[0]
|
||||
if !reflect.DeepEqual(ev, wev) {
|
||||
t.Errorf("watched event = %+v, want %+v", ev, wev)
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
// CPU might be too slow, and the routine is not able to switch around
|
||||
testutil.FatalStack(t, "failed to watch the event")
|
||||
}
|
||||
|
||||
s.Put([]byte("foo1"), []byte("bar1"), 2)
|
||||
select {
|
||||
case resp := <-w.Chan():
|
||||
wev := storagepb.Event{
|
||||
},
|
||||
{
|
||||
Type: storagepb.PUT,
|
||||
Kv: &storagepb.KeyValue{
|
||||
Key: []byte("foo1"),
|
||||
@ -763,49 +745,8 @@ func TestWatchableKVWatch(t *testing.T) {
|
||||
Version: 1,
|
||||
Lease: 2,
|
||||
},
|
||||
}
|
||||
if resp.WatchID != wid {
|
||||
t.Errorf("resp.WatchID got = %d, want = %d", resp.WatchID, wid)
|
||||
}
|
||||
ev := resp.Events[0]
|
||||
if !reflect.DeepEqual(ev, wev) {
|
||||
t.Errorf("watched event = %+v, want %+v", ev, wev)
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
testutil.FatalStack(t, "failed to watch the event")
|
||||
}
|
||||
|
||||
w = s.NewWatchStream()
|
||||
wid = w.Watch([]byte("foo1"), false, 1)
|
||||
|
||||
select {
|
||||
case resp := <-w.Chan():
|
||||
wev := storagepb.Event{
|
||||
Type: storagepb.PUT,
|
||||
Kv: &storagepb.KeyValue{
|
||||
Key: []byte("foo1"),
|
||||
Value: []byte("bar1"),
|
||||
CreateRevision: 3,
|
||||
ModRevision: 3,
|
||||
Version: 1,
|
||||
Lease: 2,
|
||||
},
|
||||
}
|
||||
if resp.WatchID != wid {
|
||||
t.Errorf("resp.WatchID got = %d, want = %d", resp.WatchID, wid)
|
||||
}
|
||||
ev := resp.Events[0]
|
||||
if !reflect.DeepEqual(ev, wev) {
|
||||
t.Errorf("watched event = %+v, want %+v", ev, wev)
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
testutil.FatalStack(t, "failed to watch the event")
|
||||
}
|
||||
|
||||
s.Put([]byte("foo1"), []byte("bar11"), 3)
|
||||
select {
|
||||
case resp := <-w.Chan():
|
||||
wev := storagepb.Event{
|
||||
},
|
||||
{
|
||||
Type: storagepb.PUT,
|
||||
Kv: &storagepb.KeyValue{
|
||||
Key: []byte("foo1"),
|
||||
@ -815,13 +756,63 @@ func TestWatchableKVWatch(t *testing.T) {
|
||||
Version: 2,
|
||||
Lease: 3,
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
s.Put([]byte("foo"), []byte("bar"), 1)
|
||||
select {
|
||||
case resp := <-w.Chan():
|
||||
if resp.WatchID != wid {
|
||||
t.Errorf("resp.WatchID got = %d, want = %d", resp.WatchID, wid)
|
||||
}
|
||||
ev := resp.Events[0]
|
||||
if !reflect.DeepEqual(ev, wev) {
|
||||
t.Errorf("watched event = %+v, want %+v", ev, wev)
|
||||
if !reflect.DeepEqual(ev, wev[0]) {
|
||||
t.Errorf("watched event = %+v, want %+v", ev, wev[0])
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
// CPU might be too slow, and the routine is not able to switch around
|
||||
testutil.FatalStack(t, "failed to watch the event")
|
||||
}
|
||||
|
||||
s.Put([]byte("foo1"), []byte("bar1"), 2)
|
||||
select {
|
||||
case resp := <-w.Chan():
|
||||
if resp.WatchID != wid {
|
||||
t.Errorf("resp.WatchID got = %d, want = %d", resp.WatchID, wid)
|
||||
}
|
||||
ev := resp.Events[0]
|
||||
if !reflect.DeepEqual(ev, wev[1]) {
|
||||
t.Errorf("watched event = %+v, want %+v", ev, wev[1])
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
testutil.FatalStack(t, "failed to watch the event")
|
||||
}
|
||||
|
||||
w = s.NewWatchStream()
|
||||
wid = w.Watch([]byte("foo1"), []byte("foo2"), 3)
|
||||
|
||||
select {
|
||||
case resp := <-w.Chan():
|
||||
if resp.WatchID != wid {
|
||||
t.Errorf("resp.WatchID got = %d, want = %d", resp.WatchID, wid)
|
||||
}
|
||||
ev := resp.Events[0]
|
||||
if !reflect.DeepEqual(ev, wev[1]) {
|
||||
t.Errorf("watched event = %+v, want %+v", ev, wev[1])
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
testutil.FatalStack(t, "failed to watch the event")
|
||||
}
|
||||
|
||||
s.Put([]byte("foo1"), []byte("bar11"), 3)
|
||||
select {
|
||||
case resp := <-w.Chan():
|
||||
if resp.WatchID != wid {
|
||||
t.Errorf("resp.WatchID got = %d, want = %d", resp.WatchID, wid)
|
||||
}
|
||||
ev := resp.Events[0]
|
||||
if !reflect.DeepEqual(ev, wev[2]) {
|
||||
t.Errorf("watched event = %+v, want %+v", ev, wev[2])
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
testutil.FatalStack(t, "failed to watch the event")
|
||||
|
@ -16,8 +16,6 @@ package storage
|
||||
|
||||
import (
|
||||
"log"
|
||||
"math"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@ -34,103 +32,8 @@ const (
|
||||
chanBufLen = 1024
|
||||
)
|
||||
|
||||
var (
|
||||
// watchBatchMaxRevs is the maximum distinct revisions that
|
||||
// may be sent to an unsynced watcher at a time. Declared as
|
||||
// var instead of const for testing purposes.
|
||||
watchBatchMaxRevs = 1000
|
||||
)
|
||||
|
||||
type eventBatch struct {
|
||||
// evs is a batch of revision-ordered events
|
||||
evs []storagepb.Event
|
||||
// revs is the minimum unique revisions observed for this batch
|
||||
revs int
|
||||
// moreRev is first revision with more events following this batch
|
||||
moreRev int64
|
||||
}
|
||||
|
||||
type (
|
||||
watcherSetByKey map[string]watcherSet
|
||||
watcherSet map[*watcher]struct{}
|
||||
watcherBatch map[*watcher]*eventBatch
|
||||
)
|
||||
|
||||
func (eb *eventBatch) add(ev storagepb.Event) {
|
||||
if eb.revs > watchBatchMaxRevs {
|
||||
// maxed out batch size
|
||||
return
|
||||
}
|
||||
|
||||
if len(eb.evs) == 0 {
|
||||
// base case
|
||||
eb.revs = 1
|
||||
eb.evs = append(eb.evs, ev)
|
||||
return
|
||||
}
|
||||
|
||||
// revision accounting
|
||||
ebRev := eb.evs[len(eb.evs)-1].Kv.ModRevision
|
||||
evRev := ev.Kv.ModRevision
|
||||
if evRev > ebRev {
|
||||
eb.revs++
|
||||
if eb.revs > watchBatchMaxRevs {
|
||||
eb.moreRev = evRev
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
eb.evs = append(eb.evs, ev)
|
||||
}
|
||||
|
||||
func (wb watcherBatch) add(w *watcher, ev storagepb.Event) {
|
||||
eb := wb[w]
|
||||
if eb == nil {
|
||||
eb = &eventBatch{}
|
||||
wb[w] = eb
|
||||
}
|
||||
eb.add(ev)
|
||||
}
|
||||
|
||||
func (w watcherSet) add(wa *watcher) {
|
||||
if _, ok := w[wa]; ok {
|
||||
panic("add watcher twice!")
|
||||
}
|
||||
w[wa] = struct{}{}
|
||||
}
|
||||
|
||||
func (w watcherSetByKey) add(wa *watcher) {
|
||||
set := w[string(wa.key)]
|
||||
if set == nil {
|
||||
set = make(watcherSet)
|
||||
w[string(wa.key)] = set
|
||||
}
|
||||
set.add(wa)
|
||||
}
|
||||
|
||||
func (w watcherSetByKey) getSetByKey(key string) (watcherSet, bool) {
|
||||
set, ok := w[key]
|
||||
return set, ok
|
||||
}
|
||||
|
||||
func (w watcherSetByKey) delete(wa *watcher) bool {
|
||||
k := string(wa.key)
|
||||
if v, ok := w[k]; ok {
|
||||
if _, ok := v[wa]; ok {
|
||||
delete(v, wa)
|
||||
// if there is nothing in the set,
|
||||
// remove the set
|
||||
if len(v) == 0 {
|
||||
delete(w, k)
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type watchable interface {
|
||||
watch(key []byte, prefix bool, startRev int64, id WatchID, ch chan<- WatchResponse) (*watcher, cancelFunc)
|
||||
watch(key, end []byte, startRev int64, id WatchID, ch chan<- WatchResponse) (*watcher, cancelFunc)
|
||||
rev() int64
|
||||
}
|
||||
|
||||
@ -140,11 +43,11 @@ type watchableStore struct {
|
||||
*store
|
||||
|
||||
// contains all unsynced watchers that needs to sync with events that have happened
|
||||
unsynced watcherSetByKey
|
||||
unsynced watcherGroup
|
||||
|
||||
// contains all synced watchers that are in sync with the progress of the store.
|
||||
// The key of the map is the key that the watcher watches on.
|
||||
synced watcherSetByKey
|
||||
synced watcherGroup
|
||||
|
||||
stopc chan struct{}
|
||||
wg sync.WaitGroup
|
||||
@ -157,8 +60,8 @@ type cancelFunc func()
|
||||
func newWatchableStore(b backend.Backend, le lease.Lessor) *watchableStore {
|
||||
s := &watchableStore{
|
||||
store: NewStore(b, le),
|
||||
unsynced: make(watcherSetByKey),
|
||||
synced: make(watcherSetByKey),
|
||||
unsynced: newWatcherGroup(),
|
||||
synced: newWatcherGroup(),
|
||||
stopc: make(chan struct{}),
|
||||
}
|
||||
if s.le != nil {
|
||||
@ -268,16 +171,16 @@ func (s *watchableStore) NewWatchStream() WatchStream {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *watchableStore) watch(key []byte, prefix bool, startRev int64, id WatchID, ch chan<- WatchResponse) (*watcher, cancelFunc) {
|
||||
func (s *watchableStore) watch(key, end []byte, startRev int64, id WatchID, ch chan<- WatchResponse) (*watcher, cancelFunc) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
wa := &watcher{
|
||||
key: key,
|
||||
prefix: prefix,
|
||||
cur: startRev,
|
||||
id: id,
|
||||
ch: ch,
|
||||
key: key,
|
||||
end: end,
|
||||
cur: startRev,
|
||||
id: id,
|
||||
ch: ch,
|
||||
}
|
||||
|
||||
s.store.mu.Lock()
|
||||
@ -342,15 +245,16 @@ func (s *watchableStore) syncWatchers() {
|
||||
s.store.mu.Lock()
|
||||
defer s.store.mu.Unlock()
|
||||
|
||||
if len(s.unsynced) == 0 {
|
||||
if s.unsynced.size() == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// in order to find key-value pairs from unsynced watchers, we need to
|
||||
// find min revision index, and these revisions can be used to
|
||||
// query the backend store of key-value pairs
|
||||
prefixes, minRev := s.scanUnsync()
|
||||
curRev := s.store.currentRev.main
|
||||
compactionRev := s.store.compactMainRev
|
||||
minRev := s.unsynced.scanMinRev(curRev, compactionRev)
|
||||
minBytes, maxBytes := newRevBytes(), newRevBytes()
|
||||
revToBytes(revision{main: minRev}, minBytes)
|
||||
revToBytes(revision{main: curRev + 1}, maxBytes)
|
||||
@ -360,10 +264,10 @@ func (s *watchableStore) syncWatchers() {
|
||||
tx := s.store.b.BatchTx()
|
||||
tx.Lock()
|
||||
revs, vs := tx.UnsafeRange(keyBucketName, minBytes, maxBytes, 0)
|
||||
evs := kvsToEvents(revs, vs, s.unsynced, prefixes)
|
||||
evs := kvsToEvents(&s.unsynced, revs, vs)
|
||||
tx.Unlock()
|
||||
|
||||
for w, eb := range newWatcherBatch(s.unsynced, evs) {
|
||||
for w, eb := range newWatcherBatch(&s.unsynced, evs) {
|
||||
select {
|
||||
// s.store.Rev also uses Lock, so just return directly
|
||||
case w.ch <- WatchResponse{WatchID: w.id, Events: eb.evs, Revision: s.store.currentRev.main}:
|
||||
@ -383,56 +287,18 @@ func (s *watchableStore) syncWatchers() {
|
||||
s.unsynced.delete(w)
|
||||
}
|
||||
|
||||
slowWatcherGauge.Set(float64(len(s.unsynced)))
|
||||
}
|
||||
|
||||
func (s *watchableStore) scanUnsync() (prefixes map[string]struct{}, minRev int64) {
|
||||
curRev := s.store.currentRev.main
|
||||
compactionRev := s.store.compactMainRev
|
||||
|
||||
prefixes = make(map[string]struct{})
|
||||
minRev = int64(math.MaxInt64)
|
||||
for _, set := range s.unsynced {
|
||||
for w := range set {
|
||||
k := string(w.key)
|
||||
|
||||
if w.cur > curRev {
|
||||
panic("watcher current revision should not exceed current revision")
|
||||
}
|
||||
|
||||
if w.cur < compactionRev {
|
||||
select {
|
||||
case w.ch <- WatchResponse{WatchID: w.id, CompactRevision: compactionRev}:
|
||||
s.unsynced.delete(w)
|
||||
default:
|
||||
// retry next time
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if minRev > w.cur {
|
||||
minRev = w.cur
|
||||
}
|
||||
|
||||
if w.prefix {
|
||||
prefixes[k] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return prefixes, minRev
|
||||
slowWatcherGauge.Set(float64(s.unsynced.size()))
|
||||
}
|
||||
|
||||
// kvsToEvents gets all events for the watchers from all key-value pairs
|
||||
func kvsToEvents(revs, vals [][]byte, wsk watcherSetByKey, pfxs map[string]struct{}) (evs []storagepb.Event) {
|
||||
func kvsToEvents(wg *watcherGroup, revs, vals [][]byte) (evs []storagepb.Event) {
|
||||
for i, v := range vals {
|
||||
var kv storagepb.KeyValue
|
||||
if err := kv.Unmarshal(v); err != nil {
|
||||
log.Panicf("storage: cannot unmarshal event: %v", err)
|
||||
}
|
||||
|
||||
k := string(kv.Key)
|
||||
if _, ok := wsk.getSetByKey(k); !ok && !matchPrefix(k, pfxs) {
|
||||
if !wg.contains(string(kv.Key)) {
|
||||
continue
|
||||
}
|
||||
|
||||
@ -450,26 +316,19 @@ func kvsToEvents(revs, vals [][]byte, wsk watcherSetByKey, pfxs map[string]struc
|
||||
// notify notifies the fact that given event at the given rev just happened to
|
||||
// watchers that watch on the key of the event.
|
||||
func (s *watchableStore) notify(rev int64, evs []storagepb.Event) {
|
||||
we := newWatcherBatch(s.synced, evs)
|
||||
for _, wm := range s.synced {
|
||||
for w := range wm {
|
||||
eb, ok := we[w]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if eb.revs != 1 {
|
||||
panic("unexpected multiple revisions in notification")
|
||||
}
|
||||
select {
|
||||
case w.ch <- WatchResponse{WatchID: w.id, Events: eb.evs, Revision: s.Rev()}:
|
||||
pendingEventsGauge.Add(float64(len(eb.evs)))
|
||||
default:
|
||||
// move slow watcher to unsynced
|
||||
w.cur = rev
|
||||
s.unsynced.add(w)
|
||||
delete(wm, w)
|
||||
slowWatcherGauge.Inc()
|
||||
}
|
||||
for w, eb := range newWatcherBatch(&s.synced, evs) {
|
||||
if eb.revs != 1 {
|
||||
panic("unexpected multiple revisions in notification")
|
||||
}
|
||||
select {
|
||||
case w.ch <- WatchResponse{WatchID: w.id, Events: eb.evs, Revision: s.Rev()}:
|
||||
pendingEventsGauge.Add(float64(len(eb.evs)))
|
||||
default:
|
||||
// move slow watcher to unsynced
|
||||
w.cur = rev
|
||||
s.unsynced.add(w)
|
||||
s.synced.delete(w)
|
||||
slowWatcherGauge.Inc()
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -479,9 +338,9 @@ func (s *watchableStore) rev() int64 { return s.store.Rev() }
|
||||
type watcher struct {
|
||||
// the watcher key
|
||||
key []byte
|
||||
// prefix indicates if watcher is on a key or a prefix.
|
||||
// If prefix is true, the watcher is on a prefix.
|
||||
prefix bool
|
||||
// end indicates the end of the range to watch.
|
||||
// If end is set, the watcher is on a range.
|
||||
end []byte
|
||||
// cur is the current watcher revision.
|
||||
// If cur is behind the current revision of the KV,
|
||||
// watcher is unsynced and needs to catch up.
|
||||
@ -492,42 +351,3 @@ type watcher struct {
|
||||
// The chan might be shared with other watchers.
|
||||
ch chan<- WatchResponse
|
||||
}
|
||||
|
||||
// newWatcherBatch maps watchers to their matched events. It enables quick
|
||||
// events look up by watcher.
|
||||
func newWatcherBatch(sm watcherSetByKey, evs []storagepb.Event) watcherBatch {
|
||||
wb := make(watcherBatch)
|
||||
for _, ev := range evs {
|
||||
key := string(ev.Kv.Key)
|
||||
|
||||
// check all prefixes of the key to notify all corresponded watchers
|
||||
for i := 0; i <= len(key); i++ {
|
||||
for w := range sm[key[:i]] {
|
||||
// don't double notify
|
||||
if ev.Kv.ModRevision < w.cur {
|
||||
continue
|
||||
}
|
||||
|
||||
// the watcher needs to be notified when either it watches prefix or
|
||||
// the key is exactly matched.
|
||||
if !w.prefix && i != len(ev.Kv.Key) {
|
||||
continue
|
||||
}
|
||||
wb.add(w, ev)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return wb
|
||||
}
|
||||
|
||||
// matchPrefix returns true if key has any matching prefix
|
||||
// from prefixes map.
|
||||
func matchPrefix(key string, prefixes map[string]struct{}) bool {
|
||||
for p := range prefixes {
|
||||
if strings.HasPrefix(key, p) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
@ -40,11 +40,11 @@ func BenchmarkWatchableStoreUnsyncedCancel(b *testing.B) {
|
||||
// in unsynced for this benchmark.
|
||||
ws := &watchableStore{
|
||||
store: s,
|
||||
unsynced: make(watcherSetByKey),
|
||||
unsynced: newWatcherGroup(),
|
||||
|
||||
// to make the test not crash from assigning to nil map.
|
||||
// 'synced' doesn't get populated in this test.
|
||||
synced: make(watcherSetByKey),
|
||||
synced: newWatcherGroup(),
|
||||
}
|
||||
|
||||
defer func() {
|
||||
@ -69,7 +69,7 @@ func BenchmarkWatchableStoreUnsyncedCancel(b *testing.B) {
|
||||
watchIDs := make([]WatchID, watcherN)
|
||||
for i := 0; i < watcherN; i++ {
|
||||
// non-0 value to keep watchers in unsynced
|
||||
watchIDs[i] = w.Watch(testKey, true, 1)
|
||||
watchIDs[i] = w.Watch(testKey, nil, 1)
|
||||
}
|
||||
|
||||
// random-cancel N watchers to make it not biased towards
|
||||
@ -109,7 +109,7 @@ func BenchmarkWatchableStoreSyncedCancel(b *testing.B) {
|
||||
watchIDs := make([]WatchID, watcherN)
|
||||
for i := 0; i < watcherN; i++ {
|
||||
// 0 for startRev to keep watchers in synced
|
||||
watchIDs[i] = w.Watch(testKey, true, 0)
|
||||
watchIDs[i] = w.Watch(testKey, nil, 0)
|
||||
}
|
||||
|
||||
// randomly cancel watchers to make it not biased towards
|
||||
|
@ -40,11 +40,11 @@ func TestWatch(t *testing.T) {
|
||||
s.Put(testKey, testValue, lease.NoLease)
|
||||
|
||||
w := s.NewWatchStream()
|
||||
w.Watch(testKey, true, 0)
|
||||
w.Watch(testKey, nil, 0)
|
||||
|
||||
if _, ok := s.synced[string(testKey)]; !ok {
|
||||
if !s.synced.contains(string(testKey)) {
|
||||
// the key must have had an entry in synced
|
||||
t.Errorf("existence = %v, want true", ok)
|
||||
t.Errorf("existence = false, want true")
|
||||
}
|
||||
}
|
||||
|
||||
@ -61,15 +61,15 @@ func TestNewWatcherCancel(t *testing.T) {
|
||||
s.Put(testKey, testValue, lease.NoLease)
|
||||
|
||||
w := s.NewWatchStream()
|
||||
wt := w.Watch(testKey, true, 0)
|
||||
wt := w.Watch(testKey, nil, 0)
|
||||
|
||||
if err := w.Cancel(wt); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if _, ok := s.synced[string(testKey)]; ok {
|
||||
if s.synced.contains(string(testKey)) {
|
||||
// the key shoud have been deleted
|
||||
t.Errorf("existence = %v, want false", ok)
|
||||
t.Errorf("existence = true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
@ -83,11 +83,11 @@ func TestCancelUnsynced(t *testing.T) {
|
||||
// in unsynced to test if syncWatchers works as expected.
|
||||
s := &watchableStore{
|
||||
store: NewStore(b, &lease.FakeLessor{}),
|
||||
unsynced: make(watcherSetByKey),
|
||||
unsynced: newWatcherGroup(),
|
||||
|
||||
// to make the test not crash from assigning to nil map.
|
||||
// 'synced' doesn't get populated in this test.
|
||||
synced: make(watcherSetByKey),
|
||||
synced: newWatcherGroup(),
|
||||
}
|
||||
|
||||
defer func() {
|
||||
@ -112,7 +112,7 @@ func TestCancelUnsynced(t *testing.T) {
|
||||
watchIDs := make([]WatchID, watcherN)
|
||||
for i := 0; i < watcherN; i++ {
|
||||
// use 1 to keep watchers in unsynced
|
||||
watchIDs[i] = w.Watch(testKey, true, 1)
|
||||
watchIDs[i] = w.Watch(testKey, nil, 1)
|
||||
}
|
||||
|
||||
for _, idx := range watchIDs {
|
||||
@ -125,8 +125,8 @@ func TestCancelUnsynced(t *testing.T) {
|
||||
//
|
||||
// unsynced should be empty
|
||||
// because cancel removes watcher from unsynced
|
||||
if len(s.unsynced) != 0 {
|
||||
t.Errorf("unsynced size = %d, want 0", len(s.unsynced))
|
||||
if size := s.unsynced.size(); size != 0 {
|
||||
t.Errorf("unsynced size = %d, want 0", size)
|
||||
}
|
||||
}
|
||||
|
||||
@ -138,8 +138,8 @@ func TestSyncWatchers(t *testing.T) {
|
||||
|
||||
s := &watchableStore{
|
||||
store: NewStore(b, &lease.FakeLessor{}),
|
||||
unsynced: make(watcherSetByKey),
|
||||
synced: make(watcherSetByKey),
|
||||
unsynced: newWatcherGroup(),
|
||||
synced: newWatcherGroup(),
|
||||
}
|
||||
|
||||
defer func() {
|
||||
@ -158,13 +158,13 @@ func TestSyncWatchers(t *testing.T) {
|
||||
|
||||
for i := 0; i < watcherN; i++ {
|
||||
// specify rev as 1 to keep watchers in unsynced
|
||||
w.Watch(testKey, true, 1)
|
||||
w.Watch(testKey, nil, 1)
|
||||
}
|
||||
|
||||
// Before running s.syncWatchers() synced should be empty because we manually
|
||||
// populate unsynced only
|
||||
sws, _ := s.synced.getSetByKey(string(testKey))
|
||||
uws, _ := s.unsynced.getSetByKey(string(testKey))
|
||||
sws := s.synced.watcherSetByKey(string(testKey))
|
||||
uws := s.unsynced.watcherSetByKey(string(testKey))
|
||||
|
||||
if len(sws) != 0 {
|
||||
t.Fatalf("synced[string(testKey)] size = %d, want 0", len(sws))
|
||||
@ -177,8 +177,8 @@ func TestSyncWatchers(t *testing.T) {
|
||||
// this should move all unsynced watchers to synced ones
|
||||
s.syncWatchers()
|
||||
|
||||
sws, _ = s.synced.getSetByKey(string(testKey))
|
||||
uws, _ = s.unsynced.getSetByKey(string(testKey))
|
||||
sws = s.synced.watcherSetByKey(string(testKey))
|
||||
uws = s.unsynced.watcherSetByKey(string(testKey))
|
||||
|
||||
// After running s.syncWatchers(), synced should not be empty because syncwatchers
|
||||
// populates synced in this test case
|
||||
@ -240,7 +240,7 @@ func TestWatchCompacted(t *testing.T) {
|
||||
}
|
||||
|
||||
w := s.NewWatchStream()
|
||||
wt := w.Watch(testKey, true, compactRev-1)
|
||||
wt := w.Watch(testKey, nil, compactRev-1)
|
||||
|
||||
select {
|
||||
case resp := <-w.Chan():
|
||||
@ -275,7 +275,7 @@ func TestWatchBatchUnsynced(t *testing.T) {
|
||||
}
|
||||
|
||||
w := s.NewWatchStream()
|
||||
w.Watch(v, false, 1)
|
||||
w.Watch(v, nil, 1)
|
||||
for i := 0; i < batches; i++ {
|
||||
if resp := <-w.Chan(); len(resp.Events) != watchBatchMaxRevs {
|
||||
t.Fatalf("len(events) = %d, want %d", len(resp.Events), watchBatchMaxRevs)
|
||||
@ -284,8 +284,8 @@ func TestWatchBatchUnsynced(t *testing.T) {
|
||||
|
||||
s.store.mu.Lock()
|
||||
defer s.store.mu.Unlock()
|
||||
if len(s.synced) != 1 {
|
||||
t.Errorf("synced size = %d, want 1", len(s.synced))
|
||||
if size := s.synced.size(); size != 1 {
|
||||
t.Errorf("synced size = %d, want 1", size)
|
||||
}
|
||||
}
|
||||
|
||||
@ -311,14 +311,14 @@ func TestNewMapwatcherToEventMap(t *testing.T) {
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
sync watcherSetByKey
|
||||
sync []*watcher
|
||||
evs []storagepb.Event
|
||||
|
||||
wwe map[*watcher][]storagepb.Event
|
||||
}{
|
||||
// no watcher in sync, some events should return empty wwe
|
||||
{
|
||||
watcherSetByKey{},
|
||||
nil,
|
||||
evs,
|
||||
map[*watcher][]storagepb.Event{},
|
||||
},
|
||||
@ -326,9 +326,7 @@ func TestNewMapwatcherToEventMap(t *testing.T) {
|
||||
// one watcher in sync, one event that does not match the key of that
|
||||
// watcher should return empty wwe
|
||||
{
|
||||
watcherSetByKey{
|
||||
string(k2): {ws[2]: struct{}{}},
|
||||
},
|
||||
[]*watcher{ws[2]},
|
||||
evs[:1],
|
||||
map[*watcher][]storagepb.Event{},
|
||||
},
|
||||
@ -336,9 +334,7 @@ func TestNewMapwatcherToEventMap(t *testing.T) {
|
||||
// one watcher in sync, one event that matches the key of that
|
||||
// watcher should return wwe with that matching watcher
|
||||
{
|
||||
watcherSetByKey{
|
||||
string(k1): {ws[1]: struct{}{}},
|
||||
},
|
||||
[]*watcher{ws[1]},
|
||||
evs[1:2],
|
||||
map[*watcher][]storagepb.Event{
|
||||
ws[1]: evs[1:2],
|
||||
@ -349,10 +345,7 @@ func TestNewMapwatcherToEventMap(t *testing.T) {
|
||||
// that matches the key of only one of the watcher should return wwe
|
||||
// with the matching watcher
|
||||
{
|
||||
watcherSetByKey{
|
||||
string(k0): {ws[0]: struct{}{}},
|
||||
string(k2): {ws[2]: struct{}{}},
|
||||
},
|
||||
[]*watcher{ws[0], ws[2]},
|
||||
evs[2:],
|
||||
map[*watcher][]storagepb.Event{
|
||||
ws[2]: evs[2:],
|
||||
@ -362,10 +355,7 @@ func TestNewMapwatcherToEventMap(t *testing.T) {
|
||||
// two watchers in sync that watches the same key, two events that
|
||||
// match the keys should return wwe with those two watchers
|
||||
{
|
||||
watcherSetByKey{
|
||||
string(k0): {ws[0]: struct{}{}},
|
||||
string(k1): {ws[1]: struct{}{}},
|
||||
},
|
||||
[]*watcher{ws[0], ws[1]},
|
||||
evs[:2],
|
||||
map[*watcher][]storagepb.Event{
|
||||
ws[0]: evs[:1],
|
||||
@ -375,7 +365,12 @@ func TestNewMapwatcherToEventMap(t *testing.T) {
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
gwe := newWatcherBatch(tt.sync, tt.evs)
|
||||
wg := newWatcherGroup()
|
||||
for _, w := range tt.sync {
|
||||
wg.add(w)
|
||||
}
|
||||
|
||||
gwe := newWatcherBatch(&wg, tt.evs)
|
||||
if len(gwe) != len(tt.wwe) {
|
||||
t.Errorf("#%d: len(gwe) got = %d, want = %d", i, len(gwe), len(tt.wwe))
|
||||
}
|
||||
|
@ -29,16 +29,15 @@ type WatchID int64
|
||||
|
||||
type WatchStream interface {
|
||||
// Watch creates a watcher. The watcher watches the events happening or
|
||||
// happened on the given key or key prefix from the given startRev.
|
||||
// happened on the given key or range [key, end) from the given startRev.
|
||||
//
|
||||
// The whole event history can be watched unless compacted.
|
||||
// If `prefix` is true, watch observes all events whose key prefix could be the given `key`.
|
||||
// If `startRev` <=0, watch observes events after currentRev.
|
||||
//
|
||||
// The returned `id` is the ID of this watcher. It appears as WatchID
|
||||
// in events that are sent to the created watcher through stream channel.
|
||||
//
|
||||
Watch(key []byte, prefix bool, startRev int64) WatchID
|
||||
Watch(key, end []byte, startRev int64) WatchID
|
||||
|
||||
// Chan returns a chan. All watch response will be sent to the returned chan.
|
||||
Chan() <-chan WatchResponse
|
||||
@ -87,7 +86,7 @@ type watchStream struct {
|
||||
|
||||
// Watch creates a new watcher in the stream and returns its WatchID.
|
||||
// TODO: return error if ws is closed?
|
||||
func (ws *watchStream) Watch(key []byte, prefix bool, startRev int64) WatchID {
|
||||
func (ws *watchStream) Watch(key, end []byte, startRev int64) WatchID {
|
||||
ws.mu.Lock()
|
||||
defer ws.mu.Unlock()
|
||||
if ws.closed {
|
||||
@ -97,7 +96,7 @@ func (ws *watchStream) Watch(key []byte, prefix bool, startRev int64) WatchID {
|
||||
id := ws.nextID
|
||||
ws.nextID++
|
||||
|
||||
_, c := ws.watchable.watch(key, prefix, startRev, id, ws.ch)
|
||||
_, c := ws.watchable.watch(key, end, startRev, id, ws.ch)
|
||||
|
||||
ws.cancels[id] = c
|
||||
return id
|
||||
|
@ -33,6 +33,6 @@ func BenchmarkKVWatcherMemoryUsage(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.StartTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
w.Watch([]byte(fmt.Sprint("foo", i)), false, 0)
|
||||
w.Watch([]byte(fmt.Sprint("foo", i)), nil, 0)
|
||||
}
|
||||
}
|
||||
|
269
storage/watcher_group.go
Normal file
269
storage/watcher_group.go
Normal file
@ -0,0 +1,269 @@
|
||||
// Copyright 2016 CoreOS, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package storage
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/coreos/etcd/pkg/adt"
|
||||
"github.com/coreos/etcd/storage/storagepb"
|
||||
)
|
||||
|
||||
var (
|
||||
// watchBatchMaxRevs is the maximum distinct revisions that
|
||||
// may be sent to an unsynced watcher at a time. Declared as
|
||||
// var instead of const for testing purposes.
|
||||
watchBatchMaxRevs = 1000
|
||||
)
|
||||
|
||||
type eventBatch struct {
|
||||
// evs is a batch of revision-ordered events
|
||||
evs []storagepb.Event
|
||||
// revs is the minimum unique revisions observed for this batch
|
||||
revs int
|
||||
// moreRev is first revision with more events following this batch
|
||||
moreRev int64
|
||||
}
|
||||
|
||||
func (eb *eventBatch) add(ev storagepb.Event) {
|
||||
if eb.revs > watchBatchMaxRevs {
|
||||
// maxed out batch size
|
||||
return
|
||||
}
|
||||
|
||||
if len(eb.evs) == 0 {
|
||||
// base case
|
||||
eb.revs = 1
|
||||
eb.evs = append(eb.evs, ev)
|
||||
return
|
||||
}
|
||||
|
||||
// revision accounting
|
||||
ebRev := eb.evs[len(eb.evs)-1].Kv.ModRevision
|
||||
evRev := ev.Kv.ModRevision
|
||||
if evRev > ebRev {
|
||||
eb.revs++
|
||||
if eb.revs > watchBatchMaxRevs {
|
||||
eb.moreRev = evRev
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
eb.evs = append(eb.evs, ev)
|
||||
}
|
||||
|
||||
type watcherBatch map[*watcher]*eventBatch
|
||||
|
||||
func (wb watcherBatch) add(w *watcher, ev storagepb.Event) {
|
||||
eb := wb[w]
|
||||
if eb == nil {
|
||||
eb = &eventBatch{}
|
||||
wb[w] = eb
|
||||
}
|
||||
eb.add(ev)
|
||||
}
|
||||
|
||||
// newWatcherBatch maps watchers to their matched events. It enables quick
|
||||
// events look up by watcher.
|
||||
func newWatcherBatch(wg *watcherGroup, evs []storagepb.Event) watcherBatch {
|
||||
wb := make(watcherBatch)
|
||||
for _, ev := range evs {
|
||||
for w := range wg.watcherSetByKey(string(ev.Kv.Key)) {
|
||||
if ev.Kv.ModRevision >= w.cur {
|
||||
// don't double notify
|
||||
wb.add(w, ev)
|
||||
}
|
||||
}
|
||||
}
|
||||
return wb
|
||||
}
|
||||
|
||||
type watcherSet map[*watcher]struct{}
|
||||
|
||||
func (w watcherSet) add(wa *watcher) {
|
||||
if _, ok := w[wa]; ok {
|
||||
panic("add watcher twice!")
|
||||
}
|
||||
w[wa] = struct{}{}
|
||||
}
|
||||
|
||||
func (w watcherSet) union(ws watcherSet) {
|
||||
for wa := range ws {
|
||||
w.add(wa)
|
||||
}
|
||||
}
|
||||
|
||||
func (w watcherSet) delete(wa *watcher) {
|
||||
if _, ok := w[wa]; !ok {
|
||||
panic("removing missing watcher!")
|
||||
}
|
||||
delete(w, wa)
|
||||
}
|
||||
|
||||
type watcherSetByKey map[string]watcherSet
|
||||
|
||||
func (w watcherSetByKey) add(wa *watcher) {
|
||||
set := w[string(wa.key)]
|
||||
if set == nil {
|
||||
set = make(watcherSet)
|
||||
w[string(wa.key)] = set
|
||||
}
|
||||
set.add(wa)
|
||||
}
|
||||
|
||||
func (w watcherSetByKey) delete(wa *watcher) bool {
|
||||
k := string(wa.key)
|
||||
if v, ok := w[k]; ok {
|
||||
if _, ok := v[wa]; ok {
|
||||
delete(v, wa)
|
||||
if len(v) == 0 {
|
||||
// remove the set; nothing left
|
||||
delete(w, k)
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type interval struct {
|
||||
begin string
|
||||
end string
|
||||
}
|
||||
|
||||
type watcherSetByInterval map[interval]watcherSet
|
||||
|
||||
// watcherGroup is a collection of watchers organized by their ranges
|
||||
type watcherGroup struct {
|
||||
// keyWatchers has the watchers that watch on a single key
|
||||
keyWatchers watcherSetByKey
|
||||
// ranges has the watchers that watch a range; it is sorted by interval
|
||||
ranges adt.IntervalTree
|
||||
// watchers is the set of all watchers
|
||||
watchers watcherSet
|
||||
}
|
||||
|
||||
func newWatcherGroup() watcherGroup {
|
||||
return watcherGroup{
|
||||
keyWatchers: make(watcherSetByKey),
|
||||
watchers: make(watcherSet),
|
||||
}
|
||||
}
|
||||
|
||||
// add puts a watcher in the group.
|
||||
func (wg *watcherGroup) add(wa *watcher) {
|
||||
wg.watchers.add(wa)
|
||||
if wa.end == nil {
|
||||
wg.keyWatchers.add(wa)
|
||||
return
|
||||
}
|
||||
|
||||
// interval already registered?
|
||||
ivl := adt.NewStringAffineInterval(string(wa.key), string(wa.end))
|
||||
if iv := wg.ranges.Find(ivl); iv != nil {
|
||||
iv.Val.(watcherSet).add(wa)
|
||||
return
|
||||
}
|
||||
|
||||
// not registered, put in interval tree
|
||||
ws := make(watcherSet)
|
||||
ws.add(wa)
|
||||
wg.ranges.Insert(ivl, ws)
|
||||
}
|
||||
|
||||
// contains is whether the given key has a watcher in the group.
|
||||
func (wg *watcherGroup) contains(key string) bool {
|
||||
_, ok := wg.keyWatchers[key]
|
||||
return ok || wg.ranges.Contains(adt.NewStringAffinePoint(key))
|
||||
}
|
||||
|
||||
// size gives the number of unique watchers in the group.
|
||||
func (wg *watcherGroup) size() int { return len(wg.watchers) }
|
||||
|
||||
// delete removes a watcher from the group.
|
||||
func (wg *watcherGroup) delete(wa *watcher) bool {
|
||||
if _, ok := wg.watchers[wa]; !ok {
|
||||
return false
|
||||
}
|
||||
wg.watchers.delete(wa)
|
||||
if wa.end == nil {
|
||||
wg.keyWatchers.delete(wa)
|
||||
return true
|
||||
}
|
||||
|
||||
ivl := adt.NewStringAffineInterval(string(wa.key), string(wa.end))
|
||||
iv := wg.ranges.Find(ivl)
|
||||
if iv == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
ws := iv.Val.(watcherSet)
|
||||
delete(ws, wa)
|
||||
if len(ws) == 0 {
|
||||
// remove interval missing watchers
|
||||
if ok := wg.ranges.Delete(ivl); !ok {
|
||||
panic("could not remove watcher from interval tree")
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (wg *watcherGroup) scanMinRev(curRev int64, compactRev int64) int64 {
|
||||
minRev := int64(math.MaxInt64)
|
||||
for w := range wg.watchers {
|
||||
if w.cur > curRev {
|
||||
panic("watcher current revision should not exceed current revision")
|
||||
}
|
||||
if w.cur < compactRev {
|
||||
select {
|
||||
case w.ch <- WatchResponse{WatchID: w.id, CompactRevision: compactRev}:
|
||||
wg.delete(w)
|
||||
default:
|
||||
// retry next time
|
||||
}
|
||||
continue
|
||||
}
|
||||
if minRev > w.cur {
|
||||
minRev = w.cur
|
||||
}
|
||||
}
|
||||
return minRev
|
||||
}
|
||||
|
||||
// watcherSetByKey gets the set of watchers that recieve events on the given key.
|
||||
func (wg *watcherGroup) watcherSetByKey(key string) watcherSet {
|
||||
wkeys := wg.keyWatchers[key]
|
||||
wranges := wg.ranges.Stab(adt.NewStringAffinePoint(key))
|
||||
|
||||
// zero-copy cases
|
||||
switch {
|
||||
case len(wranges) == 0:
|
||||
// no need to merge ranges or copy; reuse single-key set
|
||||
return wkeys
|
||||
case len(wranges) == 0 && len(wkeys) == 0:
|
||||
return nil
|
||||
case len(wranges) == 1 && len(wkeys) == 0:
|
||||
return wranges[0].Val.(watcherSet)
|
||||
}
|
||||
|
||||
// copy case
|
||||
ret := make(watcherSet)
|
||||
ret.union(wg.keyWatchers[key])
|
||||
for _, item := range wranges {
|
||||
ret.union(item.Val.(watcherSet))
|
||||
}
|
||||
return ret
|
||||
}
|
@ -35,7 +35,7 @@ func TestWatcherWatchID(t *testing.T) {
|
||||
idm := make(map[WatchID]struct{})
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
id := w.Watch([]byte("foo"), false, 0)
|
||||
id := w.Watch([]byte("foo"), nil, 0)
|
||||
if _, ok := idm[id]; ok {
|
||||
t.Errorf("#%d: id %d exists", i, id)
|
||||
}
|
||||
@ -57,7 +57,7 @@ func TestWatcherWatchID(t *testing.T) {
|
||||
|
||||
// unsynced watchers
|
||||
for i := 10; i < 20; i++ {
|
||||
id := w.Watch([]byte("foo2"), false, 1)
|
||||
id := w.Watch([]byte("foo2"), nil, 1)
|
||||
if _, ok := idm[id]; ok {
|
||||
t.Errorf("#%d: id %d exists", i, id)
|
||||
}
|
||||
@ -86,12 +86,11 @@ func TestWatcherWatchPrefix(t *testing.T) {
|
||||
|
||||
idm := make(map[WatchID]struct{})
|
||||
|
||||
prefixMatch := true
|
||||
val := []byte("bar")
|
||||
keyWatch, keyPut := []byte("foo"), []byte("foobar")
|
||||
keyWatch, keyEnd, keyPut := []byte("foo"), []byte("fop"), []byte("foobar")
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
id := w.Watch(keyWatch, prefixMatch, 0)
|
||||
id := w.Watch(keyWatch, keyEnd, 0)
|
||||
if _, ok := idm[id]; ok {
|
||||
t.Errorf("#%d: unexpected duplicated id %x", i, id)
|
||||
}
|
||||
@ -118,12 +117,12 @@ func TestWatcherWatchPrefix(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
keyWatch1, keyPut1 := []byte("foo1"), []byte("foo1bar")
|
||||
keyWatch1, keyEnd1, keyPut1 := []byte("foo1"), []byte("foo2"), []byte("foo1bar")
|
||||
s.Put(keyPut1, val, lease.NoLease)
|
||||
|
||||
// unsynced watchers
|
||||
for i := 10; i < 15; i++ {
|
||||
id := w.Watch(keyWatch1, prefixMatch, 1)
|
||||
id := w.Watch(keyWatch1, keyEnd1, 1)
|
||||
if _, ok := idm[id]; ok {
|
||||
t.Errorf("#%d: id %d exists", i, id)
|
||||
}
|
||||
@ -159,7 +158,7 @@ func TestWatchStreamCancelWatcherByID(t *testing.T) {
|
||||
w := s.NewWatchStream()
|
||||
defer w.Close()
|
||||
|
||||
id := w.Watch([]byte("foo"), false, 0)
|
||||
id := w.Watch([]byte("foo"), nil, 0)
|
||||
|
||||
tests := []struct {
|
||||
cancelID WatchID
|
||||
|
Loading…
x
Reference in New Issue
Block a user