diff --git a/etcdserver/server_test.go b/etcdserver/server_test.go index e7e7016dc..5e73206a8 100644 --- a/etcdserver/server_test.go +++ b/etcdserver/server_test.go @@ -1305,6 +1305,8 @@ func (n *nodeRecorder) Stop() { func (n *nodeRecorder) ReportUnreachable(id uint64) {} +func (n *nodeRecorder) ReportSnapshot(id uint64, status raft.SnapshotStatus) {} + func (n *nodeRecorder) Compact(index uint64, nodes []uint64, d []byte) { n.Record(testutil.Action{Name: "Compact"}) } diff --git a/raft/node.go b/raft/node.go index 5c60b5aa3..00e25bdce 100644 --- a/raft/node.go +++ b/raft/node.go @@ -22,6 +22,13 @@ import ( pb "github.com/coreos/etcd/raft/raftpb" ) +type SnapshotStatus int + +const ( + SnapshotFinish SnapshotStatus = 1 + SnapshotFailure SnapshotStatus = 2 +) + var ( emptyState = pb.HardState{} @@ -68,6 +75,8 @@ type Ready struct { // Messages specifies outbound messages to be sent AFTER Entries are // committed to stable storage. + // If it contains a MsgSnap message, the application MUST report back to raft + // when the snapshot has been received or has failed by calling ReportSnapshot. Messages []pb.Message } @@ -121,6 +130,8 @@ type Node interface { Status() Status // Report reports the given node is not reachable for the last send. ReportUnreachable(id uint64) + // ReportSnapshot reports the stutus of the sent snapshot. + ReportSnapshot(id uint64, status SnapshotStatus) // Stop performs any necessary termination of the Node Stop() } @@ -427,6 +438,15 @@ func (n *node) ReportUnreachable(id uint64) { } } +func (n *node) ReportSnapshot(id uint64, status SnapshotStatus) { + rej := status == SnapshotFailure + + select { + case n.recvc <- pb.Message{Type: pb.MsgSnapStatus, From: id, Reject: rej}: + case <-n.done: + } +} + func newReady(r *raft, prevSoftSt *SoftState, prevHardSt pb.HardState) Ready { rd := Ready{ Entries: r.raftLog.unstableEntries(), diff --git a/raft/node_test.go b/raft/node_test.go index 4c78aba4a..cbd537f9b 100644 --- a/raft/node_test.go +++ b/raft/node_test.go @@ -42,7 +42,7 @@ func TestNodeStep(t *testing.T) { t.Errorf("%d: cannot receive %s on propc chan", msgt, msgn) } } else { - if msgt == raftpb.MsgBeat || msgt == raftpb.MsgHup { + if msgt == raftpb.MsgBeat || msgt == raftpb.MsgHup || msgt == raftpb.MsgUnreachable || msgt == raftpb.MsgSnapStatus { select { case <-n.recvc: t.Errorf("%d: step should ignore %s", msgt, msgn) diff --git a/raft/raft.go b/raft/raft.go index f1c179561..585f267d3 100644 --- a/raft/raft.go +++ b/raft/raft.go @@ -60,6 +60,15 @@ type Progress struct { // Unreachable will be unset if raft starts to receive message (msgAppResp, // msgHeartbeatResp) from the remote peer of the Progress. Unreachable bool + // If there is a pending snapshot, the pendingSnapshot will be set to the + // index of the snapshot. If pendingSnapshot is set, the replication process of + // this Progress will be paused. raft will not resend snapshot until the pending one + // is reported to be failed. + // + // PendingSnapshot is set when raft sends out a snapshot to this Progress. + // PendingSnapshot is unset when the snapshot is reported to be successfully, + // or raft updates an equal or higher Match for this Progress. + PendingSnapshot uint64 } func (pr *Progress) update(n uint64) { @@ -114,6 +123,33 @@ func (pr *Progress) reachable() { pr.Unreachable = false } func (pr *Progress) unreachable() { pr.Unreachable = true } func (pr *Progress) shouldWait() bool { return (pr.Unreachable || pr.Match == 0) && pr.Wait > 0 } +func (pr *Progress) hasPendingSnapshot() bool { return pr.PendingSnapshot != 0 } +func (pr *Progress) setPendingSnapshot(i uint64) { pr.PendingSnapshot = i } + +// finishSnapshot unsets the pending snapshot and optimistically increase Next to +// the index of pendingSnapshot + 1. The next replication message is expected +// to be msgApp. +func (pr *Progress) snapshotFinish() { + pr.Next = pr.PendingSnapshot + 1 + pr.PendingSnapshot = 0 +} + +// snapshotFail unsets the pending snapshot. The next replication message is expected +// to be another msgSnap. +func (pr *Progress) snapshotFail() { + pr.PendingSnapshot = 0 +} + +// maybeSnapshotAbort unsets pendingSnapshot if Match is equal or higher than +// the pendingSnapshot +func (pr *Progress) maybeSnapshotAbort() bool { + if pr.hasPendingSnapshot() && pr.Match >= pr.PendingSnapshot { + pr.PendingSnapshot = 0 + return true + } + return false +} + func (pr *Progress) String() string { return fmt.Sprintf("next = %d, match = %d, wait = %v", pr.Next, pr.Match, pr.Wait) } @@ -227,7 +263,7 @@ func (r *raft) send(m pb.Message) { // sendAppend sends RRPC, with entries to the given peer. func (r *raft) sendAppend(to uint64) { pr := r.prs[to] - if pr.shouldWait() { + if pr.shouldWait() || pr.hasPendingSnapshot() { return } m := pb.Message{} @@ -251,7 +287,8 @@ func (r *raft) sendAppend(to uint64) { sindex, sterm := snapshot.Metadata.Index, snapshot.Metadata.Term log.Printf("raft: %x [firstindex: %d, commit: %d] sent snapshot[index: %d, term: %d] to %x [%s]", r.id, r.raftLog.firstIndex(), r.Commit, sindex, sterm, to, pr) - pr.waitSet(r.electionTimeout) + pr.setPendingSnapshot(sindex) + log.Printf("raft: %x paused sending replication messages to %x [%s]", r.id, to, pr) } else { m.Type = pb.MsgApp m.Index = pr.Next - 1 @@ -509,6 +546,9 @@ func stepLeader(r *raft, m pb.Message) { } else { oldWait := pr.shouldWait() pr.update(m.Index) + if r.prs[m.From].maybeSnapshotAbort() { + log.Printf("raft: %x snapshot aborted, resumed sending replication messages to %x [%s]", r.id, m.From, pr) + } if r.maybeCommit() { r.bcastAppend() } else if oldWait { @@ -526,6 +566,20 @@ func stepLeader(r *raft, m pb.Message) { log.Printf("raft: %x [logterm: %d, index: %d, vote: %x] rejected vote from %x [logterm: %d, index: %d] at term %d", r.id, r.raftLog.lastTerm(), r.raftLog.lastIndex(), r.Vote, m.From, m.LogTerm, m.Index, r.Term) r.send(pb.Message{To: m.From, Type: pb.MsgVoteResp, Reject: true}) + case pb.MsgSnapStatus: + if !pr.hasPendingSnapshot() { + return + } + if m.Reject { + pr.snapshotFail() + log.Printf("raft: %x snapshot failed, resumed sending replication messages to %x [%s]", r.id, m.From, pr) + } else { + pr.snapshotFinish() + log.Printf("raft: %x snapshot succeeded resumed sending replication messages to %x [%s]", r.id, m.From, pr) + // wait for the msgAppResp from the remote node before sending + // out the next msgApp + pr.waitSet(r.electionTimeout) + } case pb.MsgUnreachable: r.prs[m.From].unreachable() } diff --git a/raft/raft_snap_test.go b/raft/raft_snap_test.go new file mode 100644 index 000000000..62c4dbb53 --- /dev/null +++ b/raft/raft_snap_test.go @@ -0,0 +1,128 @@ +// Copyright 2015 CoreOS, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package raft + +import ( + "testing" + + pb "github.com/coreos/etcd/raft/raftpb" +) + +var ( + testingSnap = pb.Snapshot{ + Metadata: pb.SnapshotMetadata{ + Index: 11, // magic number + Term: 11, // magic number + ConfState: pb.ConfState{Nodes: []uint64{1, 2}}, + }, + } +) + +func TestSendingSnapshotSetPendingSnapshot(t *testing.T) { + storage := NewMemoryStorage() + sm := newRaft(1, []uint64{1}, 10, 1, storage, 0) + sm.restore(testingSnap) + + sm.becomeCandidate() + sm.becomeLeader() + + // force set the next of node 1, so that + // node 1 needs a snapshot + sm.prs[2].Next = sm.raftLog.firstIndex() + + sm.Step(pb.Message{From: 2, To: 1, Type: pb.MsgAppResp, Index: sm.prs[2].Next - 1, Reject: true}) + if sm.prs[2].PendingSnapshot != 11 { + t.Fatalf("PendingSnapshot = %d, want 11", sm.prs[2].PendingSnapshot) + } +} + +func TestPendingSnapshotPauseReplication(t *testing.T) { + storage := NewMemoryStorage() + sm := newRaft(1, []uint64{1, 2}, 10, 1, storage, 0) + sm.restore(testingSnap) + + sm.becomeCandidate() + sm.becomeLeader() + + sm.prs[2].setPendingSnapshot(11) + + sm.Step(pb.Message{From: 1, To: 1, Type: pb.MsgProp, Entries: []pb.Entry{{Data: []byte("somedata")}}}) + msgs := sm.readMessages() + if len(msgs) != 0 { + t.Fatalf("len(msgs) = %d, want 0", len(msgs)) + } +} + +func TestSnapshotFailure(t *testing.T) { + storage := NewMemoryStorage() + sm := newRaft(1, []uint64{1, 2}, 10, 1, storage, 0) + sm.restore(testingSnap) + + sm.becomeCandidate() + sm.becomeLeader() + + sm.prs[2].Next = 1 + sm.prs[2].setPendingSnapshot(11) + + sm.Step(pb.Message{From: 2, To: 1, Type: pb.MsgSnapStatus, Reject: true}) + if sm.prs[2].PendingSnapshot != 0 { + t.Fatalf("PendingSnapshot = %d, want 0", sm.prs[2].PendingSnapshot) + } + if sm.prs[2].Next != 1 { + t.Fatalf("Next = %d, want 1", sm.prs[2].Next) + } +} + +func TestSnapshotSucceed(t *testing.T) { + storage := NewMemoryStorage() + sm := newRaft(1, []uint64{1, 2}, 10, 1, storage, 0) + sm.restore(testingSnap) + + sm.becomeCandidate() + sm.becomeLeader() + + sm.prs[2].Next = 1 + sm.prs[2].setPendingSnapshot(11) + + sm.Step(pb.Message{From: 2, To: 1, Type: pb.MsgSnapStatus, Reject: false}) + if sm.prs[2].PendingSnapshot != 0 { + t.Fatalf("PendingSnapshot = %d, want 0", sm.prs[2].PendingSnapshot) + } + if sm.prs[2].Next != 12 { + t.Fatalf("Next = %d, want 12", sm.prs[2].Next) + } +} + +func TestSnapshotAbort(t *testing.T) { + storage := NewMemoryStorage() + sm := newRaft(1, []uint64{1, 2}, 10, 1, storage, 0) + sm.restore(testingSnap) + + sm.becomeCandidate() + sm.becomeLeader() + + sm.prs[2].Next = 1 + sm.prs[2].setPendingSnapshot(11) + + // A successful msgAppResp that has a higher/equal index than the + // pending snapshot should abort the pending snapshot. + sm.Step(pb.Message{From: 2, To: 1, Type: pb.MsgAppResp, Index: 11}) + if sm.prs[2].PendingSnapshot != 0 { + t.Fatalf("PendingSnapshot = %d, want 0", sm.prs[2].PendingSnapshot) + } + if sm.prs[2].Next != 12 { + t.Fatalf("Next = %d, want 12", sm.prs[2].Next) + } +} diff --git a/raft/raftpb/raft.pb.go b/raft/raftpb/raft.pb.go index c8056572a..9865a07bf 100644 --- a/raft/raftpb/raft.pb.go +++ b/raft/raftpb/raft.pb.go @@ -80,6 +80,7 @@ const ( MsgHeartbeat MessageType = 8 MsgHeartbeatResp MessageType = 9 MsgUnreachable MessageType = 10 + MsgSnapStatus MessageType = 11 ) var MessageType_name = map[int32]string{ @@ -94,6 +95,7 @@ var MessageType_name = map[int32]string{ 8: "MsgHeartbeat", 9: "MsgHeartbeatResp", 10: "MsgUnreachable", + 11: "MsgSnapStatus", } var MessageType_value = map[string]int32{ "MsgHup": 0, @@ -107,6 +109,7 @@ var MessageType_value = map[string]int32{ "MsgHeartbeat": 8, "MsgHeartbeatResp": 9, "MsgUnreachable": 10, + "MsgSnapStatus": 11, } func (x MessageType) Enum() *MessageType { diff --git a/raft/raftpb/raft.proto b/raft/raftpb/raft.proto index 579546ed0..eadc45f85 100644 --- a/raft/raftpb/raft.proto +++ b/raft/raftpb/raft.proto @@ -32,17 +32,18 @@ message Snapshot { } enum MessageType { - MsgHup = 0; - MsgBeat = 1; - MsgProp = 2; - MsgApp = 3; - MsgAppResp = 4; - MsgVote = 5; - MsgVoteResp = 6; - MsgSnap = 7; - MsgHeartbeat = 8; - MsgHeartbeatResp = 9; - MsgUnreachable = 10; + MsgHup = 0; + MsgBeat = 1; + MsgProp = 2; + MsgApp = 3; + MsgAppResp = 4; + MsgVote = 5; + MsgVoteResp = 6; + MsgSnap = 7; + MsgHeartbeat = 8; + MsgHeartbeatResp = 9; + MsgUnreachable = 10; + MsgSnapStatus = 11; } message Message { diff --git a/raft/util.go b/raft/util.go index 2dc4a9182..6e512fd4d 100644 --- a/raft/util.go +++ b/raft/util.go @@ -46,7 +46,9 @@ func max(a, b uint64) uint64 { return b } -func IsLocalMsg(m pb.Message) bool { return m.Type == pb.MsgHup || m.Type == pb.MsgBeat } +func IsLocalMsg(m pb.Message) bool { + return m.Type == pb.MsgHup || m.Type == pb.MsgBeat || m.Type == pb.MsgUnreachable || m.Type == pb.MsgSnapStatus +} func IsResponseMsg(m pb.Message) bool { return m.Type == pb.MsgAppResp || m.Type == pb.MsgVoteResp || m.Type == pb.MsgHeartbeatResp || m.Type == pb.MsgUnreachable