diff --git a/raft/node.go b/raft/node.go index bf7b2ef05..f2711f792 100644 --- a/raft/node.go +++ b/raft/node.go @@ -209,3 +209,12 @@ func (n *Node) UpdateConf(t int64, c *Config) { func (n *Node) UnstableEnts() []Entry { return n.sm.raftLog.unstableEnts() } + +func (n *Node) UnstableState() State { + if n.sm.unstableState == emptyState { + return emptyState + } + s := n.sm.unstableState + n.sm.clearState() + return s +} diff --git a/raft/raft.go b/raft/raft.go index f15553710..0502af4eb 100644 --- a/raft/raft.go +++ b/raft/raft.go @@ -66,6 +66,14 @@ func (st stateType) String() string { return stmap[int64(st)] } +type State struct { + Term int64 + Vote int64 + Commit int64 +} + +var emptyState = State{} + type Message struct { Type messageType ClusterId int64 @@ -151,6 +159,8 @@ type stateMachine struct { pendingConf bool snapshoter Snapshoter + + unstableState State } func newStateMachine(id int64, peers []int64) *stateMachine { @@ -273,9 +283,9 @@ func (sm *stateMachine) nextEnts() (ents []Entry) { } func (sm *stateMachine) reset(term int64) { - sm.term.Set(term) + sm.setTerm(term) sm.lead.Set(none) - sm.vote = none + sm.setVote(none) sm.votes = make(map[int64]bool) for i := range sm.ins { sm.ins[i] = &index{next: sm.raftLog.lastIndex() + 1} @@ -316,7 +326,7 @@ func (sm *stateMachine) becomeCandidate() { panic("invalid transition [leader -> candidate]") } sm.reset(sm.term.Get() + 1) - sm.vote = sm.id + sm.setVote(sm.id) sm.state = stateCandidate } @@ -399,12 +409,12 @@ func (sm *stateMachine) handleSnapshot(m Message) { } func (sm *stateMachine) addNode(id int64) { - sm.ins[id] = &index{next: sm.raftLog.lastIndex() + 1} + sm.addIns(id, 0, sm.raftLog.lastIndex()+1) sm.pendingConf = false } func (sm *stateMachine) removeNode(id int64) { - delete(sm.ins, id) + sm.deleteIns(id) sm.pendingConf = false } @@ -483,7 +493,7 @@ func stepFollower(sm *stateMachine, m Message) bool { sm.handleSnapshot(m) case msgVote: if (sm.vote == none || sm.vote == m.From) && sm.raftLog.isUpToDate(m.Index, m.LogTerm) { - sm.vote = m.From + sm.setVote(m.From) sm.send(Message{To: m.From, Type: msgVoteResp, Index: sm.raftLog.lastIndex()}) } else { sm.send(Message{To: m.From, Type: msgVoteResp, Index: -1}) @@ -515,9 +525,10 @@ func (sm *stateMachine) restore(s Snapshot) { sm.index.Set(sm.raftLog.lastIndex()) sm.ins = make(map[int64]*index) for _, n := range s.Nodes { - sm.ins[n] = &index{next: sm.raftLog.lastIndex() + 1} if n == sm.id { - sm.ins[n].match = sm.raftLog.lastIndex() + sm.addIns(n, sm.raftLog.lastIndex(), sm.raftLog.lastIndex()+1) + } else { + sm.addIns(n, 0, sm.raftLog.lastIndex()+1) } } sm.pendingConf = false @@ -541,3 +552,40 @@ func (sm *stateMachine) nodes() []int64 { } return nodes } + +func (sm *stateMachine) setTerm(term int64) { + sm.term.Set(term) + sm.saveState() +} + +func (sm *stateMachine) setVote(vote int64) { + sm.vote = vote + sm.saveState() +} + +func (sm *stateMachine) addIns(id, match, next int64) { + sm.ins[id] = &index{next: next, match: match} + sm.saveState() +} + +func (sm *stateMachine) deleteIns(id int64) { + delete(sm.ins, id) + sm.saveState() +} + +// saveState saves the state to sm.unstableState +// When there is a term change, vote change or configuration change, raft +// must call saveState. +func (sm *stateMachine) saveState() { + sm.setState(sm.vote, sm.term.Get(), sm.raftLog.committed) +} + +func (sm *stateMachine) clearState() { + sm.setState(0, 0, 0) +} + +func (sm *stateMachine) setState(vote, term, commit int64) { + sm.unstableState.Vote = vote + sm.unstableState.Term = term + sm.unstableState.Commit = commit +} diff --git a/raft/raft_test.go b/raft/raft_test.go index 99ddcc39c..f60947a38 100644 --- a/raft/raft_test.go +++ b/raft/raft_test.go @@ -954,6 +954,41 @@ func TestSlowNodeRestore(t *testing.T) { } } +func TestUnstableState(t *testing.T) { + sm := newStateMachine(0, []int64{0}) + w := State{} + + sm.setVote(1) + w.Vote = 1 + if !reflect.DeepEqual(sm.unstableState, w) { + t.Errorf("unstableState = %v, want %v", sm.unstableState, w) + } + sm.clearState() + + sm.setTerm(1) + w.Term = 1 + if !reflect.DeepEqual(sm.unstableState, w) { + t.Errorf("unstableState = %v, want %v", sm.unstableState, w) + } + sm.clearState() + + sm.raftLog.committed = 1 + sm.addIns(1, 0, 0) + w.Commit = 1 + if !reflect.DeepEqual(sm.unstableState, w) { + t.Errorf("unstableState = %v, want %v", sm.unstableState, w) + } + sm.clearState() + + sm.raftLog.committed = 2 + sm.deleteIns(1) + w.Commit = 2 + if !reflect.DeepEqual(sm.unstableState, w) { + t.Errorf("unstableState = %v, want %v", sm.unstableState, w) + } + sm.clearState() +} + func ents(terms ...int64) *stateMachine { ents := []Entry{{}} for _, term := range terms {