diff --git a/raft/raft.go b/raft/raft.go index a0cc173c3..0169c1bb6 100644 --- a/raft/raft.go +++ b/raft/raft.go @@ -827,6 +827,7 @@ func stepFollower(r *raft, m pb.Message) { r.handleHeartbeat(m) case pb.MsgSnap: r.electionElapsed = 0 + r.lead = m.From r.handleSnapshot(m) case pb.MsgVote: if (r.Vote == None || r.Vote == m.From) && r.raftLog.isUpToDate(m.Index, m.LogTerm) { diff --git a/raft/raft_test.go b/raft/raft_test.go index 8318f6f78..824839f96 100644 --- a/raft/raft_test.go +++ b/raft/raft_test.go @@ -1901,6 +1901,10 @@ func TestRestoreFromSnapMsg(t *testing.T) { sm := newTestRaft(2, []uint64{1, 2}, 10, 1, NewMemoryStorage()) sm.Step(m) + if sm.lead != uint64(1) { + t.Errorf("sm.lead = %d, want 1", sm.lead) + } + // TODO(bdarnell): what should this test? }