diff --git a/raft/rafttest/network.go b/raft/rafttest/network.go index 0da9e045c..5bd75ed6d 100644 --- a/raft/rafttest/network.go +++ b/raft/rafttest/network.go @@ -1,6 +1,7 @@ package rafttest import ( + "math/rand" "sync" "time" @@ -31,12 +32,18 @@ type network interface { type raftNetwork struct { mu sync.Mutex disconnected map[uint64]bool + dropmap map[conn]float64 recvQueues map[uint64]chan raftpb.Message } +type conn struct { + from, to uint64 +} + func newRaftNetwork(nodes ...uint64) *raftNetwork { pn := &raftNetwork{ recvQueues: make(map[uint64]chan raftpb.Message), + dropmap: make(map[conn]float64), disconnected: make(map[uint64]bool), } @@ -56,11 +63,16 @@ func (rn *raftNetwork) send(m raftpb.Message) { if rn.disconnected[m.To] { to = nil } + drop := rn.dropmap[conn{m.From, m.To}] rn.mu.Unlock() if to == nil { return } + if drop != 0 && rand.Float64() < drop { + return + } + to <- m } @@ -76,14 +88,20 @@ func (rn *raftNetwork) recvFrom(from uint64) chan raftpb.Message { } func (rn *raftNetwork) drop(from, to uint64, rate float64) { - panic("unimplemented") + rn.mu.Lock() + defer rn.mu.Unlock() + rn.dropmap[conn{from, to}] = rate } func (rn *raftNetwork) delay(from, to uint64, d time.Duration, rate float64) { panic("unimplemented") } -func (rn *raftNetwork) heal() {} +func (rn *raftNetwork) heal() { + rn.mu.Lock() + defer rn.mu.Unlock() + rn.dropmap = make(map[conn]float64) +} func (rn *raftNetwork) disconnect(id uint64) { rn.mu.Lock() diff --git a/raft/rafttest/network_test.go b/raft/rafttest/network_test.go new file mode 100644 index 000000000..3718ef283 --- /dev/null +++ b/raft/rafttest/network_test.go @@ -0,0 +1,36 @@ +package rafttest + +import ( + "testing" + + "github.com/coreos/etcd/raft/raftpb" +) + +func TestNetworkDrop(t *testing.T) { + // drop around 10% messages + sent := 1000 + droprate := 0.1 + nt := newRaftNetwork(1, 2) + nt.drop(1, 2, droprate) + for i := 0; i < sent; i++ { + nt.send(raftpb.Message{From: 1, To: 2}) + } + + c := nt.recvFrom(2) + + received := 0 + done := false + for !done { + select { + case <-c: + received++ + default: + done = true + } + } + + drop := sent - received + if drop > int((droprate+0.1)*float64(sent)) || drop < int((droprate-0.1)*float64(sent)) { + t.Errorf("drop = %d, want around %d", drop, droprate*float64(sent)) + } +}