diff --git a/raft/rafttest/network.go b/raft/rafttest/network.go index 006711ab3..d305a44b1 100644 --- a/raft/rafttest/network.go +++ b/raft/rafttest/network.go @@ -1,6 +1,7 @@ package rafttest import ( + "sync" "time" "github.com/coreos/etcd/raft/raftpb" @@ -14,15 +15,21 @@ type network interface { // delay message for (0, d] randomly at given rate (1.0 delay all messages) // do we need rate here? delay(from, to uint64, d time.Duration, rate float64) + + disconnect(id uint64) + connect(id uint64) } type raftNetwork struct { - recvQueues map[uint64]chan raftpb.Message + mu sync.Mutex + disconnected map[uint64]bool + recvQueues map[uint64]chan raftpb.Message } func newRaftNetwork(nodes ...uint64) *raftNetwork { pn := &raftNetwork{ - recvQueues: make(map[uint64]chan raftpb.Message, 0), + recvQueues: make(map[uint64]chan raftpb.Message), + disconnected: make(map[uint64]bool), } for _, n := range nodes { @@ -36,18 +43,27 @@ func (rn *raftNetwork) nodeNetwork(id uint64) *nodeNetwork { } func (rn *raftNetwork) send(m raftpb.Message) { + rn.mu.Lock() to := rn.recvQueues[m.To] + if rn.disconnected[m.To] { + to = nil + } + rn.mu.Unlock() + if to == nil { - panic("sent to nil") + return } to <- m } func (rn *raftNetwork) recvFrom(from uint64) chan raftpb.Message { + rn.mu.Lock() fromc := rn.recvQueues[from] - if fromc == nil { - panic("recv from nil") + if rn.disconnected[from] { + fromc = nil } + rn.mu.Unlock() + return fromc } @@ -59,6 +75,18 @@ func (rn *raftNetwork) delay(from, to uint64, d time.Duration, rate float64) { panic("unimplemented") } +func (rn *raftNetwork) disconnect(id uint64) { + rn.mu.Lock() + defer rn.mu.Unlock() + rn.disconnected[id] = true +} + +func (rn *raftNetwork) connect(id uint64) { + rn.mu.Lock() + defer rn.mu.Unlock() + rn.disconnected[id] = false +} + type nodeNetwork struct { id uint64 *raftNetwork diff --git a/raft/rafttest/node.go b/raft/rafttest/node.go index 56cd7411f..fdceb92f5 100644 --- a/raft/rafttest/node.go +++ b/raft/rafttest/node.go @@ -11,6 +11,7 @@ import ( type node struct { raft.Node + id uint64 paused bool nt network stopc chan struct{} @@ -25,12 +26,18 @@ func startNode(id uint64, peers []raft.Peer, nt network) *node { rn := raft.StartNode(id, peers, 10, 1, st) n := &node{ Node: rn, + id: id, storage: st, nt: nt, - stopc: make(chan struct{}), } + n.start() + return n +} +func (n *node) start() { + n.stopc = make(chan struct{}) ticker := time.Tick(5 * time.Millisecond) + go func() { for { select { @@ -39,32 +46,46 @@ func startNode(id uint64, peers []raft.Peer, nt network) *node { case rd := <-n.Ready(): if !raft.IsEmptyHardState(rd.HardState) { n.state = rd.HardState + n.storage.SetHardState(n.state) } n.storage.Append(rd.Entries) go func() { for _, m := range rd.Messages { - nt.send(m) + n.nt.send(m) } }() n.Advance() case m := <-n.nt.recv(): n.Step(context.TODO(), m) case <-n.stopc: - log.Printf("raft.%d: stop", id) + n.Stop() + log.Printf("raft.%d: stop", n.id) + n.Node = nil + close(n.stopc) return } } }() - return n } -func (n *node) stop() { close(n.stopc) } - -// restart restarts the node with the given delay. -// All in memory state of node is reset to initialized state. +// stop stops the node. stop a stopped node might panic. +// All in memory state of node is discarded. // All stable MUST be unchanged. -func (n *node) restart(delay time.Duration) { - panic("unimplemented") +func (n *node) stop() { + n.nt.disconnect(n.id) + n.stopc <- struct{}{} + // wait for the shutdown + <-n.stopc +} + +// restart restarts the node. restart a started node +// blocks and might affect the future stop operation. +func (n *node) restart() { + // wait for the shutdown + <-n.stopc + n.Node = raft.RestartNode(n.id, 10, 1, n.storage, 0) + n.start() + n.nt.connect(n.id) } // pause pauses the node. diff --git a/raft/rafttest/node_test.go b/raft/rafttest/node_test.go index a38865016..04e70ccff 100644 --- a/raft/rafttest/node_test.go +++ b/raft/rafttest/node_test.go @@ -27,8 +27,47 @@ func TestBasicProgress(t *testing.T) { time.Sleep(100 * time.Millisecond) for _, n := range nodes { n.stop() - if n.state.Commit < 1000 { - t.Errorf("commit = %d, want > 1000", n.state.Commit) + if n.state.Commit != 1006 { + t.Errorf("commit = %d, want = 1006", n.state.Commit) + } + } +} + +func TestRestart(t *testing.T) { + peers := []raft.Peer{{1, nil}, {2, nil}, {3, nil}, {4, nil}, {5, nil}} + nt := newRaftNetwork(1, 2, 3, 4, 5) + + nodes := make([]*node, 0) + + for i := 1; i <= 5; i++ { + n := startNode(uint64(i), peers, nt.nodeNetwork(uint64(i))) + nodes = append(nodes, n) + } + + time.Sleep(50 * time.Millisecond) + for i := 0; i < 300; i++ { + nodes[0].Propose(context.TODO(), []byte("somedata")) + } + nodes[1].stop() + for i := 0; i < 300; i++ { + nodes[0].Propose(context.TODO(), []byte("somedata")) + } + nodes[2].stop() + for i := 0; i < 300; i++ { + nodes[0].Propose(context.TODO(), []byte("somedata")) + } + nodes[2].restart() + for i := 0; i < 300; i++ { + nodes[0].Propose(context.TODO(), []byte("somedata")) + } + nodes[1].restart() + + // give some time for nodes to catch up with the raft leader + time.Sleep(300 * time.Millisecond) + for _, n := range nodes { + n.stop() + if n.state.Commit != 1206 { + t.Errorf("commit = %d, want = 1206", n.state.Commit) } } }