diff --git a/contrib/raftexample/raft.go b/contrib/raftexample/raft.go index 901ddb0ff..b79145441 100644 --- a/contrib/raftexample/raft.go +++ b/contrib/raftexample/raft.go @@ -185,7 +185,7 @@ func (rc *raftNode) publishEntries(ents []raftpb.Entry) (<-chan struct{}, bool) var applyDoneC chan struct{} if len(data) > 0 { - applyDoneC := make(chan struct{}, 1) + applyDoneC = make(chan struct{}, 1) select { case rc.commitC <- &commit{data, applyDoneC}: case <-rc.stopc: diff --git a/contrib/raftexample/raftexample_test.go b/contrib/raftexample/raftexample_test.go index 9c11e7d61..c49d8d25c 100644 --- a/contrib/raftexample/raftexample_test.go +++ b/contrib/raftexample/raftexample_test.go @@ -27,12 +27,21 @@ import ( "go.etcd.io/etcd/raft/v3/raftpb" ) +func getSnapshotFn() (func() ([]byte, error), <-chan struct{}) { + snapshotTriggeredC := make(chan struct{}) + return func() ([]byte, error) { + snapshotTriggeredC <- struct{}{} + return nil, nil + }, snapshotTriggeredC +} + type cluster struct { - peers []string - commitC []<-chan *commit - errorC []<-chan error - proposeC []chan string - confChangeC []chan raftpb.ConfChange + peers []string + commitC []<-chan *commit + errorC []<-chan error + proposeC []chan string + confChangeC []chan raftpb.ConfChange + snapshotTriggeredC []<-chan struct{} } // newCluster creates a cluster of n nodes @@ -43,11 +52,12 @@ func newCluster(n int) *cluster { } clus := &cluster{ - peers: peers, - commitC: make([]<-chan *commit, len(peers)), - errorC: make([]<-chan error, len(peers)), - proposeC: make([]chan string, len(peers)), - confChangeC: make([]chan raftpb.ConfChange, len(peers)), + peers: peers, + commitC: make([]<-chan *commit, len(peers)), + errorC: make([]<-chan error, len(peers)), + proposeC: make([]chan string, len(peers)), + confChangeC: make([]chan raftpb.ConfChange, len(peers)), + snapshotTriggeredC: make([]<-chan struct{}, len(peers)), } for i := range clus.peers { @@ -55,7 +65,9 @@ func newCluster(n int) *cluster { os.RemoveAll(fmt.Sprintf("raftexample-%d-snap", i+1)) clus.proposeC[i] = make(chan string, 1) clus.confChangeC[i] = make(chan raftpb.ConfChange, 1) - clus.commitC[i], clus.errorC[i], _ = newRaftNode(i+1, clus.peers, false, nil, clus.proposeC[i], clus.confChangeC[i]) + fn, snapshotTriggeredC := getSnapshotFn() + clus.snapshotTriggeredC[i] = snapshotTriggeredC + clus.commitC[i], clus.errorC[i], _ = newRaftNode(i+1, clus.peers, false, fn, clus.proposeC[i], clus.confChangeC[i]) } return clus @@ -64,10 +76,12 @@ func newCluster(n int) *cluster { // Close closes all cluster nodes and returns an error if any failed. func (clus *cluster) Close() (err error) { for i := range clus.peers { + go func(i int) { + for range clus.commitC[i] { + // drain pending commits + } + }(i) close(clus.proposeC[i]) - for range clus.commitC[i] { - // drain pending commits - } // wait for channel to close if erri := <-clus.errorC[i]; erri != nil { err = erri @@ -241,3 +255,31 @@ func TestAddNewNode(t *testing.T) { t.Fatalf("Commit failed") } } + +func TestSnapshot(t *testing.T) { + prevDefaultSnapshotCount := defaultSnapshotCount + prevSnapshotCatchUpEntriesN := snapshotCatchUpEntriesN + defaultSnapshotCount = 4 + snapshotCatchUpEntriesN = 4 + defer func() { + defaultSnapshotCount = prevDefaultSnapshotCount + snapshotCatchUpEntriesN = prevSnapshotCatchUpEntriesN + }() + + clus := newCluster(3) + defer clus.closeNoErrors(t) + + go func() { + clus.proposeC[0] <- "foo" + }() + + c := <-clus.commitC[0] + + select { + case <-clus.snapshotTriggeredC[0]: + t.Fatalf("snapshot triggered before applying done") + default: + } + close(c.applyDoneC) + <-clus.snapshotTriggeredC[0] +}