Merge pull request #4614 from heyitsanthony/future-watch-rpc

etcdserver, storage, clientv3: watcher ranges
This commit is contained in:
Anthony Romano 2016-02-29 15:59:18 -08:00
commit 3a9d532140
17 changed files with 1164 additions and 429 deletions

View File

@ -157,6 +157,20 @@ func testWatchMultiWatcher(t *testing.T, wctx *watchctx) {
}
}
// TestWatchRange tests watcher creates ranges
func TestWatchRange(t *testing.T) {
runWatchTest(t, testWatchReconnInit)
}
func testWatchRange(t *testing.T, wctx *watchctx) {
if wctx.ch = wctx.w.Watch(context.TODO(), "a", clientv3.WithRange("c")); wctx.ch == nil {
t.Fatalf("expected non-nil channel")
}
putAndWatch(t, wctx, "a", "a")
putAndWatch(t, wctx, "b", "b")
putAndWatch(t, wctx, "bar", "bar")
}
// TestWatchReconnRequest tests the send failure path when requesting a watcher.
func TestWatchReconnRequest(t *testing.T) {
runWatchTest(t, testWatchReconnRequest)

View File

@ -15,8 +15,6 @@
package clientv3
import (
"reflect"
pb "github.com/coreos/etcd/etcdserver/etcdserverpb"
"github.com/coreos/etcd/lease"
)
@ -69,27 +67,6 @@ func (op Op) toRequestUnion() *pb.RequestUnion {
}
}
func (op Op) toWatchRequest() *watchRequest {
switch op.t {
case tRange:
key := string(op.key)
prefix := ""
if op.end != nil {
prefix = key
key = ""
}
wr := &watchRequest{
key: key,
prefix: prefix,
rev: op.rev,
}
return wr
default:
panic("Only for tRange")
}
}
func (op Op) isWrite() bool {
return op.t != tRange
}
@ -140,8 +117,6 @@ func opWatch(key string, opts ...OpOption) Op {
ret := Op{t: tRange, key: []byte(key)}
ret.applyOpts(opts)
switch {
case ret.end != nil && !reflect.DeepEqual(ret.end, getPrefix(ret.key)):
panic("only supports single keys or prefixes")
case ret.leaseID != 0:
panic("unexpected lease in watch")
case ret.limit != 0:

View File

@ -78,10 +78,10 @@ type watcher struct {
// watchRequest is issued by the subscriber to start a new watcher
type watchRequest struct {
ctx context.Context
key string
prefix string
rev int64
ctx context.Context
key string
end string
rev int64
// retc receives a chan WatchResponse once the watcher is established
retc chan chan WatchResponse
}
@ -129,11 +129,14 @@ func NewWatcher(c *Client) Watcher {
func (w *watcher) Watch(ctx context.Context, key string, opts ...OpOption) WatchChan {
ow := opWatch(key, opts...)
wr := ow.toWatchRequest()
wr.ctx = ctx
retc := make(chan chan WatchResponse, 1)
wr.retc = retc
wr := &watchRequest{
ctx: ctx,
key: string(ow.key),
end: string(ow.end),
rev: ow.rev,
retc: retc,
}
ok := false
@ -502,11 +505,10 @@ func (w *watcher) resumeWatchers(wc pb.Watch_WatchClient) error {
// toPB converts an internal watch request structure to its protobuf messagefunc (wr *watchRequest)
func (wr *watchRequest) toPB() *pb.WatchRequest {
req := &pb.WatchCreateRequest{StartRevision: wr.rev}
if wr.key != "" {
req.Key = []byte(wr.key)
} else {
req.Prefix = []byte(wr.prefix)
req := &pb.WatchCreateRequest{
StartRevision: wr.rev,
Key: []byte(wr.key),
RangeEnd: []byte(wr.end),
}
cr := &pb.WatchRequest_CreateRequest{CreateRequest: req}
return &pb.WatchRequest{RequestUnion: cr}

View File

@ -94,35 +94,33 @@ func (sws *serverWatchStream) recvLoop() error {
switch uv := req.RequestUnion.(type) {
case *pb.WatchRequest_CreateRequest:
if uv.CreateRequest != nil {
creq := uv.CreateRequest
var prefix bool
toWatch := creq.Key
if len(creq.Key) == 0 {
toWatch = creq.Prefix
prefix = true
}
if uv.CreateRequest == nil {
break
}
rev := creq.StartRevision
wsrev := sws.watchStream.Rev()
if rev == 0 {
// rev 0 watches past the current revision
rev = wsrev + 1
} else if rev > wsrev { // do not allow watching future revision.
sws.ctrlStream <- &pb.WatchResponse{
Header: sws.newResponseHeader(wsrev),
WatchId: -1,
Created: true,
Canceled: true,
}
continue
}
id := sws.watchStream.Watch(toWatch, prefix, rev)
sws.ctrlStream <- &pb.WatchResponse{
Header: sws.newResponseHeader(wsrev),
WatchId: int64(id),
Created: true,
}
creq := uv.CreateRequest
if len(creq.RangeEnd) == 1 && creq.RangeEnd[0] == 0 {
// support >= key queries
creq.RangeEnd = []byte{}
}
rev := creq.StartRevision
wsrev := sws.watchStream.Rev()
futureRev := rev > wsrev
if rev == 0 {
// rev 0 watches past the current revision
rev = wsrev + 1
}
// do not allow future watch revision
id := storage.WatchID(-1)
if !futureRev {
id = sws.watchStream.Watch(creq.Key, creq.RangeEnd, rev)
}
sws.ctrlStream <- &pb.WatchResponse{
Header: sws.newResponseHeader(wsrev),
WatchId: int64(id),
Created: true,
Canceled: futureRev,
}
case *pb.WatchRequest_CancelRequest:
if uv.CancelRequest != nil {

View File

@ -870,8 +870,9 @@ func _WatchRequest_OneofUnmarshaler(msg proto.Message, tag, wire int, b *proto.B
type WatchCreateRequest struct {
// the key to be watched
Key []byte `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"`
// the prefix to be watched.
Prefix []byte `protobuf:"bytes,2,opt,name=prefix,proto3" json:"prefix,omitempty"`
// if the range_end is given, keys in [key, range_end) are watched
// NOTE: only range_end == prefixEnd(key) is accepted now
RangeEnd []byte `protobuf:"bytes,2,opt,name=range_end,proto3" json:"range_end,omitempty"`
// start_revision is an optional revision (including) to watch from. No start_revision is "now".
StartRevision int64 `protobuf:"varint,3,opt,name=start_revision,proto3" json:"start_revision,omitempty"`
}
@ -2588,12 +2589,12 @@ func (m *WatchCreateRequest) MarshalTo(data []byte) (int, error) {
i += copy(data[i:], m.Key)
}
}
if m.Prefix != nil {
if len(m.Prefix) > 0 {
if m.RangeEnd != nil {
if len(m.RangeEnd) > 0 {
data[i] = 0x12
i++
i = encodeVarintRpc(data, i, uint64(len(m.Prefix)))
i += copy(data[i:], m.Prefix)
i = encodeVarintRpc(data, i, uint64(len(m.RangeEnd)))
i += copy(data[i:], m.RangeEnd)
}
}
if m.StartRevision != 0 {
@ -3592,8 +3593,8 @@ func (m *WatchCreateRequest) Size() (n int) {
n += 1 + l + sovRpc(uint64(l))
}
}
if m.Prefix != nil {
l = len(m.Prefix)
if m.RangeEnd != nil {
l = len(m.RangeEnd)
if l > 0 {
n += 1 + l + sovRpc(uint64(l))
}
@ -6004,7 +6005,7 @@ func (m *WatchCreateRequest) Unmarshal(data []byte) error {
iNdEx = postIndex
case 2:
if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field Prefix", wireType)
return fmt.Errorf("proto: wrong wireType = %d for field RangeEnd", wireType)
}
var byteLen int
for shift := uint(0); ; shift += 7 {
@ -6028,9 +6029,9 @@ func (m *WatchCreateRequest) Unmarshal(data []byte) error {
if postIndex > l {
return io.ErrUnexpectedEOF
}
m.Prefix = append(m.Prefix[:0], data[iNdEx:postIndex]...)
if m.Prefix == nil {
m.Prefix = []byte{}
m.RangeEnd = append(m.RangeEnd[:0], data[iNdEx:postIndex]...)
if m.RangeEnd == nil {
m.RangeEnd = []byte{}
}
iNdEx = postIndex
case 3:

View File

@ -262,11 +262,11 @@ message WatchRequest {
message WatchCreateRequest {
// the key to be watched
bytes key = 1;
// the prefix to be watched.
bytes prefix = 2;
// if the range_end is given, keys in [key, range_end) are watched
// NOTE: only range_end == prefixEnd(key) is accepted now
bytes range_end = 2;
// start_revision is an optional revision (including) to watch from. No start_revision is "now".
int64 start_revision = 3;
// TODO: support Range watch?
}
message WatchCancelRequest {

View File

@ -71,7 +71,8 @@ func TestV3WatchFromCurrentRevision(t *testing.T) {
[]string{"fooLong"},
&pb.WatchRequest{RequestUnion: &pb.WatchRequest_CreateRequest{
CreateRequest: &pb.WatchCreateRequest{
Prefix: []byte("foo")}}},
Key: []byte("foo"),
RangeEnd: []byte("fop")}}},
[]*pb.WatchResponse{
{
@ -91,7 +92,8 @@ func TestV3WatchFromCurrentRevision(t *testing.T) {
[]string{"foo"},
&pb.WatchRequest{RequestUnion: &pb.WatchRequest_CreateRequest{
CreateRequest: &pb.WatchCreateRequest{
Prefix: []byte("helloworld")}}},
Key: []byte("helloworld"),
RangeEnd: []byte("helloworle")}}},
[]*pb.WatchResponse{},
},
@ -140,7 +142,8 @@ func TestV3WatchFromCurrentRevision(t *testing.T) {
[]string{"foo", "foo", "foo"},
&pb.WatchRequest{RequestUnion: &pb.WatchRequest_CreateRequest{
CreateRequest: &pb.WatchCreateRequest{
Prefix: []byte("foo")}}},
Key: []byte("foo"),
RangeEnd: []byte("fop")}}},
[]*pb.WatchResponse{
{
@ -203,6 +206,11 @@ func TestV3WatchFromCurrentRevision(t *testing.T) {
t.Errorf("#%d: did not create watchid, got +%v", i, cresp)
continue
}
if cresp.Canceled {
t.Errorf("#%d: canceled watcher on create", i, cresp)
continue
}
createdWatchId := cresp.WatchId
if cresp.Header == nil || cresp.Header.Revision != 1 {
t.Errorf("#%d: header revision got +%v, wanted revison 1", i, cresp)
@ -353,7 +361,7 @@ func TestV3WatchCurrentPutOverlap(t *testing.T) {
progress := make(map[int64]int64)
wreq := &pb.WatchRequest{RequestUnion: &pb.WatchRequest_CreateRequest{
CreateRequest: &pb.WatchCreateRequest{Prefix: []byte("foo")}}}
CreateRequest: &pb.WatchCreateRequest{Key: []byte("foo"), RangeEnd: []byte("fop")}}}
if err := wStream.Send(wreq); err != nil {
t.Fatalf("first watch request failed (%v)", err)
}
@ -437,7 +445,7 @@ func testV3WatchMultipleWatchers(t *testing.T, startRev int64) {
} else {
wreq = &pb.WatchRequest{RequestUnion: &pb.WatchRequest_CreateRequest{
CreateRequest: &pb.WatchCreateRequest{
Prefix: []byte("fo"), StartRevision: startRev}}}
Key: []byte("fo"), RangeEnd: []byte("fp"), StartRevision: startRev}}}
}
if err := wStream.Send(wreq); err != nil {
t.Fatalf("wStream.Send error: %v", err)
@ -530,7 +538,7 @@ func testV3WatchMultipleEventsTxn(t *testing.T, startRev int64) {
wreq := &pb.WatchRequest{RequestUnion: &pb.WatchRequest_CreateRequest{
CreateRequest: &pb.WatchCreateRequest{
Prefix: []byte("foo"), StartRevision: startRev}}}
Key: []byte("foo"), RangeEnd: []byte("fop"), StartRevision: startRev}}}
if err := wStream.Send(wreq); err != nil {
t.Fatalf("wStream.Send error: %v", err)
}
@ -623,7 +631,7 @@ func TestV3WatchMultipleEventsPutUnsynced(t *testing.T) {
wreq := &pb.WatchRequest{RequestUnion: &pb.WatchRequest_CreateRequest{
CreateRequest: &pb.WatchCreateRequest{
Prefix: []byte("foo"), StartRevision: 1}}}
Key: []byte("foo"), RangeEnd: []byte("fop"), StartRevision: 1}}}
if err := wStream.Send(wreq); err != nil {
t.Fatalf("wStream.Send error: %v", err)
}

526
pkg/adt/interval_tree.go Normal file
View File

@ -0,0 +1,526 @@
// Copyright 2016 CoreOS, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package adt
import (
"math"
)
// Comparable is an interface for trichotomic comparisons.
type Comparable interface {
// Compare gives the result of a 3-way comparison
// a.Compare(b) = 1 => a > b
// a.Compare(b) = 0 => a == b
// a.Compare(b) = -1 => a < b
Compare(c Comparable) int
}
type rbcolor bool
const black = true
const red = false
// Interval implements a Comparable interval [begin, end)
// TODO: support different sorts of intervals: (a,b), [a,b], (a, b]
type Interval struct {
Begin Comparable
End Comparable
}
// Compare on an interval gives == if the interval overlaps.
func (ivl *Interval) Compare(c Comparable) int {
ivl2 := c.(*Interval)
ivbCmpBegin := ivl.Begin.Compare(ivl2.Begin)
ivbCmpEnd := ivl.Begin.Compare(ivl2.End)
iveCmpBegin := ivl.End.Compare(ivl2.Begin)
// ivl is left of ivl2
if ivbCmpBegin < 0 && iveCmpBegin <= 0 {
return -1
}
// iv is right of iv2
if ivbCmpEnd >= 0 {
return 1
}
return 0
}
type intervalNode struct {
// iv is the interval-value pair entry.
iv IntervalValue
// max endpoint of all descendent nodes.
max Comparable
// left and right are sorted by low endpoint of key interval
left, right *intervalNode
// parent is the direct ancestor of the node
parent *intervalNode
c rbcolor
}
func (x *intervalNode) color() rbcolor {
if x == nil {
return black
}
return x.c
}
func (n *intervalNode) height() int {
if n == nil {
return 0
}
ld := n.left.height()
rd := n.right.height()
if ld < rd {
return rd + 1
}
return ld + 1
}
func (x *intervalNode) min() *intervalNode {
for x.left != nil {
x = x.left
}
return x
}
// successor is the next in-order node in the tree
func (x *intervalNode) successor() *intervalNode {
if x.right != nil {
return x.right.min()
}
y := x.parent
for y != nil && x == y.right {
x = y
y = y.parent
}
return y
}
// updateMax updates the maximum values for a node and its ancestors
func (x *intervalNode) updateMax() {
for x != nil {
oldmax := x.max
max := x.iv.Ivl.End
if x.left != nil && x.left.max.Compare(max) > 0 {
max = x.left.max
}
if x.right != nil && x.right.max.Compare(max) > 0 {
max = x.right.max
}
if oldmax.Compare(max) == 0 {
break
}
x.max = max
x = x.parent
}
}
type nodeVisitor func(n *intervalNode) bool
// visit will call a node visitor on each node that overlaps the given interval
func (x *intervalNode) visit(iv *Interval, nv nodeVisitor) {
if x == nil {
return
}
v := iv.Compare(&x.iv.Ivl)
switch {
case v < 0:
x.left.visit(iv, nv)
case v > 0:
maxiv := Interval{x.iv.Ivl.Begin, x.max}
if maxiv.Compare(iv) == 0 {
x.left.visit(iv, nv)
x.right.visit(iv, nv)
}
default:
nv(x)
x.left.visit(iv, nv)
x.right.visit(iv, nv)
}
}
type IntervalValue struct {
Ivl Interval
Val interface{}
}
// IntervalTree represents a (mostly) textbook implementation of the
// "Introduction to Algorithms" (Cormen et al, 2nd ed.) chapter 13 red-black tree
// and chapter 14.3 interval tree with search supporting "stabbing queries".
type IntervalTree struct {
root *intervalNode
count int
}
// Delete removes the node with the given interval from the tree, returning
// true if a node is in fact removed.
func (ivt *IntervalTree) Delete(ivl Interval) bool {
z := ivt.find(ivl)
if z == nil {
return false
}
y := z
if z.left != nil && z.right != nil {
y = z.successor()
}
x := y.left
if x == nil {
x = y.right
}
if x != nil {
x.parent = y.parent
}
if y.parent == nil {
ivt.root = x
} else {
if y == y.parent.left {
y.parent.left = x
} else {
y.parent.right = x
}
y.parent.updateMax()
}
if y != z {
z.iv = y.iv
z.updateMax()
}
if y.color() == black && x != nil {
ivt.deleteFixup(x)
}
ivt.count--
return true
}
func (ivt *IntervalTree) deleteFixup(x *intervalNode) {
for x != ivt.root && x.color() == black && x.parent != nil {
if x == x.parent.left {
w := x.parent.right
if w.color() == red {
w.c = black
x.parent.c = red
ivt.rotateLeft(x.parent)
w = x.parent.right
}
if w == nil {
break
}
if w.left.color() == black && w.right.color() == black {
w.c = red
x = x.parent
} else {
if w.right.color() == black {
w.left.c = black
w.c = red
ivt.rotateRight(w)
w = x.parent.right
}
w.c = x.parent.color()
x.parent.c = black
w.right.c = black
ivt.rotateLeft(x.parent)
x = ivt.root
}
} else {
// same as above but with left and right exchanged
w := x.parent.left
if w.color() == red {
w.c = black
x.parent.c = red
ivt.rotateRight(x.parent)
w = x.parent.left
}
if w == nil {
break
}
if w.left.color() == black && w.right.color() == black {
w.c = red
x = x.parent
} else {
if w.left.color() == black {
w.right.c = black
w.c = red
ivt.rotateLeft(w)
w = x.parent.left
}
w.c = x.parent.color()
x.parent.c = black
w.left.c = black
ivt.rotateRight(x.parent)
x = ivt.root
}
}
}
if x != nil {
x.c = black
}
}
// Insert adds a node with the given interval into the tree.
func (ivt *IntervalTree) Insert(ivl Interval, val interface{}) {
var y *intervalNode
z := &intervalNode{iv: IntervalValue{ivl, val}, max: ivl.End, c: red}
x := ivt.root
for x != nil {
y = x
if z.iv.Ivl.Begin.Compare(x.iv.Ivl.Begin) < 0 {
x = x.left
} else {
x = x.right
}
}
z.parent = y
if y == nil {
ivt.root = z
} else {
if z.iv.Ivl.Begin.Compare(y.iv.Ivl.Begin) < 0 {
y.left = z
} else {
y.right = z
}
y.updateMax()
}
z.c = red
ivt.insertFixup(z)
ivt.count++
}
func (ivt *IntervalTree) insertFixup(z *intervalNode) {
for z.parent != nil && z.parent.parent != nil && z.parent.color() == red {
if z.parent == z.parent.parent.left {
y := z.parent.parent.right
if y.color() == red {
y.c = black
z.parent.c = black
z.parent.parent.c = red
z = z.parent.parent
} else {
if z == z.parent.right {
z = z.parent
ivt.rotateLeft(z)
}
z.parent.c = black
z.parent.parent.c = red
ivt.rotateRight(z.parent.parent)
}
} else {
// same as then with left/right exchanged
y := z.parent.parent.left
if y.color() == red {
y.c = black
z.parent.c = black
z.parent.parent.c = red
z = z.parent.parent
} else {
if z == z.parent.left {
z = z.parent
ivt.rotateRight(z)
}
z.parent.c = black
z.parent.parent.c = red
ivt.rotateLeft(z.parent.parent)
}
}
}
ivt.root.c = black
}
// rotateLeft moves x so it is left of its right child
func (ivt *IntervalTree) rotateLeft(x *intervalNode) {
y := x.right
x.right = y.left
if y.left != nil {
y.left.parent = x
}
x.updateMax()
ivt.replaceParent(x, y)
y.left = x
y.updateMax()
}
// rotateLeft moves x so it is right of its left child
func (ivt *IntervalTree) rotateRight(x *intervalNode) {
if x == nil {
return
}
y := x.left
x.left = y.right
if y.right != nil {
y.right.parent = x
}
x.updateMax()
ivt.replaceParent(x, y)
y.right = x
y.updateMax()
}
// replaceParent replaces x's parent with y
func (ivt *IntervalTree) replaceParent(x *intervalNode, y *intervalNode) {
y.parent = x.parent
if x.parent == nil {
ivt.root = y
} else {
if x == x.parent.left {
x.parent.left = y
} else {
x.parent.right = y
}
x.parent.updateMax()
}
x.parent = y
}
// Len gives the number of elements in the tree
func (ivt *IntervalTree) Len() int { return ivt.count }
// Height is the number of levels in the tree; one node has height 1.
func (ivt *IntervalTree) Height() int { return ivt.root.height() }
// MaxHeight is the expected maximum tree height given the number of nodes
func (ivt *IntervalTree) MaxHeight() int {
return int((2 * math.Log2(float64(ivt.Len()+1))) + 0.5)
}
// InternalVisitor is used on tree searchs; return false to stop searching.
type IntervalVisitor func(n *IntervalValue) bool
// Visit calls a visitor function on every tree node intersecting the given interval.
func (ivt *IntervalTree) Visit(ivl Interval, ivv IntervalVisitor) {
ivt.root.visit(&ivl, func(n *intervalNode) bool { return ivv(&n.iv) })
}
// find the exact node for a given interval
func (ivt *IntervalTree) find(ivl Interval) (ret *intervalNode) {
f := func(n *intervalNode) bool {
if n.iv.Ivl != ivl {
return true
}
ret = n
return false
}
ivt.root.visit(&ivl, f)
return ret
}
// Find gets the IntervalValue for the node matching the given interval
func (ivt *IntervalTree) Find(ivl Interval) (ret *IntervalValue) {
n := ivt.find(ivl)
if n == nil {
return nil
}
return &n.iv
}
// Contains returns true if there is some tree node intersecting the given interval.
func (ivt *IntervalTree) Contains(iv Interval) bool {
x := ivt.root
for x != nil && iv.Compare(&x.iv.Ivl) != 0 {
if x.left != nil && x.left.max.Compare(iv.Begin) > 0 {
x = x.left
} else {
x = x.right
}
}
return x != nil
}
// Stab returns a slice with all elements in the tree intersecting the interval.
func (ivt *IntervalTree) Stab(iv Interval) (ivs []*IntervalValue) {
f := func(n *IntervalValue) bool { ivs = append(ivs, n); return true }
ivt.Visit(iv, f)
return ivs
}
type StringComparable string
func (s StringComparable) Compare(c Comparable) int {
sc := c.(StringComparable)
if s < sc {
return -1
}
if s > sc {
return 1
}
return 0
}
func NewStringInterval(begin, end string) Interval {
return Interval{StringComparable(begin), StringComparable(end)}
}
func NewStringPoint(s string) Interval {
return Interval{StringComparable(s), StringComparable(s + "\x00")}
}
// StringAffineComparable treats "" as > all other strings
type StringAffineComparable string
func (s StringAffineComparable) Compare(c Comparable) int {
sc := c.(StringAffineComparable)
if len(s) == 0 {
if len(sc) == 0 {
return 0
}
return 1
}
if len(sc) == 0 {
return -1
}
if s < sc {
return -1
}
if s > sc {
return 1
}
return 0
}
func NewStringAffineInterval(begin, end string) Interval {
return Interval{StringAffineComparable(begin), StringAffineComparable(end)}
}
func NewStringAffinePoint(s string) Interval {
return NewStringAffineInterval(s, s+"\x00")
}
func NewInt64Interval(a int64, b int64) Interval {
return Interval{Int64Comparable(a), Int64Comparable(b)}
}
func NewInt64Point(a int64) Interval {
return Interval{Int64Comparable(a), Int64Comparable(a + 1)}
}
type Int64Comparable int64
func (v Int64Comparable) Compare(c Comparable) int {
vc := c.(Int64Comparable)
cmp := v - vc
if cmp < 0 {
return -1
}
if cmp > 0 {
return 1
}
return 0
}

View File

@ -0,0 +1,138 @@
// Copyright 2016 CoreOS, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package adt
import (
"math/rand"
"testing"
"time"
)
func TestIntervalTreeContains(t *testing.T) {
ivt := &IntervalTree{}
ivt.Insert(NewStringInterval("1", "3"), 123)
if ivt.Contains(NewStringPoint("0")) {
t.Errorf("contains 0")
}
if !ivt.Contains(NewStringPoint("1")) {
t.Errorf("missing 1")
}
if !ivt.Contains(NewStringPoint("11")) {
t.Errorf("missing 11")
}
if !ivt.Contains(NewStringPoint("2")) {
t.Errorf("missing 2")
}
if ivt.Contains(NewStringPoint("3")) {
t.Errorf("contains 3")
}
}
func TestIntervalTreeStringAffine(t *testing.T) {
ivt := &IntervalTree{}
ivt.Insert(NewStringAffineInterval("8", ""), 123)
if !ivt.Contains(NewStringAffinePoint("9")) {
t.Errorf("missing 9")
}
if ivt.Contains(NewStringAffinePoint("7")) {
t.Errorf("contains 7")
}
}
func TestIntervalTreeStab(t *testing.T) {
ivt := &IntervalTree{}
ivt.Insert(NewStringInterval("0", "1"), 123)
ivt.Insert(NewStringInterval("0", "2"), 456)
ivt.Insert(NewStringInterval("5", "6"), 789)
ivt.Insert(NewStringInterval("6", "8"), 999)
ivt.Insert(NewStringInterval("0", "3"), 0)
if ivt.root.max.Compare(StringComparable("8")) != 0 {
t.Fatalf("wrong root max got %v, expected 8", ivt.root.max)
}
if x := len(ivt.Stab(NewStringPoint("0"))); x != 3 {
t.Errorf("got %d, expected 3", x)
}
if x := len(ivt.Stab(NewStringPoint("1"))); x != 2 {
t.Errorf("got %d, expected 2", x)
}
if x := len(ivt.Stab(NewStringPoint("2"))); x != 1 {
t.Errorf("got %d, expected 1", x)
}
if x := len(ivt.Stab(NewStringPoint("3"))); x != 0 {
t.Errorf("got %d, expected 0", x)
}
if x := len(ivt.Stab(NewStringPoint("5"))); x != 1 {
t.Errorf("got %d, expected 1", x)
}
if x := len(ivt.Stab(NewStringPoint("55"))); x != 1 {
t.Errorf("got %d, expected 1", x)
}
if x := len(ivt.Stab(NewStringPoint("6"))); x != 1 {
t.Errorf("got %d, expected 1", x)
}
}
type xy struct {
x int64
y int64
}
func TestIntervalTreeRandom(t *testing.T) {
// generate unique intervals
ivs := make(map[xy]struct{})
ivt := &IntervalTree{}
maxv := 128
rand.Seed(time.Now().UnixNano())
for i := rand.Intn(maxv) + 1; i != 0; i-- {
x, y := int64(rand.Intn(maxv)), int64(rand.Intn(maxv))
if x > y {
t := x
x = y
y = t
} else if x == y {
y++
}
iv := xy{x, y}
if _, ok := ivs[iv]; ok {
// don't double insert
continue
}
ivt.Insert(NewInt64Interval(x, y), 123)
ivs[iv] = struct{}{}
}
for ab := range ivs {
for xy := range ivs {
v := xy.x + int64(rand.Intn(int(xy.y-xy.x)))
if slen := len(ivt.Stab(NewInt64Point(v))); slen == 0 {
t.Fatalf("expected %v stab non-zero for [%+v)", v, xy)
}
if !ivt.Contains(NewInt64Point(v)) {
t.Fatalf("did not get %d as expected for [%+v)", v, xy)
}
}
if !ivt.Delete(NewInt64Interval(ab.x, ab.y)) {
t.Errorf("did not delete %v as expected", ab)
}
delete(ivs, ab)
}
if ivt.Len() != 0 {
t.Errorf("got ivt.Len() = %v, expected 0", ivt.Len())
}
}

View File

@ -722,13 +722,10 @@ func TestWatchableKVWatch(t *testing.T) {
w := s.NewWatchStream()
defer w.Close()
wid := w.Watch([]byte("foo"), true, 0)
wid := w.Watch([]byte("foo"), []byte("fop"), 0)
s.Put([]byte("foo"), []byte("bar"), 1)
select {
case resp := <-w.Chan():
wev := storagepb.Event{
Type: storagepb.PUT,
wev := []storagepb.Event{
{Type: storagepb.PUT,
Kv: &storagepb.KeyValue{
Key: []byte("foo"),
Value: []byte("bar"),
@ -737,23 +734,8 @@ func TestWatchableKVWatch(t *testing.T) {
Version: 1,
Lease: 1,
},
}
if resp.WatchID != wid {
t.Errorf("resp.WatchID got = %d, want = %d", resp.WatchID, wid)
}
ev := resp.Events[0]
if !reflect.DeepEqual(ev, wev) {
t.Errorf("watched event = %+v, want %+v", ev, wev)
}
case <-time.After(5 * time.Second):
// CPU might be too slow, and the routine is not able to switch around
testutil.FatalStack(t, "failed to watch the event")
}
s.Put([]byte("foo1"), []byte("bar1"), 2)
select {
case resp := <-w.Chan():
wev := storagepb.Event{
},
{
Type: storagepb.PUT,
Kv: &storagepb.KeyValue{
Key: []byte("foo1"),
@ -763,49 +745,8 @@ func TestWatchableKVWatch(t *testing.T) {
Version: 1,
Lease: 2,
},
}
if resp.WatchID != wid {
t.Errorf("resp.WatchID got = %d, want = %d", resp.WatchID, wid)
}
ev := resp.Events[0]
if !reflect.DeepEqual(ev, wev) {
t.Errorf("watched event = %+v, want %+v", ev, wev)
}
case <-time.After(5 * time.Second):
testutil.FatalStack(t, "failed to watch the event")
}
w = s.NewWatchStream()
wid = w.Watch([]byte("foo1"), false, 1)
select {
case resp := <-w.Chan():
wev := storagepb.Event{
Type: storagepb.PUT,
Kv: &storagepb.KeyValue{
Key: []byte("foo1"),
Value: []byte("bar1"),
CreateRevision: 3,
ModRevision: 3,
Version: 1,
Lease: 2,
},
}
if resp.WatchID != wid {
t.Errorf("resp.WatchID got = %d, want = %d", resp.WatchID, wid)
}
ev := resp.Events[0]
if !reflect.DeepEqual(ev, wev) {
t.Errorf("watched event = %+v, want %+v", ev, wev)
}
case <-time.After(5 * time.Second):
testutil.FatalStack(t, "failed to watch the event")
}
s.Put([]byte("foo1"), []byte("bar11"), 3)
select {
case resp := <-w.Chan():
wev := storagepb.Event{
},
{
Type: storagepb.PUT,
Kv: &storagepb.KeyValue{
Key: []byte("foo1"),
@ -815,13 +756,63 @@ func TestWatchableKVWatch(t *testing.T) {
Version: 2,
Lease: 3,
},
}
},
}
s.Put([]byte("foo"), []byte("bar"), 1)
select {
case resp := <-w.Chan():
if resp.WatchID != wid {
t.Errorf("resp.WatchID got = %d, want = %d", resp.WatchID, wid)
}
ev := resp.Events[0]
if !reflect.DeepEqual(ev, wev) {
t.Errorf("watched event = %+v, want %+v", ev, wev)
if !reflect.DeepEqual(ev, wev[0]) {
t.Errorf("watched event = %+v, want %+v", ev, wev[0])
}
case <-time.After(5 * time.Second):
// CPU might be too slow, and the routine is not able to switch around
testutil.FatalStack(t, "failed to watch the event")
}
s.Put([]byte("foo1"), []byte("bar1"), 2)
select {
case resp := <-w.Chan():
if resp.WatchID != wid {
t.Errorf("resp.WatchID got = %d, want = %d", resp.WatchID, wid)
}
ev := resp.Events[0]
if !reflect.DeepEqual(ev, wev[1]) {
t.Errorf("watched event = %+v, want %+v", ev, wev[1])
}
case <-time.After(5 * time.Second):
testutil.FatalStack(t, "failed to watch the event")
}
w = s.NewWatchStream()
wid = w.Watch([]byte("foo1"), []byte("foo2"), 3)
select {
case resp := <-w.Chan():
if resp.WatchID != wid {
t.Errorf("resp.WatchID got = %d, want = %d", resp.WatchID, wid)
}
ev := resp.Events[0]
if !reflect.DeepEqual(ev, wev[1]) {
t.Errorf("watched event = %+v, want %+v", ev, wev[1])
}
case <-time.After(5 * time.Second):
testutil.FatalStack(t, "failed to watch the event")
}
s.Put([]byte("foo1"), []byte("bar11"), 3)
select {
case resp := <-w.Chan():
if resp.WatchID != wid {
t.Errorf("resp.WatchID got = %d, want = %d", resp.WatchID, wid)
}
ev := resp.Events[0]
if !reflect.DeepEqual(ev, wev[2]) {
t.Errorf("watched event = %+v, want %+v", ev, wev[2])
}
case <-time.After(5 * time.Second):
testutil.FatalStack(t, "failed to watch the event")

View File

@ -16,8 +16,6 @@ package storage
import (
"log"
"math"
"strings"
"sync"
"time"
@ -34,103 +32,8 @@ const (
chanBufLen = 1024
)
var (
// watchBatchMaxRevs is the maximum distinct revisions that
// may be sent to an unsynced watcher at a time. Declared as
// var instead of const for testing purposes.
watchBatchMaxRevs = 1000
)
type eventBatch struct {
// evs is a batch of revision-ordered events
evs []storagepb.Event
// revs is the minimum unique revisions observed for this batch
revs int
// moreRev is first revision with more events following this batch
moreRev int64
}
type (
watcherSetByKey map[string]watcherSet
watcherSet map[*watcher]struct{}
watcherBatch map[*watcher]*eventBatch
)
func (eb *eventBatch) add(ev storagepb.Event) {
if eb.revs > watchBatchMaxRevs {
// maxed out batch size
return
}
if len(eb.evs) == 0 {
// base case
eb.revs = 1
eb.evs = append(eb.evs, ev)
return
}
// revision accounting
ebRev := eb.evs[len(eb.evs)-1].Kv.ModRevision
evRev := ev.Kv.ModRevision
if evRev > ebRev {
eb.revs++
if eb.revs > watchBatchMaxRevs {
eb.moreRev = evRev
return
}
}
eb.evs = append(eb.evs, ev)
}
func (wb watcherBatch) add(w *watcher, ev storagepb.Event) {
eb := wb[w]
if eb == nil {
eb = &eventBatch{}
wb[w] = eb
}
eb.add(ev)
}
func (w watcherSet) add(wa *watcher) {
if _, ok := w[wa]; ok {
panic("add watcher twice!")
}
w[wa] = struct{}{}
}
func (w watcherSetByKey) add(wa *watcher) {
set := w[string(wa.key)]
if set == nil {
set = make(watcherSet)
w[string(wa.key)] = set
}
set.add(wa)
}
func (w watcherSetByKey) getSetByKey(key string) (watcherSet, bool) {
set, ok := w[key]
return set, ok
}
func (w watcherSetByKey) delete(wa *watcher) bool {
k := string(wa.key)
if v, ok := w[k]; ok {
if _, ok := v[wa]; ok {
delete(v, wa)
// if there is nothing in the set,
// remove the set
if len(v) == 0 {
delete(w, k)
}
return true
}
}
return false
}
type watchable interface {
watch(key []byte, prefix bool, startRev int64, id WatchID, ch chan<- WatchResponse) (*watcher, cancelFunc)
watch(key, end []byte, startRev int64, id WatchID, ch chan<- WatchResponse) (*watcher, cancelFunc)
rev() int64
}
@ -140,11 +43,11 @@ type watchableStore struct {
*store
// contains all unsynced watchers that needs to sync with events that have happened
unsynced watcherSetByKey
unsynced watcherGroup
// contains all synced watchers that are in sync with the progress of the store.
// The key of the map is the key that the watcher watches on.
synced watcherSetByKey
synced watcherGroup
stopc chan struct{}
wg sync.WaitGroup
@ -157,8 +60,8 @@ type cancelFunc func()
func newWatchableStore(b backend.Backend, le lease.Lessor) *watchableStore {
s := &watchableStore{
store: NewStore(b, le),
unsynced: make(watcherSetByKey),
synced: make(watcherSetByKey),
unsynced: newWatcherGroup(),
synced: newWatcherGroup(),
stopc: make(chan struct{}),
}
if s.le != nil {
@ -268,16 +171,16 @@ func (s *watchableStore) NewWatchStream() WatchStream {
}
}
func (s *watchableStore) watch(key []byte, prefix bool, startRev int64, id WatchID, ch chan<- WatchResponse) (*watcher, cancelFunc) {
func (s *watchableStore) watch(key, end []byte, startRev int64, id WatchID, ch chan<- WatchResponse) (*watcher, cancelFunc) {
s.mu.Lock()
defer s.mu.Unlock()
wa := &watcher{
key: key,
prefix: prefix,
cur: startRev,
id: id,
ch: ch,
key: key,
end: end,
cur: startRev,
id: id,
ch: ch,
}
s.store.mu.Lock()
@ -342,15 +245,16 @@ func (s *watchableStore) syncWatchers() {
s.store.mu.Lock()
defer s.store.mu.Unlock()
if len(s.unsynced) == 0 {
if s.unsynced.size() == 0 {
return
}
// in order to find key-value pairs from unsynced watchers, we need to
// find min revision index, and these revisions can be used to
// query the backend store of key-value pairs
prefixes, minRev := s.scanUnsync()
curRev := s.store.currentRev.main
compactionRev := s.store.compactMainRev
minRev := s.unsynced.scanMinRev(curRev, compactionRev)
minBytes, maxBytes := newRevBytes(), newRevBytes()
revToBytes(revision{main: minRev}, minBytes)
revToBytes(revision{main: curRev + 1}, maxBytes)
@ -360,10 +264,10 @@ func (s *watchableStore) syncWatchers() {
tx := s.store.b.BatchTx()
tx.Lock()
revs, vs := tx.UnsafeRange(keyBucketName, minBytes, maxBytes, 0)
evs := kvsToEvents(revs, vs, s.unsynced, prefixes)
evs := kvsToEvents(&s.unsynced, revs, vs)
tx.Unlock()
for w, eb := range newWatcherBatch(s.unsynced, evs) {
for w, eb := range newWatcherBatch(&s.unsynced, evs) {
select {
// s.store.Rev also uses Lock, so just return directly
case w.ch <- WatchResponse{WatchID: w.id, Events: eb.evs, Revision: s.store.currentRev.main}:
@ -383,56 +287,18 @@ func (s *watchableStore) syncWatchers() {
s.unsynced.delete(w)
}
slowWatcherGauge.Set(float64(len(s.unsynced)))
}
func (s *watchableStore) scanUnsync() (prefixes map[string]struct{}, minRev int64) {
curRev := s.store.currentRev.main
compactionRev := s.store.compactMainRev
prefixes = make(map[string]struct{})
minRev = int64(math.MaxInt64)
for _, set := range s.unsynced {
for w := range set {
k := string(w.key)
if w.cur > curRev {
panic("watcher current revision should not exceed current revision")
}
if w.cur < compactionRev {
select {
case w.ch <- WatchResponse{WatchID: w.id, CompactRevision: compactionRev}:
s.unsynced.delete(w)
default:
// retry next time
}
continue
}
if minRev > w.cur {
minRev = w.cur
}
if w.prefix {
prefixes[k] = struct{}{}
}
}
}
return prefixes, minRev
slowWatcherGauge.Set(float64(s.unsynced.size()))
}
// kvsToEvents gets all events for the watchers from all key-value pairs
func kvsToEvents(revs, vals [][]byte, wsk watcherSetByKey, pfxs map[string]struct{}) (evs []storagepb.Event) {
func kvsToEvents(wg *watcherGroup, revs, vals [][]byte) (evs []storagepb.Event) {
for i, v := range vals {
var kv storagepb.KeyValue
if err := kv.Unmarshal(v); err != nil {
log.Panicf("storage: cannot unmarshal event: %v", err)
}
k := string(kv.Key)
if _, ok := wsk.getSetByKey(k); !ok && !matchPrefix(k, pfxs) {
if !wg.contains(string(kv.Key)) {
continue
}
@ -450,26 +316,19 @@ func kvsToEvents(revs, vals [][]byte, wsk watcherSetByKey, pfxs map[string]struc
// notify notifies the fact that given event at the given rev just happened to
// watchers that watch on the key of the event.
func (s *watchableStore) notify(rev int64, evs []storagepb.Event) {
we := newWatcherBatch(s.synced, evs)
for _, wm := range s.synced {
for w := range wm {
eb, ok := we[w]
if !ok {
continue
}
if eb.revs != 1 {
panic("unexpected multiple revisions in notification")
}
select {
case w.ch <- WatchResponse{WatchID: w.id, Events: eb.evs, Revision: s.Rev()}:
pendingEventsGauge.Add(float64(len(eb.evs)))
default:
// move slow watcher to unsynced
w.cur = rev
s.unsynced.add(w)
delete(wm, w)
slowWatcherGauge.Inc()
}
for w, eb := range newWatcherBatch(&s.synced, evs) {
if eb.revs != 1 {
panic("unexpected multiple revisions in notification")
}
select {
case w.ch <- WatchResponse{WatchID: w.id, Events: eb.evs, Revision: s.Rev()}:
pendingEventsGauge.Add(float64(len(eb.evs)))
default:
// move slow watcher to unsynced
w.cur = rev
s.unsynced.add(w)
s.synced.delete(w)
slowWatcherGauge.Inc()
}
}
}
@ -479,9 +338,9 @@ func (s *watchableStore) rev() int64 { return s.store.Rev() }
type watcher struct {
// the watcher key
key []byte
// prefix indicates if watcher is on a key or a prefix.
// If prefix is true, the watcher is on a prefix.
prefix bool
// end indicates the end of the range to watch.
// If end is set, the watcher is on a range.
end []byte
// cur is the current watcher revision.
// If cur is behind the current revision of the KV,
// watcher is unsynced and needs to catch up.
@ -492,42 +351,3 @@ type watcher struct {
// The chan might be shared with other watchers.
ch chan<- WatchResponse
}
// newWatcherBatch maps watchers to their matched events. It enables quick
// events look up by watcher.
func newWatcherBatch(sm watcherSetByKey, evs []storagepb.Event) watcherBatch {
wb := make(watcherBatch)
for _, ev := range evs {
key := string(ev.Kv.Key)
// check all prefixes of the key to notify all corresponded watchers
for i := 0; i <= len(key); i++ {
for w := range sm[key[:i]] {
// don't double notify
if ev.Kv.ModRevision < w.cur {
continue
}
// the watcher needs to be notified when either it watches prefix or
// the key is exactly matched.
if !w.prefix && i != len(ev.Kv.Key) {
continue
}
wb.add(w, ev)
}
}
}
return wb
}
// matchPrefix returns true if key has any matching prefix
// from prefixes map.
func matchPrefix(key string, prefixes map[string]struct{}) bool {
for p := range prefixes {
if strings.HasPrefix(key, p) {
return true
}
}
return false
}

View File

@ -40,11 +40,11 @@ func BenchmarkWatchableStoreUnsyncedCancel(b *testing.B) {
// in unsynced for this benchmark.
ws := &watchableStore{
store: s,
unsynced: make(watcherSetByKey),
unsynced: newWatcherGroup(),
// to make the test not crash from assigning to nil map.
// 'synced' doesn't get populated in this test.
synced: make(watcherSetByKey),
synced: newWatcherGroup(),
}
defer func() {
@ -69,7 +69,7 @@ func BenchmarkWatchableStoreUnsyncedCancel(b *testing.B) {
watchIDs := make([]WatchID, watcherN)
for i := 0; i < watcherN; i++ {
// non-0 value to keep watchers in unsynced
watchIDs[i] = w.Watch(testKey, true, 1)
watchIDs[i] = w.Watch(testKey, nil, 1)
}
// random-cancel N watchers to make it not biased towards
@ -109,7 +109,7 @@ func BenchmarkWatchableStoreSyncedCancel(b *testing.B) {
watchIDs := make([]WatchID, watcherN)
for i := 0; i < watcherN; i++ {
// 0 for startRev to keep watchers in synced
watchIDs[i] = w.Watch(testKey, true, 0)
watchIDs[i] = w.Watch(testKey, nil, 0)
}
// randomly cancel watchers to make it not biased towards

View File

@ -40,11 +40,11 @@ func TestWatch(t *testing.T) {
s.Put(testKey, testValue, lease.NoLease)
w := s.NewWatchStream()
w.Watch(testKey, true, 0)
w.Watch(testKey, nil, 0)
if _, ok := s.synced[string(testKey)]; !ok {
if !s.synced.contains(string(testKey)) {
// the key must have had an entry in synced
t.Errorf("existence = %v, want true", ok)
t.Errorf("existence = false, want true")
}
}
@ -61,15 +61,15 @@ func TestNewWatcherCancel(t *testing.T) {
s.Put(testKey, testValue, lease.NoLease)
w := s.NewWatchStream()
wt := w.Watch(testKey, true, 0)
wt := w.Watch(testKey, nil, 0)
if err := w.Cancel(wt); err != nil {
t.Error(err)
}
if _, ok := s.synced[string(testKey)]; ok {
if s.synced.contains(string(testKey)) {
// the key shoud have been deleted
t.Errorf("existence = %v, want false", ok)
t.Errorf("existence = true, want false")
}
}
@ -83,11 +83,11 @@ func TestCancelUnsynced(t *testing.T) {
// in unsynced to test if syncWatchers works as expected.
s := &watchableStore{
store: NewStore(b, &lease.FakeLessor{}),
unsynced: make(watcherSetByKey),
unsynced: newWatcherGroup(),
// to make the test not crash from assigning to nil map.
// 'synced' doesn't get populated in this test.
synced: make(watcherSetByKey),
synced: newWatcherGroup(),
}
defer func() {
@ -112,7 +112,7 @@ func TestCancelUnsynced(t *testing.T) {
watchIDs := make([]WatchID, watcherN)
for i := 0; i < watcherN; i++ {
// use 1 to keep watchers in unsynced
watchIDs[i] = w.Watch(testKey, true, 1)
watchIDs[i] = w.Watch(testKey, nil, 1)
}
for _, idx := range watchIDs {
@ -125,8 +125,8 @@ func TestCancelUnsynced(t *testing.T) {
//
// unsynced should be empty
// because cancel removes watcher from unsynced
if len(s.unsynced) != 0 {
t.Errorf("unsynced size = %d, want 0", len(s.unsynced))
if size := s.unsynced.size(); size != 0 {
t.Errorf("unsynced size = %d, want 0", size)
}
}
@ -138,8 +138,8 @@ func TestSyncWatchers(t *testing.T) {
s := &watchableStore{
store: NewStore(b, &lease.FakeLessor{}),
unsynced: make(watcherSetByKey),
synced: make(watcherSetByKey),
unsynced: newWatcherGroup(),
synced: newWatcherGroup(),
}
defer func() {
@ -158,13 +158,13 @@ func TestSyncWatchers(t *testing.T) {
for i := 0; i < watcherN; i++ {
// specify rev as 1 to keep watchers in unsynced
w.Watch(testKey, true, 1)
w.Watch(testKey, nil, 1)
}
// Before running s.syncWatchers() synced should be empty because we manually
// populate unsynced only
sws, _ := s.synced.getSetByKey(string(testKey))
uws, _ := s.unsynced.getSetByKey(string(testKey))
sws := s.synced.watcherSetByKey(string(testKey))
uws := s.unsynced.watcherSetByKey(string(testKey))
if len(sws) != 0 {
t.Fatalf("synced[string(testKey)] size = %d, want 0", len(sws))
@ -177,8 +177,8 @@ func TestSyncWatchers(t *testing.T) {
// this should move all unsynced watchers to synced ones
s.syncWatchers()
sws, _ = s.synced.getSetByKey(string(testKey))
uws, _ = s.unsynced.getSetByKey(string(testKey))
sws = s.synced.watcherSetByKey(string(testKey))
uws = s.unsynced.watcherSetByKey(string(testKey))
// After running s.syncWatchers(), synced should not be empty because syncwatchers
// populates synced in this test case
@ -240,7 +240,7 @@ func TestWatchCompacted(t *testing.T) {
}
w := s.NewWatchStream()
wt := w.Watch(testKey, true, compactRev-1)
wt := w.Watch(testKey, nil, compactRev-1)
select {
case resp := <-w.Chan():
@ -275,7 +275,7 @@ func TestWatchBatchUnsynced(t *testing.T) {
}
w := s.NewWatchStream()
w.Watch(v, false, 1)
w.Watch(v, nil, 1)
for i := 0; i < batches; i++ {
if resp := <-w.Chan(); len(resp.Events) != watchBatchMaxRevs {
t.Fatalf("len(events) = %d, want %d", len(resp.Events), watchBatchMaxRevs)
@ -284,8 +284,8 @@ func TestWatchBatchUnsynced(t *testing.T) {
s.store.mu.Lock()
defer s.store.mu.Unlock()
if len(s.synced) != 1 {
t.Errorf("synced size = %d, want 1", len(s.synced))
if size := s.synced.size(); size != 1 {
t.Errorf("synced size = %d, want 1", size)
}
}
@ -311,14 +311,14 @@ func TestNewMapwatcherToEventMap(t *testing.T) {
}
tests := []struct {
sync watcherSetByKey
sync []*watcher
evs []storagepb.Event
wwe map[*watcher][]storagepb.Event
}{
// no watcher in sync, some events should return empty wwe
{
watcherSetByKey{},
nil,
evs,
map[*watcher][]storagepb.Event{},
},
@ -326,9 +326,7 @@ func TestNewMapwatcherToEventMap(t *testing.T) {
// one watcher in sync, one event that does not match the key of that
// watcher should return empty wwe
{
watcherSetByKey{
string(k2): {ws[2]: struct{}{}},
},
[]*watcher{ws[2]},
evs[:1],
map[*watcher][]storagepb.Event{},
},
@ -336,9 +334,7 @@ func TestNewMapwatcherToEventMap(t *testing.T) {
// one watcher in sync, one event that matches the key of that
// watcher should return wwe with that matching watcher
{
watcherSetByKey{
string(k1): {ws[1]: struct{}{}},
},
[]*watcher{ws[1]},
evs[1:2],
map[*watcher][]storagepb.Event{
ws[1]: evs[1:2],
@ -349,10 +345,7 @@ func TestNewMapwatcherToEventMap(t *testing.T) {
// that matches the key of only one of the watcher should return wwe
// with the matching watcher
{
watcherSetByKey{
string(k0): {ws[0]: struct{}{}},
string(k2): {ws[2]: struct{}{}},
},
[]*watcher{ws[0], ws[2]},
evs[2:],
map[*watcher][]storagepb.Event{
ws[2]: evs[2:],
@ -362,10 +355,7 @@ func TestNewMapwatcherToEventMap(t *testing.T) {
// two watchers in sync that watches the same key, two events that
// match the keys should return wwe with those two watchers
{
watcherSetByKey{
string(k0): {ws[0]: struct{}{}},
string(k1): {ws[1]: struct{}{}},
},
[]*watcher{ws[0], ws[1]},
evs[:2],
map[*watcher][]storagepb.Event{
ws[0]: evs[:1],
@ -375,7 +365,12 @@ func TestNewMapwatcherToEventMap(t *testing.T) {
}
for i, tt := range tests {
gwe := newWatcherBatch(tt.sync, tt.evs)
wg := newWatcherGroup()
for _, w := range tt.sync {
wg.add(w)
}
gwe := newWatcherBatch(&wg, tt.evs)
if len(gwe) != len(tt.wwe) {
t.Errorf("#%d: len(gwe) got = %d, want = %d", i, len(gwe), len(tt.wwe))
}

View File

@ -29,16 +29,15 @@ type WatchID int64
type WatchStream interface {
// Watch creates a watcher. The watcher watches the events happening or
// happened on the given key or key prefix from the given startRev.
// happened on the given key or range [key, end) from the given startRev.
//
// The whole event history can be watched unless compacted.
// If `prefix` is true, watch observes all events whose key prefix could be the given `key`.
// If `startRev` <=0, watch observes events after currentRev.
//
// The returned `id` is the ID of this watcher. It appears as WatchID
// in events that are sent to the created watcher through stream channel.
//
Watch(key []byte, prefix bool, startRev int64) WatchID
Watch(key, end []byte, startRev int64) WatchID
// Chan returns a chan. All watch response will be sent to the returned chan.
Chan() <-chan WatchResponse
@ -87,7 +86,7 @@ type watchStream struct {
// Watch creates a new watcher in the stream and returns its WatchID.
// TODO: return error if ws is closed?
func (ws *watchStream) Watch(key []byte, prefix bool, startRev int64) WatchID {
func (ws *watchStream) Watch(key, end []byte, startRev int64) WatchID {
ws.mu.Lock()
defer ws.mu.Unlock()
if ws.closed {
@ -97,7 +96,7 @@ func (ws *watchStream) Watch(key []byte, prefix bool, startRev int64) WatchID {
id := ws.nextID
ws.nextID++
_, c := ws.watchable.watch(key, prefix, startRev, id, ws.ch)
_, c := ws.watchable.watch(key, end, startRev, id, ws.ch)
ws.cancels[id] = c
return id

View File

@ -33,6 +33,6 @@ func BenchmarkKVWatcherMemoryUsage(b *testing.B) {
b.ReportAllocs()
b.StartTimer()
for i := 0; i < b.N; i++ {
w.Watch([]byte(fmt.Sprint("foo", i)), false, 0)
w.Watch([]byte(fmt.Sprint("foo", i)), nil, 0)
}
}

269
storage/watcher_group.go Normal file
View File

@ -0,0 +1,269 @@
// Copyright 2016 CoreOS, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storage
import (
"math"
"github.com/coreos/etcd/pkg/adt"
"github.com/coreos/etcd/storage/storagepb"
)
var (
// watchBatchMaxRevs is the maximum distinct revisions that
// may be sent to an unsynced watcher at a time. Declared as
// var instead of const for testing purposes.
watchBatchMaxRevs = 1000
)
type eventBatch struct {
// evs is a batch of revision-ordered events
evs []storagepb.Event
// revs is the minimum unique revisions observed for this batch
revs int
// moreRev is first revision with more events following this batch
moreRev int64
}
func (eb *eventBatch) add(ev storagepb.Event) {
if eb.revs > watchBatchMaxRevs {
// maxed out batch size
return
}
if len(eb.evs) == 0 {
// base case
eb.revs = 1
eb.evs = append(eb.evs, ev)
return
}
// revision accounting
ebRev := eb.evs[len(eb.evs)-1].Kv.ModRevision
evRev := ev.Kv.ModRevision
if evRev > ebRev {
eb.revs++
if eb.revs > watchBatchMaxRevs {
eb.moreRev = evRev
return
}
}
eb.evs = append(eb.evs, ev)
}
type watcherBatch map[*watcher]*eventBatch
func (wb watcherBatch) add(w *watcher, ev storagepb.Event) {
eb := wb[w]
if eb == nil {
eb = &eventBatch{}
wb[w] = eb
}
eb.add(ev)
}
// newWatcherBatch maps watchers to their matched events. It enables quick
// events look up by watcher.
func newWatcherBatch(wg *watcherGroup, evs []storagepb.Event) watcherBatch {
wb := make(watcherBatch)
for _, ev := range evs {
for w := range wg.watcherSetByKey(string(ev.Kv.Key)) {
if ev.Kv.ModRevision >= w.cur {
// don't double notify
wb.add(w, ev)
}
}
}
return wb
}
type watcherSet map[*watcher]struct{}
func (w watcherSet) add(wa *watcher) {
if _, ok := w[wa]; ok {
panic("add watcher twice!")
}
w[wa] = struct{}{}
}
func (w watcherSet) union(ws watcherSet) {
for wa := range ws {
w.add(wa)
}
}
func (w watcherSet) delete(wa *watcher) {
if _, ok := w[wa]; !ok {
panic("removing missing watcher!")
}
delete(w, wa)
}
type watcherSetByKey map[string]watcherSet
func (w watcherSetByKey) add(wa *watcher) {
set := w[string(wa.key)]
if set == nil {
set = make(watcherSet)
w[string(wa.key)] = set
}
set.add(wa)
}
func (w watcherSetByKey) delete(wa *watcher) bool {
k := string(wa.key)
if v, ok := w[k]; ok {
if _, ok := v[wa]; ok {
delete(v, wa)
if len(v) == 0 {
// remove the set; nothing left
delete(w, k)
}
return true
}
}
return false
}
type interval struct {
begin string
end string
}
type watcherSetByInterval map[interval]watcherSet
// watcherGroup is a collection of watchers organized by their ranges
type watcherGroup struct {
// keyWatchers has the watchers that watch on a single key
keyWatchers watcherSetByKey
// ranges has the watchers that watch a range; it is sorted by interval
ranges adt.IntervalTree
// watchers is the set of all watchers
watchers watcherSet
}
func newWatcherGroup() watcherGroup {
return watcherGroup{
keyWatchers: make(watcherSetByKey),
watchers: make(watcherSet),
}
}
// add puts a watcher in the group.
func (wg *watcherGroup) add(wa *watcher) {
wg.watchers.add(wa)
if wa.end == nil {
wg.keyWatchers.add(wa)
return
}
// interval already registered?
ivl := adt.NewStringAffineInterval(string(wa.key), string(wa.end))
if iv := wg.ranges.Find(ivl); iv != nil {
iv.Val.(watcherSet).add(wa)
return
}
// not registered, put in interval tree
ws := make(watcherSet)
ws.add(wa)
wg.ranges.Insert(ivl, ws)
}
// contains is whether the given key has a watcher in the group.
func (wg *watcherGroup) contains(key string) bool {
_, ok := wg.keyWatchers[key]
return ok || wg.ranges.Contains(adt.NewStringAffinePoint(key))
}
// size gives the number of unique watchers in the group.
func (wg *watcherGroup) size() int { return len(wg.watchers) }
// delete removes a watcher from the group.
func (wg *watcherGroup) delete(wa *watcher) bool {
if _, ok := wg.watchers[wa]; !ok {
return false
}
wg.watchers.delete(wa)
if wa.end == nil {
wg.keyWatchers.delete(wa)
return true
}
ivl := adt.NewStringAffineInterval(string(wa.key), string(wa.end))
iv := wg.ranges.Find(ivl)
if iv == nil {
return false
}
ws := iv.Val.(watcherSet)
delete(ws, wa)
if len(ws) == 0 {
// remove interval missing watchers
if ok := wg.ranges.Delete(ivl); !ok {
panic("could not remove watcher from interval tree")
}
}
return true
}
func (wg *watcherGroup) scanMinRev(curRev int64, compactRev int64) int64 {
minRev := int64(math.MaxInt64)
for w := range wg.watchers {
if w.cur > curRev {
panic("watcher current revision should not exceed current revision")
}
if w.cur < compactRev {
select {
case w.ch <- WatchResponse{WatchID: w.id, CompactRevision: compactRev}:
wg.delete(w)
default:
// retry next time
}
continue
}
if minRev > w.cur {
minRev = w.cur
}
}
return minRev
}
// watcherSetByKey gets the set of watchers that recieve events on the given key.
func (wg *watcherGroup) watcherSetByKey(key string) watcherSet {
wkeys := wg.keyWatchers[key]
wranges := wg.ranges.Stab(adt.NewStringAffinePoint(key))
// zero-copy cases
switch {
case len(wranges) == 0:
// no need to merge ranges or copy; reuse single-key set
return wkeys
case len(wranges) == 0 && len(wkeys) == 0:
return nil
case len(wranges) == 1 && len(wkeys) == 0:
return wranges[0].Val.(watcherSet)
}
// copy case
ret := make(watcherSet)
ret.union(wg.keyWatchers[key])
for _, item := range wranges {
ret.union(item.Val.(watcherSet))
}
return ret
}

View File

@ -35,7 +35,7 @@ func TestWatcherWatchID(t *testing.T) {
idm := make(map[WatchID]struct{})
for i := 0; i < 10; i++ {
id := w.Watch([]byte("foo"), false, 0)
id := w.Watch([]byte("foo"), nil, 0)
if _, ok := idm[id]; ok {
t.Errorf("#%d: id %d exists", i, id)
}
@ -57,7 +57,7 @@ func TestWatcherWatchID(t *testing.T) {
// unsynced watchers
for i := 10; i < 20; i++ {
id := w.Watch([]byte("foo2"), false, 1)
id := w.Watch([]byte("foo2"), nil, 1)
if _, ok := idm[id]; ok {
t.Errorf("#%d: id %d exists", i, id)
}
@ -86,12 +86,11 @@ func TestWatcherWatchPrefix(t *testing.T) {
idm := make(map[WatchID]struct{})
prefixMatch := true
val := []byte("bar")
keyWatch, keyPut := []byte("foo"), []byte("foobar")
keyWatch, keyEnd, keyPut := []byte("foo"), []byte("fop"), []byte("foobar")
for i := 0; i < 10; i++ {
id := w.Watch(keyWatch, prefixMatch, 0)
id := w.Watch(keyWatch, keyEnd, 0)
if _, ok := idm[id]; ok {
t.Errorf("#%d: unexpected duplicated id %x", i, id)
}
@ -118,12 +117,12 @@ func TestWatcherWatchPrefix(t *testing.T) {
}
}
keyWatch1, keyPut1 := []byte("foo1"), []byte("foo1bar")
keyWatch1, keyEnd1, keyPut1 := []byte("foo1"), []byte("foo2"), []byte("foo1bar")
s.Put(keyPut1, val, lease.NoLease)
// unsynced watchers
for i := 10; i < 15; i++ {
id := w.Watch(keyWatch1, prefixMatch, 1)
id := w.Watch(keyWatch1, keyEnd1, 1)
if _, ok := idm[id]; ok {
t.Errorf("#%d: id %d exists", i, id)
}
@ -159,7 +158,7 @@ func TestWatchStreamCancelWatcherByID(t *testing.T) {
w := s.NewWatchStream()
defer w.Close()
id := w.Watch([]byte("foo"), false, 0)
id := w.Watch([]byte("foo"), nil, 0)
tests := []struct {
cancelID WatchID