diff --git a/raft/node.go b/raft/node.go index 9ff89c699..ad1aa6509 100644 --- a/raft/node.go +++ b/raft/node.go @@ -13,7 +13,7 @@ type Node struct { func New(k, addr int, next Interface) *Node { n := &Node{ - sm: newStateMachine(k, addr, next), + sm: newStateMachine(k, addr), } return n } diff --git a/raft/raft.go b/raft/raft.go index bb21bee87..77bef9697 100644 --- a/raft/raft.go +++ b/raft/raft.go @@ -108,15 +108,15 @@ type stateMachine struct { votes map[int]bool - next Interface + msgs []Message // the leader addr lead int } -func newStateMachine(k, addr int, next Interface) *stateMachine { +func newStateMachine(k, addr int) *stateMachine { log := make([]Entry, 1, 1024) - sm := &stateMachine{k: k, addr: addr, next: next, log: log} + sm := &stateMachine{k: k, addr: addr, log: log} sm.reset() return sm } @@ -145,6 +145,14 @@ func (sm *stateMachine) append(after int, ents ...Entry) int { return len(sm.log) - 1 } +func (sm *stateMachine) maybeAppend(index, logTerm int, ents ...Entry) bool { + if sm.isLogOk(index, logTerm) { + sm.append(index, ents...) + return true + } + return false +} + func (sm *stateMachine) isLogOk(i, term int) bool { if i > sm.li() { return false @@ -152,11 +160,11 @@ func (sm *stateMachine) isLogOk(i, term int) bool { return sm.log[i].Term == term } -// send persists state to stable storage and then sends m over the network to m.To +// send persists state to stable storage and then sends to its mailbox func (sm *stateMachine) send(m Message) { m.From = sm.addr m.Term = sm.term - sm.next.Step(m) + sm.msgs = append(sm.msgs, m) } // sendAppend sends RRPC, with entries to all peers that are not up-to-date according to sm.mis. @@ -233,14 +241,39 @@ func (sm *stateMachine) becomeFollower(term, lead int) { sm.state = stateFollower } +func (sm *stateMachine) becomeCandidate() { + // TODO(xiangli) remove the panic when the raft implementation is stable + if sm.state == stateLeader { + panic("invalid transition [leader -> candidate]") + } + sm.reset() + sm.term++ + sm.vote = sm.addr + sm.state = stateCandidate + sm.poll(sm.addr, true) +} + +func (sm *stateMachine) becomeLeader() { + // TODO(xiangli) remove the panic when the raft implementation is stable + if sm.state == stateFollower { + panic("invalid transition [follower -> leader]") + } + sm.reset() + sm.lead = sm.addr + sm.state = stateLeader +} + +func (sm *stateMachine) Msgs() []Message { + msgs := sm.msgs + sm.msgs = make([]Message, 0) + + return msgs +} + func (sm *stateMachine) Step(m Message) { switch m.Type { case msgHup: - sm.term++ - sm.reset() - sm.state = stateCandidate - sm.vote = sm.addr - sm.poll(sm.addr, true) + sm.becomeCandidate() for i := 0; i < sm.k; i++ { if i == sm.addr { continue @@ -301,8 +334,7 @@ func (sm *stateMachine) Step(m Message) { gr := sm.poll(m.From, m.Index >= 0) switch sm.q() { case gr: - sm.state = stateLeader - sm.lead = sm.addr + sm.becomeLeader() sm.sendAppend() case len(sm.votes) - gr: sm.becomeFollower(sm.term, none) diff --git a/raft/raft_test.go b/raft/raft_test.go index f6e2dd46d..d0eb529dc 100644 --- a/raft/raft_test.go +++ b/raft/raft_test.go @@ -23,9 +23,9 @@ func TestLeaderElection(t *testing.T) { { newNetwork( nil, - &stateMachine{log: []Entry{{}, {Term: 1}}}, - &stateMachine{log: []Entry{{}, {Term: 2}}}, - &stateMachine{log: []Entry{{}, {Term: 1}, {Term: 3}}}, + &nsm{stateMachine{log: []Entry{{}, {Term: 1}}}, nil}, + &nsm{stateMachine{log: []Entry{{}, {Term: 2}}}, nil}, + &nsm{stateMachine{log: []Entry{{}, {Term: 1}, {Term: 3}}}, nil}, nil, ), stateFollower, @@ -34,10 +34,10 @@ func TestLeaderElection(t *testing.T) { // logs converge { newNetwork( - &stateMachine{log: []Entry{{}, {Term: 1}}}, + &nsm{stateMachine{log: []Entry{{}, {Term: 1}}}, nil}, nil, - &stateMachine{log: []Entry{{}, {Term: 2}}}, - &stateMachine{log: []Entry{{}, {Term: 1}}}, + &nsm{stateMachine{log: []Entry{{}, {Term: 2}}}, nil}, + &nsm{stateMachine{log: []Entry{{}, {Term: 1}}}, nil}, nil, ), stateLeader, @@ -46,7 +46,7 @@ func TestLeaderElection(t *testing.T) { for i, tt := range tests { tt.Step(Message{To: 0, Type: msgHup}) - sm := tt.network.ss[0].(*stateMachine) + sm := tt.network.ss[0].(*nsm) if sm.state != tt.state { t.Errorf("#%d: state = %s, want %s", i, sm.state, tt.state) } @@ -57,8 +57,8 @@ func TestLeaderElection(t *testing.T) { } func TestDualingCandidates(t *testing.T) { - a := &stateMachine{log: defaultLog} - c := &stateMachine{log: defaultLog} + a := &nsm{stateMachine{log: defaultLog}, nil} + c := &nsm{stateMachine{log: defaultLog}, nil} tt := newNetwork(a, nil, c) @@ -82,7 +82,7 @@ func TestDualingCandidates(t *testing.T) { tt.Step(Message{To: 2, Type: msgHup}) tests := []struct { - sm *stateMachine + sm *nsm state stateType term int }{ @@ -106,7 +106,7 @@ func TestDualingCandidates(t *testing.T) { } func TestCandidateConcede(t *testing.T) { - a := &stateMachine{log: defaultLog} + a := &nsm{stateMachine{log: defaultLog}, nil} tt := newNetwork(a, nil, nil) tt.tee = stepperFunc(func(m Message) { @@ -143,7 +143,7 @@ func TestOldMessages(t *testing.T) { tt := newNetwork(nil, nil, nil) // make 0 leader @ term 3 tt.Step(Message{To: 0, Type: msgHup}) - tt.Step(Message{To: 0, Type: msgHup}) + tt.Step(Message{To: 1, Type: msgHup}) tt.Step(Message{To: 0, Type: msgHup}) // pretend we're an old leader trying to make progress tt.Step(Message{To: 0, Type: msgApp, Term: 1, Entries: []Entry{{Term: 1}}}) @@ -204,7 +204,7 @@ func TestProposal(t *testing.T) { t.Errorf("#%d: diff:%s", i, diff) } } - sm := tt.network.ss[0].(*stateMachine) + sm := tt.network.ss[0].(*nsm) if g := sm.term; g != 1 { t.Errorf("#%d: term = %d, want %d", i, g, 1) } @@ -235,7 +235,7 @@ func TestProposalByProxy(t *testing.T) { t.Errorf("#%d: bad entry: %s", i, diff) } } - sm := tt.ss[0].(*stateMachine) + sm := tt.ss[0].(*nsm) if g := sm.term; g != 1 { t.Errorf("#%d: term = %d, want %d", i, g, 1) } @@ -305,7 +305,7 @@ func TestVote(t *testing.T) { for i, tt := range tests { called := false - sm := &stateMachine{log: []Entry{{}, {Term: 2}, {Term: 2}}} + sm := &nsm{stateMachine{log: []Entry{{}, {Term: 2}, {Term: 2}}}, nil} sm.next = stepperFunc(func(m Message) { called = true if m.Index != tt.w { @@ -319,6 +319,46 @@ func TestVote(t *testing.T) { } } +func TestAllServerStepdown(t *testing.T) { + tests := []stateType{stateFollower, stateCandidate, stateLeader} + + want := struct { + state stateType + term int + index int + }{stateFollower, 3, 1} + + tmsgTypes := [...]messageType{msgVote, msgApp} + tterm := 3 + + for i, tt := range tests { + sm := newStateMachine(3, 0) + switch tt { + case stateFollower: + sm.becomeFollower(1, 0) + case stateCandidate: + sm.becomeCandidate() + case stateLeader: + sm.becomeCandidate() + sm.becomeLeader() + } + + for j, msgType := range tmsgTypes { + sm.Step(Message{Type: msgType, Term: tterm, LogTerm: tterm}) + + if sm.state != want.state { + t.Errorf("#%d.%d state = %v , want %v", i, j, sm.state, want.state) + } + if sm.term != want.term { + t.Errorf("#%d.%d term = %v , want %v", i, j, sm.term, want.term) + } + if len(sm.log) != want.index { + t.Errorf("#%d.%d index = %v , want %v", i, j, len(sm.log), want.index) + } + } + } +} + func TestLogDiff(t *testing.T) { a := []Entry{{}, {Term: 1}, {Term: 2}} b := []Entry{{}, {Term: 1}, {Term: 2}} @@ -349,8 +389,8 @@ func newNetwork(nodes ...Interface) *network { for i, n := range nodes { switch v := n.(type) { case nil: - nt.ss[i] = newStateMachine(len(nodes), i, nt) - case *stateMachine: + nt.ss[i] = &nsm{*newStateMachine(len(nodes), i), nt} + case *nsm: v.k = len(nodes) v.addr = i if v.next == nil { @@ -375,7 +415,7 @@ func (nt network) Step(m Message) { func (nt network) logs() [][]Entry { ls := make([][]Entry, len(nt.ss)) for i, node := range nt.ss { - if sm, ok := node.(*stateMachine); ok { + if sm, ok := node.(*nsm); ok { ls[i] = sm.log } } @@ -462,3 +502,16 @@ type stepperFunc func(Message) func (f stepperFunc) Step(m Message) { f(m) } var nopStepper = stepperFunc(func(Message) {}) + +type nsm struct { + stateMachine + next Interface +} + +func (n *nsm) Step(m Message) { + (&n.stateMachine).Step(m) + ms := n.Msgs() + for _, m := range ms { + n.next.Step(m) + } +}