diff --git a/raft/log.go b/raft/log.go index 89be6f0fc..10211da2d 100644 --- a/raft/log.go +++ b/raft/log.go @@ -168,12 +168,13 @@ func (l *raftLog) compact(i int64) int64 { return int64(len(l.ents)) } -func (l *raftLog) snap(d []byte, index, term int64, nodes []int64) { +func (l *raftLog) snap(d []byte, index, term int64, nodes []int64, removed []int64) { l.snapshot = pb.Snapshot{ - Data: d, - Nodes: nodes, - Index: index, - Term: term, + Data: d, + Nodes: nodes, + Index: index, + Term: term, + RemovedNodes: removed, } } diff --git a/raft/node_test.go b/raft/node_test.go index 6270acf3d..0eb763a6a 100644 --- a/raft/node_test.go +++ b/raft/node_test.go @@ -231,10 +231,11 @@ func TestNodeCompact(t *testing.T) { n.Propose(ctx, []byte("foo")) w := raftpb.Snapshot{ - Term: 1, - Index: 2, // one nop + one proposal - Data: []byte("a snapshot"), - Nodes: []int64{1}, + Term: 1, + Index: 2, // one nop + one proposal + Data: []byte("a snapshot"), + Nodes: []int64{1}, + RemovedNodes: []int64{}, } pkg.ForceGosched() diff --git a/raft/raft.go b/raft/raft.go index 2917385f6..53bd065fa 100644 --- a/raft/raft.go +++ b/raft/raft.go @@ -523,7 +523,10 @@ func (r *raft) compact(index int64, nodes []int64, d []byte) { if index > r.raftLog.applied { panic(fmt.Sprintf("raft: compact index (%d) exceeds applied index (%d)", index, r.raftLog.applied)) } - r.raftLog.snap(d, index, r.raftLog.term(index), nodes) + // We do not get the removed nodes at the given index. + // We get the removed nodes at current index. So a state machine might + // have a newer verison of removed nodes after recovery. It is OK. + r.raftLog.snap(d, index, r.raftLog.term(index), nodes, r.removedNodes()) r.raftLog.compact(index) } @@ -543,6 +546,10 @@ func (r *raft) restore(s pb.Snapshot) bool { r.setProgress(n, 0, r.raftLog.lastIndex()+1) } } + r.removed = make(map[int64]bool) + for _, n := range s.RemovedNodes { + r.removed[n] = true + } return true } @@ -564,6 +571,14 @@ func (r *raft) nodes() []int64 { return nodes } +func (r *raft) removedNodes() []int64 { + removed := make([]int64, 0, len(r.removed)) + for k := range r.removed { + removed = append(removed, k) + } + return removed +} + func (r *raft) setProgress(id, match, next int64) { r.prs[id] = &progress{next: next, match: match} } diff --git a/raft/raft_test.go b/raft/raft_test.go index a9dfbad89..ef1012116 100644 --- a/raft/raft_test.go +++ b/raft/raft_test.go @@ -413,12 +413,13 @@ func TestCompact(t *testing.T) { tests := []struct { compacti int64 nodes []int64 + removed []int64 snapd []byte wpanic bool }{ - {1, []int64{1, 2, 3}, []byte("some data"), false}, - {2, []int64{1, 2, 3}, []byte("some data"), false}, - {4, []int64{1, 2, 3}, []byte("some data"), true}, // compact out of range + {1, []int64{1, 2, 3}, []int64{4, 5}, []byte("some data"), false}, + {2, []int64{1, 2, 3}, []int64{4, 5}, []byte("some data"), false}, + {4, []int64{1, 2, 3}, []int64{4, 5}, []byte("some data"), true}, // compact out of range } for i, tt := range tests { @@ -426,7 +427,7 @@ func TestCompact(t *testing.T) { defer func() { if r := recover(); r != nil { if tt.wpanic != true { - t.Errorf("%d: panic = %v, want %v", i, false, true) + t.Errorf("%d: panic = %v, want %v", i, true, tt.wpanic) } } }() @@ -437,8 +438,14 @@ func TestCompact(t *testing.T) { applied: 2, ents: []pb.Entry{{}, {Term: 1}, {Term: 1}, {Term: 1}}, }, + removed: make(map[int64]bool), + } + for _, r := range tt.removed { + sm.removeNode(r) } sm.compact(tt.compacti, tt.nodes, tt.snapd) + sort.Sort(int64Slice(sm.raftLog.snapshot.Nodes)) + sort.Sort(int64Slice(sm.raftLog.snapshot.RemovedNodes)) if sm.raftLog.offset != tt.compacti { t.Errorf("%d: log.offset = %d, want %d", i, sm.raftLog.offset, tt.compacti) } @@ -448,6 +455,9 @@ func TestCompact(t *testing.T) { if !reflect.DeepEqual(sm.raftLog.snapshot.Data, tt.snapd) { t.Errorf("%d: snap.data = %v, want %v", i, sm.raftLog.snapshot.Data, tt.snapd) } + if !reflect.DeepEqual(sm.raftLog.snapshot.RemovedNodes, tt.removed) { + t.Errorf("%d: snap.removedNodes = %v, want %v", i, sm.raftLog.snapshot.RemovedNodes, tt.removed) + } }() } } @@ -886,9 +896,10 @@ func TestRecvMsgBeat(t *testing.T) { func TestRestore(t *testing.T) { s := pb.Snapshot{ - Index: defaultCompactThreshold + 1, - Term: defaultCompactThreshold + 1, - Nodes: []int64{1, 2, 3}, + Index: defaultCompactThreshold + 1, + Term: defaultCompactThreshold + 1, + Nodes: []int64{1, 2, 3}, + RemovedNodes: []int64{4, 5}, } sm := newRaft(1, []int64{1, 2}, 10, 1) @@ -902,12 +913,15 @@ func TestRestore(t *testing.T) { if sm.raftLog.term(s.Index) != s.Term { t.Errorf("log.lastTerm = %d, want %d", sm.raftLog.term(s.Index), s.Term) } - sg := int64Slice(sm.nodes()) - sw := int64Slice(s.Nodes) - sort.Sort(sg) - sort.Sort(sw) - if !reflect.DeepEqual(sg, sw) { - t.Errorf("sm.Nodes = %+v, want %+v", sg, sw) + sg := sm.nodes() + srn := sm.removedNodes() + sort.Sort(int64Slice(sg)) + sort.Sort(int64Slice(srn)) + if !reflect.DeepEqual(sg, s.Nodes) { + t.Errorf("sm.Nodes = %+v, want %+v", sg, s.Nodes) + } + if !reflect.DeepEqual(s.RemovedNodes, srn) { + t.Errorf("sm.RemovedNodes = %+v, want %+v", s.RemovedNodes, srn) } if !reflect.DeepEqual(sm.raftLog.snapshot, s) { t.Errorf("snapshot = %+v, want %+v", sm.raftLog.snapshot, s) diff --git a/raft/raftpb/raft.pb.go b/raft/raftpb/raft.pb.go index b9f60c239..9b136ccf5 100644 --- a/raft/raftpb/raft.pb.go +++ b/raft/raftpb/raft.pb.go @@ -124,6 +124,7 @@ type Snapshot struct { Nodes []int64 `protobuf:"varint,2,rep,name=nodes" json:"nodes"` Index int64 `protobuf:"varint,3,req,name=index" json:"index"` Term int64 `protobuf:"varint,4,req,name=term" json:"term"` + RemovedNodes []int64 `protobuf:"varint,5,rep,name=removed_nodes" json:"removed_nodes"` XXX_unrecognized []byte `json:"-"` } @@ -430,6 +431,23 @@ func (m *Snapshot) Unmarshal(data []byte) error { break } } + case 5: + if wireType != 0 { + return code_google_com_p_gogoprotobuf_proto.ErrWrongType + } + var v int64 + for shift := uint(0); ; shift += 7 { + if index >= l { + return io.ErrUnexpectedEOF + } + b := data[index] + index++ + v |= (int64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.RemovedNodes = append(m.RemovedNodes, v) default: var sizeOfWire int for { @@ -894,6 +912,11 @@ func (m *Snapshot) Size() (n int) { } n += 1 + sovRaft(uint64(m.Index)) n += 1 + sovRaft(uint64(m.Term)) + if len(m.RemovedNodes) > 0 { + for _, e := range m.RemovedNodes { + n += 1 + sovRaft(uint64(e)) + } + } if m.XXX_unrecognized != nil { n += len(m.XXX_unrecognized) } @@ -1055,6 +1078,19 @@ func (m *Snapshot) MarshalTo(data []byte) (n int, err error) { data[i] = 0x20 i++ i = encodeVarintRaft(data, i, uint64(m.Term)) + if len(m.RemovedNodes) > 0 { + for _, num := range m.RemovedNodes { + data[i] = 0x28 + i++ + for num >= 1<<7 { + data[i] = uint8(uint64(num)&0x7f | 0x80) + num >>= 7 + i++ + } + data[i] = uint8(num) + i++ + } + } if m.XXX_unrecognized != nil { i += copy(data[i:], m.XXX_unrecognized) } diff --git a/raft/raftpb/raft.proto b/raft/raftpb/raft.proto index 4d0a8d1d9..83cf5cff4 100644 --- a/raft/raftpb/raft.proto +++ b/raft/raftpb/raft.proto @@ -25,10 +25,11 @@ message Entry { } message Snapshot { - required bytes data = 1 [(gogoproto.nullable) = false]; - repeated int64 nodes = 2 [(gogoproto.nullable) = false]; - required int64 index = 3 [(gogoproto.nullable) = false]; - required int64 term = 4 [(gogoproto.nullable) = false]; + required bytes data = 1 [(gogoproto.nullable) = false]; + repeated int64 nodes = 2 [(gogoproto.nullable) = false]; + required int64 index = 3 [(gogoproto.nullable) = false]; + required int64 term = 4 [(gogoproto.nullable) = false]; + repeated int64 removed_nodes = 5 [(gogoproto.nullable) = false]; } message Message {