From 5651272ec86fdfd8e9adf3c80bc538390a4fe010 Mon Sep 17 00:00:00 2001 From: Xiang Li Date: Wed, 2 Jul 2014 12:49:58 -0700 Subject: [PATCH] raft: handle snapshot message --- raft/raft.go | 38 ++++++++++++++++++--- raft/raft_test.go | 87 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 120 insertions(+), 5 deletions(-) diff --git a/raft/raft.go b/raft/raft.go index b35fcb6cf..0699c4dc2 100644 --- a/raft/raft.go +++ b/raft/raft.go @@ -17,6 +17,7 @@ const ( msgAppResp msgVote msgVoteResp + msgSnap ) var mtmap = [...]string{ @@ -27,6 +28,7 @@ var mtmap = [...]string{ msgAppResp: "msgAppResp", msgVote: "msgVote", msgVoteResp: "msgVoteResp", + msgSnap: "msgSnap", } func (mt messageType) String() string { @@ -69,6 +71,7 @@ type Message struct { PrevTerm int Entries []Entry Commit int + Snapshot Snapshot } type index struct { @@ -151,12 +154,17 @@ func (sm *stateMachine) send(m Message) { func (sm *stateMachine) sendAppend(to int) { in := sm.ins[to] m := Message{} - m.Type = msgApp m.To = to m.Index = in.next - 1 - m.LogTerm = sm.log.term(in.next - 1) - m.Entries = sm.log.entries(in.next) - m.Commit = sm.log.committed + if sm.needSnapshot(m.Index) { + m.Type = msgSnap + m.Snapshot = sm.snapshoter.GetSnap() + } else { + m.Type = msgApp + m.LogTerm = sm.log.term(in.next - 1) + m.Entries = sm.log.entries(in.next) + m.Commit = sm.log.committed + } sm.send(m) } @@ -244,7 +252,7 @@ func (sm *stateMachine) becomeLeader() { sm.lead = sm.id sm.state = stateLeader - for _, e := range sm.log.ents[sm.log.committed:] { + for _, e := range sm.log.entries(sm.log.committed + 1) { if e.isConfig() { sm.pendingConf = true } @@ -298,6 +306,11 @@ func (sm *stateMachine) handleAppendEntries(m Message) { } } +func (sm *stateMachine) handleSnapshot(m Message) { + sm.restore(m.Snapshot) + sm.send(Message{To: m.From, Type: msgAppResp, Index: sm.log.lastIndex()}) +} + func (sm *stateMachine) addNode(id int) { sm.ins[id] = &index{next: sm.log.lastIndex() + 1} sm.pendingConf = false @@ -350,6 +363,9 @@ func stepCandidate(sm *stateMachine, m Message) bool { case msgApp: sm.becomeFollower(sm.term, m.From) sm.handleAppendEntries(m) + case msgSnap: + sm.becomeFollower(m.Term, m.From) + sm.handleSnapshot(m) case msgVote: sm.send(Message{To: m.From, Type: msgVoteResp, Index: -1}) case msgVoteResp: @@ -375,6 +391,8 @@ func stepFollower(sm *stateMachine, m Message) bool { sm.send(m) case msgApp: sm.handleAppendEntries(m) + case msgSnap: + sm.handleSnapshot(m) case msgVote: if (sm.vote == none || sm.vote == m.From) && sm.log.isUpToDate(m.Index, m.LogTerm) { sm.vote = m.From @@ -417,6 +435,16 @@ func (sm *stateMachine) restore(s Snapshot) { sm.snapshoter.Restore(s) } +func (sm *stateMachine) needSnapshot(i int) bool { + if i < sm.log.offset { + if sm.snapshoter == nil { + panic("need snapshot but snapshoter is nil") + } + return true + } + return false +} + func (sm *stateMachine) nodes() []int { nodes := make([]int, 0, len(sm.ins)) for k := range sm.ins { diff --git a/raft/raft_test.go b/raft/raft_test.go index dcc3d7715..0dadb091f 100644 --- a/raft/raft_test.go +++ b/raft/raft_test.go @@ -802,6 +802,92 @@ func TestRestore(t *testing.T) { } } +func TestProvideSnap(t *testing.T) { + s := Snapshot{ + Index: defaultCompactThreshold + 1, + Term: defaultCompactThreshold + 1, + Nodes: []int{0, 1}, + } + sm := newStateMachine(0, []int{0}) + sm.setSnapshoter(new(logSnapshoter)) + // restore the statemachin from a snapshot + // so it has a compacted log and a snapshot + sm.restore(s) + + sm.becomeCandidate() + sm.becomeLeader() + + sm.Step(Message{Type: msgBeat}) + msgs := sm.Msgs() + if len(msgs) != 1 { + t.Errorf("len(msgs) = %d, want 1", len(msgs)) + } + m := msgs[0] + if m.Type != msgApp { + t.Errorf("m.Type = %v, want %v", m.Type, msgApp) + } + + // force set the next of node 1, so that + // node 1 needs a snapshot + sm.ins[1].next = sm.log.offset + + sm.Step(Message{Type: msgBeat}) + msgs = sm.Msgs() + if len(msgs) != 1 { + t.Errorf("len(msgs) = %d, want 1", len(msgs)) + } + m = msgs[0] + if m.Type != msgSnap { + t.Errorf("m.Type = %v, want %v", m.Type, msgSnap) + } +} + +func TestRestoreFromSnapMsg(t *testing.T) { + s := Snapshot{ + Index: defaultCompactThreshold + 1, + Term: defaultCompactThreshold + 1, + Nodes: []int{0, 1}, + } + m := Message{Type: msgSnap, From: 0, Term: 1, Snapshot: s} + + sm := newStateMachine(1, []int{0, 1}) + sm.setSnapshoter(new(logSnapshoter)) + sm.Step(m) + + if !reflect.DeepEqual(sm.snapshoter.GetSnap(), s) { + t.Errorf("snapshot = %+v, want %+v", sm.snapshoter.GetSnap(), s) + } +} + +func TestSlowNodeRestore(t *testing.T) { + nt := newNetwork(nil, nil, nil) + nt.send(Message{To: 0, Type: msgHup}) + + nt.isolate(2) + for j := 0; j < defaultCompactThreshold+1; j++ { + nt.send(Message{To: 0, Type: msgProp, Entries: []Entry{{}}}) + } + lead := nt.peers[0].(*stateMachine) + lead.nextEnts() + if !lead.maybeCompact() { + t.Errorf("compacted = false, want true") + } + + nt.recover() + nt.send(Message{To: 0, Type: msgBeat}) + + follower := nt.peers[2].(*stateMachine) + if !reflect.DeepEqual(follower.snapshoter.GetSnap(), lead.snapshoter.GetSnap()) { + t.Errorf("follower.snap = %+v, want %+v", follower.snapshoter.GetSnap(), lead.snapshoter.GetSnap()) + } + + committed := follower.log.lastIndex() + nt.send(Message{To: 0, Type: msgProp, Entries: []Entry{{}}}) + if follower.log.committed != committed+1 { + t.Errorf("follower.comitted = %d, want %d", follower.log.committed, committed+1) + } +} + func ents(terms ...int) *stateMachine { ents := []Entry{{}} for _, term := range terms { @@ -836,6 +922,7 @@ func newNetwork(peers ...Interface) *network { switch v := p.(type) { case nil: sm := newStateMachine(id, defaultPeerAddrs) + sm.setSnapshoter(new(logSnapshoter)) npeers[id] = sm case *stateMachine: v.id = id