diff --git a/tests/functional/simple_snapshot_test.go b/tests/functional/simple_snapshot_test.go index d03945bcf..dd2103dbe 100644 --- a/tests/functional/simple_snapshot_test.go +++ b/tests/functional/simple_snapshot_test.go @@ -11,7 +11,7 @@ import ( ) // This test creates a single node and then set a value to it to trigger snapshot -func TestSimpleSnapshot(t *testing.T) { +func TestSnapshot(t *testing.T) { procAttr := new(os.ProcAttr) procAttr.Files = []*os.File{nil, os.Stdout, os.Stderr} args := []string{"etcd", "-name=node1", "-data-dir=/tmp/node1", "-snapshot=true", "-snapshot-count=500"} @@ -93,3 +93,57 @@ func TestSimpleSnapshot(t *testing.T) { t.Fatal("wrong name of snapshot :", snapshots[0].Name()) } } + +// TestSnapshotRestart tests etcd restarts with snapshot file +func TestSnapshotRestart(t *testing.T) { + procAttr := new(os.ProcAttr) + procAttr.Files = []*os.File{nil, os.Stdout, os.Stderr} + args := []string{"etcd", "-name=node1", "-data-dir=/tmp/node1", "-snapshot=true", "-snapshot-count=500"} + + process, err := os.StartProcess(EtcdBinPath, append(args, "-f"), procAttr) + if err != nil { + t.Fatal("start process failed:" + err.Error()) + } + + time.Sleep(time.Second) + + c := etcd.NewClient(nil) + + c.SyncCluster() + // issue first 501 commands + for i := 0; i < 501; i++ { + result, err := c.Set("foo", "bar", 100) + node := result.Node + + if err != nil || node.Key != "/foo" || node.Value != "bar" || node.TTL < 95 { + if err != nil { + t.Fatal(err) + } + + t.Fatalf("Set failed with %s %s %v", node.Key, node.Value, node.TTL) + } + } + + // wait for a snapshot interval + time.Sleep(3 * time.Second) + + _, err = ioutil.ReadDir("/tmp/node1/snapshot") + if err != nil { + t.Fatal("list snapshot failed:" + err.Error()) + } + + process.Kill() + + process, err = os.StartProcess(EtcdBinPath, args, procAttr) + if err != nil { + t.Fatal("start process failed:" + err.Error()) + } + defer process.Kill() + + time.Sleep(1 * time.Second) + + _, err = c.Set("foo", "bar", 100) + if err != nil { + t.Fatal(err) + } +} diff --git a/third_party/github.com/goraft/raft/http_transporter.go b/third_party/github.com/goraft/raft/http_transporter.go index 1ab06dd38..183254b23 100644 --- a/third_party/github.com/goraft/raft/http_transporter.go +++ b/third_party/github.com/goraft/raft/http_transporter.go @@ -7,6 +7,7 @@ import ( "net/http" "net/url" "path" + "time" ) // Parts from this transporter were heavily influenced by Peter Bougon's @@ -42,7 +43,7 @@ type HTTPMuxer interface { //------------------------------------------------------------------------------ // Creates a new HTTP transporter with the given path prefix. -func NewHTTPTransporter(prefix string) *HTTPTransporter { +func NewHTTPTransporter(prefix string, timeout time.Duration) *HTTPTransporter { t := &HTTPTransporter{ DisableKeepAlives: false, prefix: prefix, @@ -53,6 +54,7 @@ func NewHTTPTransporter(prefix string) *HTTPTransporter { Transport: &http.Transport{DisableKeepAlives: false}, } t.httpClient.Transport = t.Transport + t.Transport.ResponseHeaderTimeout = timeout return t } @@ -120,7 +122,6 @@ func (t *HTTPTransporter) SendAppendEntriesRequest(server Server, peer *Peer, re url := joinPath(peer.ConnectionString, t.AppendEntriesPath()) traceln(server.Name(), "POST", url) - t.Transport.ResponseHeaderTimeout = server.ElectionTimeout() httpResp, err := t.httpClient.Post(url, "application/protobuf", &b) if httpResp == nil || err != nil { traceln("transporter.ae.response.error:", err) diff --git a/third_party/github.com/goraft/raft/http_transporter_test.go b/third_party/github.com/goraft/raft/http_transporter_test.go index d406e8a71..9f68674ea 100644 --- a/third_party/github.com/goraft/raft/http_transporter_test.go +++ b/third_party/github.com/goraft/raft/http_transporter_test.go @@ -11,7 +11,7 @@ import ( // Ensure that we can start several servers and have them communicate. func TestHTTPTransporter(t *testing.T) { - transporter := NewHTTPTransporter("/raft") + transporter := NewHTTPTransporter("/raft", testElectionTimeout) transporter.DisableKeepAlives = true servers := []Server{} @@ -91,7 +91,7 @@ func runTestHttpServers(t *testing.T, servers *[]Server, transporter *HTTPTransp func BenchmarkSpeed(b *testing.B) { - transporter := NewHTTPTransporter("/raft") + transporter := NewHTTPTransporter("/raft", testElectionTimeout) transporter.DisableKeepAlives = true servers := []Server{} diff --git a/third_party/github.com/goraft/raft/log.go b/third_party/github.com/goraft/raft/log.go index bd4e4afea..6e87da455 100644 --- a/third_party/github.com/goraft/raft/log.go +++ b/third_party/github.com/goraft/raft/log.go @@ -27,6 +27,7 @@ type Log struct { mutex sync.RWMutex startIndex uint64 // the index before the first entry in the Log entries startTerm uint64 + initialized bool } // The results of the applying a log entry. @@ -147,7 +148,9 @@ func (l *Log) open(path string) error { if os.IsNotExist(err) { l.file, err = os.OpenFile(path, os.O_WRONLY|os.O_CREATE, 0600) debugln("log.open.create ", path) - + if err == nil { + l.initialized = true + } return err } return err @@ -187,6 +190,7 @@ func (l *Log) open(path string) error { readBytes += int64(n) } debugln("open.log.recovery number of log ", len(l.entries)) + l.initialized = true return nil } diff --git a/third_party/github.com/goraft/raft/peer.go b/third_party/github.com/goraft/raft/peer.go index 83ecc683d..28c25c48d 100644 --- a/third_party/github.com/goraft/raft/peer.go +++ b/third_party/github.com/goraft/raft/peer.go @@ -17,10 +17,10 @@ type Peer struct { Name string `json:"name"` ConnectionString string `json:"connectionString"` prevLogIndex uint64 - mutex sync.RWMutex stopChan chan bool heartbeatInterval time.Duration lastActivity time.Time + sync.RWMutex } //------------------------------------------------------------------------------ @@ -56,18 +56,24 @@ func (p *Peer) setHeartbeatInterval(duration time.Duration) { // Retrieves the previous log index. func (p *Peer) getPrevLogIndex() uint64 { - p.mutex.RLock() - defer p.mutex.RUnlock() + p.RLock() + defer p.RUnlock() return p.prevLogIndex } // Sets the previous log index. func (p *Peer) setPrevLogIndex(value uint64) { - p.mutex.Lock() - defer p.mutex.Unlock() + p.Lock() + defer p.Unlock() p.prevLogIndex = value } +func (p *Peer) setLastActivity(now time.Time) { + p.Lock() + defer p.Unlock() + p.lastActivity = now +} + //------------------------------------------------------------------------------ // // Methods @@ -93,6 +99,8 @@ func (p *Peer) stopHeartbeat(flush bool) { // LastActivity returns the last time any response was received from the peer. func (p *Peer) LastActivity() time.Time { + p.RLock() + defer p.RUnlock() return p.lastActivity } @@ -103,8 +111,8 @@ func (p *Peer) LastActivity() time.Time { // Clones the state of the peer. The clone is not attached to a server and // the heartbeat timer will not exist. func (p *Peer) clone() *Peer { - p.mutex.Lock() - defer p.mutex.Unlock() + p.Lock() + defer p.Unlock() return &Peer{ Name: p.Name, ConnectionString: p.ConnectionString, @@ -181,9 +189,9 @@ func (p *Peer) sendAppendEntriesRequest(req *AppendEntriesRequest) { } traceln("peer.append.resp: ", p.server.Name(), "<-", p.Name) + p.setLastActivity(time.Now()) // If successful then update the previous log index. - p.mutex.Lock() - p.lastActivity = time.Now() + p.Lock() if resp.Success() { if len(req.Entries) > 0 { p.prevLogIndex = req.Entries[len(req.Entries)-1].GetIndex() @@ -229,7 +237,7 @@ func (p *Peer) sendAppendEntriesRequest(req *AppendEntriesRequest) { debugln("peer.append.resp.decrement: ", p.Name, "; idx =", p.prevLogIndex) } } - p.mutex.Unlock() + p.Unlock() // Attach the peer to resp, thus server can know where it comes from resp.peer = p.Name @@ -251,7 +259,8 @@ func (p *Peer) sendSnapshotRequest(req *SnapshotRequest) { // If successful, the peer should have been to snapshot state // Send it the snapshot! - p.lastActivity = time.Now() + p.setLastActivity(time.Now()) + if resp.Success { p.sendSnapshotRecoveryRequest() } else { @@ -272,7 +281,7 @@ func (p *Peer) sendSnapshotRecoveryRequest() { return } - p.lastActivity = time.Now() + p.setLastActivity(time.Now()) if resp.Success { p.prevLogIndex = req.LastIndex } else { @@ -293,7 +302,7 @@ func (p *Peer) sendVoteRequest(req *RequestVoteRequest, c chan *RequestVoteRespo req.peer = p if resp := p.server.Transporter().SendVoteRequest(p.server, p, req); resp != nil { debugln("peer.vote.recv: ", p.server.Name(), "<-", p.Name) - p.lastActivity = time.Now() + p.setLastActivity(time.Now()) resp.peer = p c <- resp } else { diff --git a/third_party/github.com/goraft/raft/server.go b/third_party/github.com/goraft/raft/server.go index 3f9b653a4..5f29010af 100644 --- a/third_party/github.com/goraft/raft/server.go +++ b/third_party/github.com/goraft/raft/server.go @@ -358,8 +358,8 @@ func (s *server) promotable() bool { // Retrieves the number of member servers in the consensus. func (s *server) MemberCount() int { - s.mutex.Lock() - defer s.mutex.Unlock() + s.mutex.RLock() + defer s.mutex.RUnlock() return len(s.peers) + 1 } @@ -468,8 +468,10 @@ func (s *server) Init() error { return fmt.Errorf("raft.Server: Server already running[%v]", s.state) } - // server has been initialized or server was stopped after initialized - if s.state == Initialized || !s.log.isEmpty() { + // Server has been initialized or server was stopped after initialized + // If log has been initialized, we know that the server was stopped after + // running. + if s.state == Initialized || s.log.initialized { s.state = Initialized return nil } @@ -501,13 +503,17 @@ func (s *server) Init() error { // Shuts down the server. func (s *server) Stop() { + if s.State() == Stopped { + return + } + stop := make(chan bool) s.stopped <- stop - s.state = Stopped // make sure the server has stopped before we close the log <-stop s.log.close() + s.setState(Stopped) } // Checks if the server is currently running. @@ -527,8 +533,6 @@ func (s *server) updateCurrentTerm(term uint64, leaderName string) { _assert(term > s.currentTerm, "upadteCurrentTerm: update is called when term is not larger than currentTerm") - s.mutex.Lock() - defer s.mutex.Unlock() // Store previous values temporarily. prevTerm := s.currentTerm prevLeader := s.leader @@ -536,21 +540,20 @@ func (s *server) updateCurrentTerm(term uint64, leaderName string) { // set currentTerm = T, convert to follower (ยง5.1) // stop heartbeats before step-down if s.state == Leader { - s.mutex.Unlock() for _, peer := range s.peers { peer.stopHeartbeat(false) } - s.mutex.Lock() } // update the term and clear vote for if s.state != Follower { - s.mutex.Unlock() s.setState(Follower) - s.mutex.Lock() } + + s.mutex.Lock() s.currentTerm = term s.leader = leaderName s.votedFor = "" + s.mutex.Unlock() // Dispatch change events. s.DispatchEvent(newEvent(TermChangeEventType, s.currentTerm, prevTerm)) @@ -580,9 +583,9 @@ func (s *server) updateCurrentTerm(term uint64, leaderName string) { func (s *server) loop() { defer s.debugln("server.loop.end") - for s.state != Stopped { - state := s.State() + state := s.State() + for state != Stopped { s.debugln("server.loop.run ", state) switch state { case Follower: @@ -594,6 +597,7 @@ func (s *server) loop() { case Snapshotting: s.snapshotLoop() } + state = s.State() } } @@ -903,9 +907,9 @@ func (s *server) processAppendEntriesRequest(req *AppendEntriesRequest) (*Append } if req.Term == s.currentTerm { - _assert(s.state != Leader, "leader.elected.at.same.term.%d\n", s.currentTerm) + _assert(s.State() != Leader, "leader.elected.at.same.term.%d\n", s.currentTerm) // change state to follower - s.state = Follower + s.setState(Follower) // discover new leader when candidate // save leader name when follower s.leader = req.LeaderName diff --git a/third_party/github.com/goraft/raft/snapshot_test.go b/third_party/github.com/goraft/raft/snapshot_test.go index d650aa975..5d6eecb43 100644 --- a/third_party/github.com/goraft/raft/snapshot_test.go +++ b/third_party/github.com/goraft/raft/snapshot_test.go @@ -2,6 +2,7 @@ package raft import ( "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -26,11 +27,43 @@ func TestSnapshot(t *testing.T) { // Restart server. s.Stop() - s.Start() - // Recover from snapshot. err = s.LoadSnapshot() assert.NoError(t, err) + s.Start() + }) +} + +// Ensure that a new server can recover from previous snapshot with log +func TestSnapshotRecovery(t *testing.T) { + runServerWithMockStateMachine(Leader, func(s Server, m *mock.Mock) { + m.On("Save").Return([]byte("foo"), nil) + m.On("Recovery", []byte("foo")).Return(nil) + + s.Do(&testCommand1{}) + err := s.TakeSnapshot() + assert.NoError(t, err) + assert.Equal(t, s.(*server).snapshot.LastIndex, uint64(2)) + + // Repeat to make sure new snapshot gets created. + s.Do(&testCommand1{}) + + // Stop the old server + s.Stop() + + // create a new server with previous log and snapshot + newS, err := NewServer("1", s.Path(), &testTransporter{}, s.StateMachine(), nil, "") + // Recover from snapshot. + err = newS.LoadSnapshot() + assert.NoError(t, err) + + newS.Start() + defer newS.Stop() + + // wait for it to become leader + time.Sleep(time.Second) + // ensure server load the previous log + assert.Equal(t, len(newS.LogEntries()), 3, "") }) } diff --git a/third_party/github.com/goraft/raft/util.go b/third_party/github.com/goraft/raft/util.go index 5fa2c41a8..44a3efd31 100644 --- a/third_party/github.com/goraft/raft/util.go +++ b/third_party/github.com/goraft/raft/util.go @@ -25,15 +25,16 @@ func writeFileSynced(filename string, data []byte, perm os.FileMode) error { if err != nil { return err } + defer f.Close() // Idempotent n, err := f.Write(data) - if n < len(data) { - f.Close() + if err == nil && n < len(data) { return io.ErrShortWrite + } else if err != nil { + return err } - err = f.Sync() - if err != nil { + if err = f.Sync(); err != nil { return err }