raft: public progress struct in raft

This commit is contained in:
Xiang Li 2015-01-20 10:26:22 -08:00
parent b34936b097
commit 003b97a60f
3 changed files with 84 additions and 83 deletions

View File

@ -52,61 +52,61 @@ func (st StateType) String() string {
return stmap[uint64(st)] return stmap[uint64(st)]
} }
type progress struct { type Progress struct {
match, next uint64 Match, Next uint64
wait int Wait int
} }
func (pr *progress) update(n uint64) { func (pr *Progress) update(n uint64) {
pr.waitReset() pr.waitReset()
if pr.match < n { if pr.Match < n {
pr.match = n pr.Match = n
} }
if pr.next < n+1 { if pr.Next < n+1 {
pr.next = n + 1 pr.Next = n + 1
} }
} }
func (pr *progress) optimisticUpdate(n uint64) { pr.next = n + 1 } func (pr *Progress) optimisticUpdate(n uint64) { pr.Next = n + 1 }
// maybeDecrTo returns false if the given to index comes from an out of order message. // maybeDecrTo returns false if the given to index comes from an out of order message.
// Otherwise it decreases the progress next index to min(rejected, last) and returns true. // Otherwise it decreases the progress next index to min(rejected, last) and returns true.
func (pr *progress) maybeDecrTo(rejected, last uint64) bool { func (pr *Progress) maybeDecrTo(rejected, last uint64) bool {
pr.waitReset() pr.waitReset()
if pr.match != 0 { if pr.Match != 0 {
// the rejection must be stale if the progress has matched and "rejected" // the rejection must be stale if the progress has matched and "rejected"
// is smaller than "match". // is smaller than "match".
if rejected <= pr.match { if rejected <= pr.Match {
return false return false
} }
// directly decrease next to match + 1 // directly decrease next to match + 1
pr.next = pr.match + 1 pr.Next = pr.Match + 1
return true return true
} }
// the rejection must be stale if "rejected" does not match next - 1 // the rejection must be stale if "rejected" does not match next - 1
if pr.next-1 != rejected { if pr.Next-1 != rejected {
return false return false
} }
if pr.next = min(rejected, last+1); pr.next < 1 { if pr.Next = min(rejected, last+1); pr.Next < 1 {
pr.next = 1 pr.Next = 1
} }
return true return true
} }
func (pr *progress) waitDecr(i int) { func (pr *Progress) waitDecr(i int) {
pr.wait -= i pr.Wait -= i
if pr.wait < 0 { if pr.Wait < 0 {
pr.wait = 0 pr.Wait = 0
} }
} }
func (pr *progress) waitSet(w int) { pr.wait = w } func (pr *Progress) waitSet(w int) { pr.Wait = w }
func (pr *progress) waitReset() { pr.wait = 0 } func (pr *Progress) waitReset() { pr.Wait = 0 }
func (pr *progress) shouldWait() bool { return pr.match == 0 && pr.wait > 0 } func (pr *Progress) shouldWait() bool { return pr.Match == 0 && pr.Wait > 0 }
func (pr *progress) String() string { func (pr *Progress) String() string {
return fmt.Sprintf("next = %d, match = %d, wait = %v", pr.next, pr.match, pr.wait) return fmt.Sprintf("next = %d, match = %d, wait = %v", pr.Next, pr.Match, pr.Wait)
} }
type raft struct { type raft struct {
@ -117,7 +117,7 @@ type raft struct {
// the log // the log
raftLog *raftLog raftLog *raftLog
prs map[uint64]*progress prs map[uint64]*Progress
state StateType state StateType
@ -161,13 +161,13 @@ func newRaft(id uint64, peers []uint64, election, heartbeat int, storage Storage
id: id, id: id,
lead: None, lead: None,
raftLog: raftlog, raftLog: raftlog,
prs: make(map[uint64]*progress), prs: make(map[uint64]*Progress),
electionTimeout: election, electionTimeout: election,
heartbeatTimeout: heartbeat, heartbeatTimeout: heartbeat,
} }
r.rand = rand.New(rand.NewSource(int64(id))) r.rand = rand.New(rand.NewSource(int64(id)))
for _, p := range peers { for _, p := range peers {
r.prs[p] = &progress{next: 1} r.prs[p] = &Progress{Next: 1}
} }
if !isHardStateEqual(hs, emptyState) { if !isHardStateEqual(hs, emptyState) {
r.loadState(hs) r.loadState(hs)
@ -220,7 +220,7 @@ func (r *raft) sendAppend(to uint64) {
} }
m := pb.Message{} m := pb.Message{}
m.To = to m.To = to
if r.needSnapshot(pr.next) { if r.needSnapshot(pr.Next) {
m.Type = pb.MsgSnap m.Type = pb.MsgSnap
snapshot, err := r.raftLog.snapshot() snapshot, err := r.raftLog.snapshot()
if err != nil { if err != nil {
@ -236,15 +236,15 @@ func (r *raft) sendAppend(to uint64) {
pr.waitSet(r.electionTimeout) pr.waitSet(r.electionTimeout)
} else { } else {
m.Type = pb.MsgApp m.Type = pb.MsgApp
m.Index = pr.next - 1 m.Index = pr.Next - 1
m.LogTerm = r.raftLog.term(pr.next - 1) m.LogTerm = r.raftLog.term(pr.Next - 1)
m.Entries = r.raftLog.entries(pr.next) m.Entries = r.raftLog.entries(pr.Next)
m.Commit = r.raftLog.committed m.Commit = r.raftLog.committed
// optimistically increase the next if the follower // optimistically increase the next if the follower
// has been matched. // has been matched.
if n := len(m.Entries); pr.match != 0 && n != 0 { if n := len(m.Entries); pr.Match != 0 && n != 0 {
pr.optimisticUpdate(m.Entries[n-1].Index) pr.optimisticUpdate(m.Entries[n-1].Index)
} else if pr.match == 0 { } else if pr.Match == 0 {
// TODO (xiangli): better way to find out if the follower is in good path or not // TODO (xiangli): better way to find out if the follower is in good path or not
// a follower might be in bad path even if match != 0, since we optimistically // a follower might be in bad path even if match != 0, since we optimistically
// increase the next. // increase the next.
@ -262,7 +262,7 @@ func (r *raft) sendHeartbeat(to uint64) {
// or it might not have all the committed entries. // or it might not have all the committed entries.
// The leader MUST NOT forward the follower's commit to // The leader MUST NOT forward the follower's commit to
// an unmatched index. // an unmatched index.
commit := min(r.prs[to].match, r.raftLog.committed) commit := min(r.prs[to].Match, r.raftLog.committed)
m := pb.Message{ m := pb.Message{
To: to, To: to,
Type: pb.MsgHeartbeat, Type: pb.MsgHeartbeat,
@ -297,7 +297,7 @@ func (r *raft) maybeCommit() bool {
// TODO(bmizerany): optimize.. Currently naive // TODO(bmizerany): optimize.. Currently naive
mis := make(uint64Slice, 0, len(r.prs)) mis := make(uint64Slice, 0, len(r.prs))
for i := range r.prs { for i := range r.prs {
mis = append(mis, r.prs[i].match) mis = append(mis, r.prs[i].Match)
} }
sort.Sort(sort.Reverse(mis)) sort.Sort(sort.Reverse(mis))
mci := mis[r.q()-1] mci := mis[r.q()-1]
@ -311,9 +311,9 @@ func (r *raft) reset(term uint64) {
r.elapsed = 0 r.elapsed = 0
r.votes = make(map[uint64]bool) r.votes = make(map[uint64]bool)
for i := range r.prs { for i := range r.prs {
r.prs[i] = &progress{next: r.raftLog.lastIndex() + 1} r.prs[i] = &Progress{Next: r.raftLog.lastIndex() + 1}
if i == r.id { if i == r.id {
r.prs[i].match = r.raftLog.lastIndex() r.prs[i].Match = r.raftLog.lastIndex()
} }
} }
r.pendingConf = false r.pendingConf = false
@ -495,7 +495,7 @@ func stepLeader(r *raft, m pb.Message) {
} }
} }
case pb.MsgHeartbeatResp: case pb.MsgHeartbeatResp:
if r.prs[m.From].match < r.raftLog.lastIndex() { if r.prs[m.From].Match < r.raftLog.lastIndex() {
r.sendAppend(m.From) r.sendAppend(m.From)
} }
case pb.MsgVote: case pb.MsgVote:
@ -616,7 +616,7 @@ func (r *raft) restore(s pb.Snapshot) bool {
r.id, r.Commit, r.raftLog.lastIndex(), r.raftLog.lastTerm(), s.Metadata.Index, s.Metadata.Term) r.id, r.Commit, r.raftLog.lastIndex(), r.raftLog.lastTerm(), s.Metadata.Index, s.Metadata.Term)
r.raftLog.restore(s) r.raftLog.restore(s)
r.prs = make(map[uint64]*progress) r.prs = make(map[uint64]*Progress)
for _, n := range s.Metadata.ConfState.Nodes { for _, n := range s.Metadata.ConfState.Nodes {
match, next := uint64(0), uint64(r.raftLog.lastIndex())+1 match, next := uint64(0), uint64(r.raftLog.lastIndex())+1
if n == r.id { if n == r.id {
@ -660,7 +660,7 @@ func (r *raft) removeNode(id uint64) {
func (r *raft) resetPendingConf() { r.pendingConf = false } func (r *raft) resetPendingConf() { r.pendingConf = false }
func (r *raft) setProgress(id, match, next uint64) { func (r *raft) setProgress(id, match, next uint64) {
r.prs[id] = &progress{next: next, match: match} r.prs[id] = &Progress{Next: next, Match: match}
} }
func (r *raft) delProgress(id uint64) { func (r *raft) delProgress(id uint64) {

View File

@ -64,16 +64,16 @@ func TestProgressUpdate(t *testing.T) {
{prevM + 2, prevM + 2, prevN + 1}, // increase match, next {prevM + 2, prevM + 2, prevN + 1}, // increase match, next
} }
for i, tt := range tests { for i, tt := range tests {
p := &progress{ p := &Progress{
match: prevM, Match: prevM,
next: prevN, Next: prevN,
} }
p.update(tt.update) p.update(tt.update)
if p.match != tt.wm { if p.Match != tt.wm {
t.Errorf("#%d: match= %d, want %d", i, p.match, tt.wm) t.Errorf("#%d: match= %d, want %d", i, p.Match, tt.wm)
} }
if p.next != tt.wn { if p.Next != tt.wn {
t.Errorf("#%d: next= %d, want %d", i, p.next, tt.wn) t.Errorf("#%d: next= %d, want %d", i, p.Next, tt.wn)
} }
} }
} }
@ -136,17 +136,17 @@ func TestProgressMaybeDecr(t *testing.T) {
}, },
} }
for i, tt := range tests { for i, tt := range tests {
p := &progress{ p := &Progress{
match: tt.m, Match: tt.m,
next: tt.n, Next: tt.n,
} }
if g := p.maybeDecrTo(tt.rejected, tt.last); g != tt.w { if g := p.maybeDecrTo(tt.rejected, tt.last); g != tt.w {
t.Errorf("#%d: maybeDecrTo= %t, want %t", i, g, tt.w) t.Errorf("#%d: maybeDecrTo= %t, want %t", i, g, tt.w)
} }
if gm := p.match; gm != tt.m { if gm := p.Match; gm != tt.m {
t.Errorf("#%d: match= %d, want %d", i, gm, tt.m) t.Errorf("#%d: match= %d, want %d", i, gm, tt.m)
} }
if gn := p.next; gn != tt.wn { if gn := p.Next; gn != tt.wn {
t.Errorf("#%d: next= %d, want %d", i, gn, tt.wn) t.Errorf("#%d: next= %d, want %d", i, gn, tt.wn)
} }
} }
@ -166,9 +166,9 @@ func TestProgressShouldWait(t *testing.T) {
{0, 0, false}, {0, 0, false},
} }
for i, tt := range tests { for i, tt := range tests {
p := &progress{ p := &Progress{
match: tt.m, Match: tt.m,
wait: tt.wait, Wait: tt.wait,
} }
if g := p.shouldWait(); g != tt.w { if g := p.shouldWait(); g != tt.w {
t.Errorf("#%d: shouldwait = %t, want %t", i, g, tt.w) t.Errorf("#%d: shouldwait = %t, want %t", i, g, tt.w)
@ -179,17 +179,17 @@ func TestProgressShouldWait(t *testing.T) {
// TestProgressWaitReset ensures that progress.Update and progress.DercTo // TestProgressWaitReset ensures that progress.Update and progress.DercTo
// will reset progress.wait. // will reset progress.wait.
func TestProgressWaitReset(t *testing.T) { func TestProgressWaitReset(t *testing.T) {
p := &progress{ p := &Progress{
wait: 1, Wait: 1,
} }
p.maybeDecrTo(1, 1) p.maybeDecrTo(1, 1)
if p.wait != 0 { if p.Wait != 0 {
t.Errorf("wait= %d, want 0", p.wait) t.Errorf("wait= %d, want 0", p.Wait)
} }
p.wait = 1 p.Wait = 1
p.update(2) p.update(2)
if p.wait != 0 { if p.Wait != 0 {
t.Errorf("wait= %d, want 0", p.wait) t.Errorf("wait= %d, want 0", p.Wait)
} }
} }
@ -198,11 +198,11 @@ func TestProgressDecr(t *testing.T) {
r := newRaft(1, []uint64{1, 2}, 5, 1, NewMemoryStorage()) r := newRaft(1, []uint64{1, 2}, 5, 1, NewMemoryStorage())
r.becomeCandidate() r.becomeCandidate()
r.becomeLeader() r.becomeLeader()
r.prs[2].wait = r.heartbeatTimeout * 2 r.prs[2].Wait = r.heartbeatTimeout * 2
r.Step(pb.Message{From: 1, To: 1, Type: pb.MsgBeat}) r.Step(pb.Message{From: 1, To: 1, Type: pb.MsgBeat})
if r.prs[2].wait != r.heartbeatTimeout*(2-1) { if r.prs[2].Wait != r.heartbeatTimeout*(2-1) {
t.Errorf("wait = %d, want %d", r.prs[2].wait, r.heartbeatTimeout*(2-1)) t.Errorf("wait = %d, want %d", r.prs[2].Wait, r.heartbeatTimeout*(2-1))
} }
} }
@ -1073,11 +1073,11 @@ func TestLeaderAppResp(t *testing.T) {
sm.Step(pb.Message{From: 2, Type: pb.MsgAppResp, Index: tt.index, Term: sm.Term, Reject: tt.reject, RejectHint: tt.index}) sm.Step(pb.Message{From: 2, Type: pb.MsgAppResp, Index: tt.index, Term: sm.Term, Reject: tt.reject, RejectHint: tt.index})
p := sm.prs[2] p := sm.prs[2]
if p.match != tt.wmatch { if p.Match != tt.wmatch {
t.Errorf("#%d match = %d, want %d", i, p.match, tt.wmatch) t.Errorf("#%d match = %d, want %d", i, p.Match, tt.wmatch)
} }
if p.next != tt.wnext { if p.Next != tt.wnext {
t.Errorf("#%d next = %d, want %d", i, p.next, tt.wnext) t.Errorf("#%d next = %d, want %d", i, p.Next, tt.wnext)
} }
msgs := sm.readMessages() msgs := sm.readMessages()
@ -1119,9 +1119,9 @@ func TestBcastBeat(t *testing.T) {
sm.appendEntry(pb.Entry{Index: uint64(i) + 1}) sm.appendEntry(pb.Entry{Index: uint64(i) + 1})
} }
// slow follower // slow follower
sm.prs[2].match, sm.prs[2].next = 5, 6 sm.prs[2].Match, sm.prs[2].Next = 5, 6
// normal follower // normal follower
sm.prs[3].match, sm.prs[3].next = sm.raftLog.lastIndex(), sm.raftLog.lastIndex()+1 sm.prs[3].Match, sm.prs[3].Next = sm.raftLog.lastIndex(), sm.raftLog.lastIndex()+1
sm.Step(pb.Message{Type: pb.MsgBeat}) sm.Step(pb.Message{Type: pb.MsgBeat})
msgs := sm.readMessages() msgs := sm.readMessages()
@ -1129,8 +1129,8 @@ func TestBcastBeat(t *testing.T) {
t.Fatalf("len(msgs) = %v, want 2", len(msgs)) t.Fatalf("len(msgs) = %v, want 2", len(msgs))
} }
wantCommitMap := map[uint64]uint64{ wantCommitMap := map[uint64]uint64{
2: min(sm.raftLog.committed, sm.prs[2].match), 2: min(sm.raftLog.committed, sm.prs[2].Match),
3: min(sm.raftLog.committed, sm.prs[3].match), 3: min(sm.raftLog.committed, sm.prs[3].Match),
} }
for i, m := range msgs { for i, m := range msgs {
if m.Type != pb.MsgHeartbeat { if m.Type != pb.MsgHeartbeat {
@ -1216,12 +1216,12 @@ func TestLeaderIncreaseNext(t *testing.T) {
sm.raftLog.append(previousEnts...) sm.raftLog.append(previousEnts...)
sm.becomeCandidate() sm.becomeCandidate()
sm.becomeLeader() sm.becomeLeader()
sm.prs[2].match, sm.prs[2].next = tt.match, tt.next sm.prs[2].Match, sm.prs[2].Next = tt.match, tt.next
sm.Step(pb.Message{From: 1, To: 1, Type: pb.MsgProp, Entries: []pb.Entry{{Data: []byte("somedata")}}}) sm.Step(pb.Message{From: 1, To: 1, Type: pb.MsgProp, Entries: []pb.Entry{{Data: []byte("somedata")}}})
p := sm.prs[2] p := sm.prs[2]
if p.next != tt.wnext { if p.Next != tt.wnext {
t.Errorf("#%d next = %d, want %d", i, p.next, tt.wnext) t.Errorf("#%d next = %d, want %d", i, p.Next, tt.wnext)
} }
} }
} }
@ -1310,9 +1310,9 @@ func TestProvideSnap(t *testing.T) {
// force set the next of node 1, so that // force set the next of node 1, so that
// node 1 needs a snapshot // node 1 needs a snapshot
sm.prs[2].next = sm.raftLog.firstIndex() sm.prs[2].Next = sm.raftLog.firstIndex()
sm.Step(pb.Message{From: 2, To: 1, Type: pb.MsgAppResp, Index: sm.prs[2].next - 1, Reject: true}) sm.Step(pb.Message{From: 2, To: 1, Type: pb.MsgAppResp, Index: sm.prs[2].Next - 1, Reject: true})
msgs := sm.readMessages() msgs := sm.readMessages()
if len(msgs) != 1 { if len(msgs) != 1 {
t.Fatalf("len(msgs) = %d, want 1", len(msgs)) t.Fatalf("len(msgs) = %d, want 1", len(msgs))
@ -1547,9 +1547,9 @@ func newNetwork(peers ...Interface) *network {
npeers[id] = sm npeers[id] = sm
case *raft: case *raft:
v.id = id v.id = id
v.prs = make(map[uint64]*progress) v.prs = make(map[uint64]*Progress)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
v.prs[peerAddrs[i]] = &progress{} v.prs[peerAddrs[i]] = &Progress{}
} }
v.reset(0) v.reset(0)
npeers[id] = v npeers[id] = v

View File

@ -27,9 +27,10 @@ type Status struct {
SoftState SoftState
Applied uint64 Applied uint64
Progress map[uint64]progress Progress map[uint64]Progress
} }
// getStatus gets a copy of the current raft status.
func getStatus(r *raft) Status { func getStatus(r *raft) Status {
s := Status{ID: r.id} s := Status{ID: r.id}
s.HardState = r.HardState s.HardState = r.HardState
@ -38,7 +39,7 @@ func getStatus(r *raft) Status {
s.Applied = r.raftLog.applied s.Applied = r.raftLog.applied
if s.RaftState == StateLeader { if s.RaftState == StateLeader {
s.Progress = make(map[uint64]progress) s.Progress = make(map[uint64]Progress)
for id, p := range r.prs { for id, p := range r.prs {
s.Progress[id] = *p s.Progress[id] = *p
} }