diff --git a/raft.go b/raft.go index d7592e918..e185bebff 100644 --- a/raft.go +++ b/raft.go @@ -34,7 +34,7 @@ func (mt messageType) String() string { var errNoLeader = errors.New("no leader") const ( - stateFollower = iota + stateFollower stateType = iota stateCandidate stateLeader ) @@ -244,6 +244,7 @@ func (sm *stateMachine) step(m Message) { sm.term++ sm.reset() sm.state = stateCandidate + sm.vote = sm.addr sm.poll(sm.addr, true) for i := 0; i < sm.k; i++ { if i == sm.addr { diff --git a/raft_test.go b/raft_test.go index aac248d16..971ccf58e 100644 --- a/raft_test.go +++ b/raft_test.go @@ -10,7 +10,7 @@ var defaultLog = []Entry{{}} func TestLeaderElection(t *testing.T) { tests := []struct { - network + *network state stateType }{ {newNetwork(nil, nil, nil), stateLeader}, @@ -22,7 +22,7 @@ func TestLeaderElection(t *testing.T) { for i, tt := range tests { tt.step(Message{To: 0, Type: msgHup}) - sm := tt.network[0].(*stateMachine) + sm := tt.network.ss[0].(*stateMachine) if sm.state != tt.state { t.Errorf("#%d: state = %s, want %s", i, sm.state, tt.state) } @@ -32,12 +32,44 @@ func TestLeaderElection(t *testing.T) { } } +func TestDualingCandidates(t *testing.T) { + a := &stateMachine{ + log: []Entry{{}}, + next: nopStepper, // field next is nil (partitioned) + } + c := &stateMachine{ + log: []Entry{{}}, + next: nopStepper, // field next is nil (partitioned) + } + tt := newNetwork(a, nil, c) + tt.tee = stepperFunc(func(m Message) { + t.Logf("m = %+v", m) + }) + tt.step(Message{To: 0, Type: msgHup}) + tt.step(Message{To: 2, Type: msgHup}) + + t.Log("healing") + tt.heal() + tt.step(Message{To: 2, Type: msgHup}) + if c.state != stateLeader { + t.Errorf("state = %s, want %s", c.state, stateLeader) + } + if g := c.term; g != 2 { + t.Errorf("term = %d, want %d", g, 2) + } + if g := diffLogs(tt.logs(defaultLog)); g != nil { + for _, diff := range g { + t.Errorf("bag log:\n%s", diff) + } + } +} + func TestProposal(t *testing.T) { data := []byte("somedata") successLog := []Entry{{}, {Term: 1, Data: data}} tests := []struct { - network + *network log []Entry willpanic bool }{ @@ -73,7 +105,7 @@ func TestProposal(t *testing.T) { t.Errorf("#%d: bag log:\n%s", i, diff) } } - sm := tt.network[0].(*stateMachine) + sm := tt.network.ss[0].(*stateMachine) if g := sm.term; g != 1 { t.Errorf("#%d: term = %d, want %d", i, g, 1) } @@ -85,7 +117,7 @@ func TestProposalByProxy(t *testing.T) { successLog := []Entry{{}, {Term: 1, Data: data}} tests := []struct { - network + *network log []Entry }{ {newNetwork(nil, nil, nil), successLog}, @@ -93,59 +125,71 @@ func TestProposalByProxy(t *testing.T) { } for i, tt := range tests { - step := stepperFunc(func(m Message) { + tt.tee = stepperFunc(func(m Message) { t.Logf("#%d: m = %+v", i, m) - tt.step(m) }) // promote 0 the leader - step(Message{To: 0, Type: msgHup}) + tt.step(Message{To: 0, Type: msgHup}) // propose via follower - step(Message{To: 1, Type: msgProp, Data: []byte("somedata")}) + tt.step(Message{To: 1, Type: msgProp, Data: []byte("somedata")}) if g := diffLogs(tt.logs(tt.log)); g != nil { for _, diff := range g { t.Errorf("#%d: bag log:\n%s", i, diff) } } - sm := tt.network[0].(*stateMachine) + sm := tt.network.ss[0].(*stateMachine) if g := sm.term; g != 1 { t.Errorf("#%d: term = %d, want %d", i, g, 1) } } } -type network []stepper +type network struct { + tee stepper + ss []stepper +} // newNetwork initializes a network from nodes. A nil node will be replaced // with a new *stateMachine. A *stateMachine will get its k, addr, and next // fields set. -func newNetwork(nodes ...stepper) network { - nt := network(nodes) +func newNetwork(nodes ...stepper) *network { + nt := &network{ss: nodes} for i, n := range nodes { switch v := n.(type) { case nil: - nt[i] = newStateMachine(len(nodes), i, &nt) + nt.ss[i] = newStateMachine(len(nodes), i, nt) case *stateMachine: v.k = len(nodes) v.addr = i - v.next = &nt } } return nt } func (nt network) step(m Message) { - nt[m.To].step(m) + if nt.tee != nil { + nt.tee.step(m) + } + nt.ss[m.To].step(m) +} + +func (nt network) heal() { + for _, s := range nt.ss { + if sm, ok := s.(*stateMachine); ok { + sm.next = nt + } + } } // logs returns all logs in nt prepended with want. If a node is not a // *stateMachine, its log will be nil. func (nt network) logs(want []Entry) [][]Entry { - ls := make([][]Entry, len(nt)+1) + ls := make([][]Entry, len(nt.ss)+1) ls[0] = want - for i, node := range nt { + for i, node := range nt.ss { if sm, ok := node.(*stateMachine); ok { ls[i] = sm.log }