diff --git a/raft/node.go b/raft/node.go index f0ef4d260..ed9b01a39 100644 --- a/raft/node.go +++ b/raft/node.go @@ -231,11 +231,20 @@ func (n *node) Tick() { } func (n *node) Campaign(ctx context.Context) error { - return n.Step(ctx, pb.Message{Type: msgHup}) + return n.step(ctx, pb.Message{Type: msgHup}) } func (n *node) Propose(ctx context.Context, data []byte) error { - return n.Step(ctx, pb.Message{Type: msgProp, Entries: []pb.Entry{{Data: data}}}) + return n.step(ctx, pb.Message{Type: msgProp, Entries: []pb.Entry{{Data: data}}}) +} + +func (n *node) Step(ctx context.Context, m pb.Message) error { + // ignore unexpected local messages receiving over network + if m.Type == msgHup || m.Type == msgBeat { + // TODO: return an error? + return nil + } + return n.step(ctx, m) } func (n *node) ProposeConfChange(ctx context.Context, cc pb.ConfChange) error { @@ -248,7 +257,7 @@ func (n *node) ProposeConfChange(ctx context.Context, cc pb.ConfChange) error { // Step advances the state machine using msgs. The ctx.Err() will be returned, // if any. -func (n *node) Step(ctx context.Context, m pb.Message) error { +func (n *node) step(ctx context.Context, m pb.Message) error { ch := n.recvc if m.Type == msgProp { ch = n.propc diff --git a/raft/node_test.go b/raft/node_test.go index a3f2470ce..782653aab 100644 --- a/raft/node_test.go +++ b/raft/node_test.go @@ -18,7 +18,8 @@ func TestNodeStep(t *testing.T) { propc: make(chan raftpb.Message, 1), recvc: make(chan raftpb.Message, 1), } - n.Step(context.TODO(), raftpb.Message{Type: int64(i)}) + msgt := int64(i) + n.Step(context.TODO(), raftpb.Message{Type: msgt}) // Proposal goes to proc chan. Others go to recvc chan. if int64(i) == msgProp { select { @@ -27,10 +28,18 @@ func TestNodeStep(t *testing.T) { t.Errorf("%d: cannot receive %s on propc chan", i, mtmap[i]) } } else { - select { - case <-n.recvc: - default: - t.Errorf("%d: cannot receive %s on recvc chan", i, mtmap[i]) + if msgt == msgBeat || msgt == msgHup { + select { + case <-n.recvc: + t.Errorf("%d: step should ignore msgHub/msgBeat", i, mtmap[i]) + default: + } + } else { + select { + case <-n.recvc: + default: + t.Errorf("%d: cannot receive %s on recvc chan", i, mtmap[i]) + } } } }