diff --git a/etcd/etcd_test.go b/etcd/etcd_test.go index 71864b35b..ff5b512b1 100644 --- a/etcd/etcd_test.go +++ b/etcd/etcd_test.go @@ -25,6 +25,7 @@ import ( "net/http/httptest" "net/url" "os" + "reflect" "strings" "testing" "time" @@ -386,6 +387,57 @@ func TestTakingSnapshot(t *testing.T) { } } +func TestRestoreSnapshotFromLeader(t *testing.T) { + es, hs := buildCluster(1, false) + // let leader do snapshot + for i := 0; i < defaultCompact; i++ { + es[0].p.Set(fmt.Sprint("/foo", i), false, fmt.Sprint("bar", i), store.Permanent) + } + + // create one to join the cluster + c := config.New() + c.Peers = []string{hs[0].URL} + e, h := initTestServer(c, 1, false) + go e.Run() + waitMode(participantMode, e) + + // check new proposal could be submitted + if _, err := es[0].p.Set("/foo", false, "bar", store.Permanent); err != nil { + t.Fatal(err) + } + + // check store is recovered + for i := 0; i < defaultCompact; i++ { + ev, err := e.p.Get(fmt.Sprint("/foo", i), false, false) + if err != nil { + t.Errorf("get err = %v", err) + continue + } + w := fmt.Sprint("bar", i) + if g := *ev.Node.Value; g != w { + t.Errorf("value = %v, want %v", g, w) + } + } + + // check new proposal could be committed in the new machine + wch, err := e.p.Watch("/foo", false, false, defaultCompact) + if err != nil { + t.Errorf("watch err = %v", err) + } + <-wch.EventChan + + g := e.p.node.Nodes() + w := es[0].p.node.Nodes() + if !reflect.DeepEqual(g, w) { + t.Errorf("nodes = %v, want %v", g, w) + } + + e.Stop() + es[0].Stop() + h.Close() + hs[0].Close() +} + func buildCluster(number int, tls bool) ([]*Server, []*httptest.Server) { bootstrapper := 0 es := make([]*Server, number) diff --git a/etcd/participant.go b/etcd/participant.go index ce9c933f2..f0b30a771 100644 --- a/etcd/participant.go +++ b/etcd/participant.go @@ -205,6 +205,12 @@ func (p *participant) run() int64 { log.Printf("id=%x participant.stop\n", p.id) return stopMode } + if s := node.UnstableSnapshot(); !s.IsEmpty() { + if err := p.Recovery(s.Data); err != nil { + panic(err) + } + log.Printf("id=%x recovered index=%d\n", p.id, s.Index) + } p.apply(node.Next()) ents := node.UnstableEnts() p.save(ents, node.UnstableState())