diff --git a/rafthttp/http_test.go b/rafthttp/http_test.go index e276f1110..e7f6e034f 100644 --- a/rafthttp/http_test.go +++ b/rafthttp/http_test.go @@ -359,6 +359,7 @@ type fakePeer struct { snapMsgs []snap.Message peerURLs types.URLs connc chan *outgoingConn + paused bool } func newFakePeer() *fakePeer { @@ -369,9 +370,23 @@ func newFakePeer() *fakePeer { } } -func (pr *fakePeer) send(m raftpb.Message) { pr.msgs = append(pr.msgs, m) } -func (pr *fakePeer) sendSnap(m snap.Message) { pr.snapMsgs = append(pr.snapMsgs, m) } +func (pr *fakePeer) send(m raftpb.Message) { + if pr.paused { + return + } + pr.msgs = append(pr.msgs, m) +} + +func (pr *fakePeer) sendSnap(m snap.Message) { + if pr.paused { + return + } + pr.snapMsgs = append(pr.snapMsgs, m) +} + func (pr *fakePeer) update(urls types.URLs) { pr.peerURLs = urls } func (pr *fakePeer) attachOutgoingConn(conn *outgoingConn) { pr.connc <- conn } func (pr *fakePeer) activeSince() time.Time { return time.Time{} } func (pr *fakePeer) stop() {} +func (pr *fakePeer) Pause() { pr.paused = true } +func (pr *fakePeer) Resume() { pr.paused = false } diff --git a/rafthttp/remote.go b/rafthttp/remote.go index f83f4ab82..c62c81823 100644 --- a/rafthttp/remote.go +++ b/rafthttp/remote.go @@ -59,3 +59,11 @@ func (g *remote) send(m raftpb.Message) { func (g *remote) stop() { g.pipeline.stop() } + +func (g *remote) Pause() { + g.stop() +} + +func (g *remote) Resume() { + g.pipeline.start() +} diff --git a/rafthttp/transport.go b/rafthttp/transport.go index a7692d3b9..1f0b46836 100644 --- a/rafthttp/transport.go +++ b/rafthttp/transport.go @@ -206,6 +206,36 @@ func (t *Transport) Stop() { t.remotes = nil } +// CutPeer drops messages to the specified peer. +func (t *Transport) CutPeer(id types.ID) { + t.mu.RLock() + p, pok := t.peers[id] + g, gok := t.remotes[id] + t.mu.RUnlock() + + if pok { + p.(Pausable).Pause() + } + if gok { + g.Pause() + } +} + +// MendPeer recovers the message dropping behavior of the given peer. +func (t *Transport) MendPeer(id types.ID) { + t.mu.RLock() + p, pok := t.peers[id] + g, gok := t.remotes[id] + t.mu.RUnlock() + + if pok { + p.(Pausable).Resume() + } + if gok { + g.Resume() + } +} + func (t *Transport) AddRemote(id types.ID, us []string) { t.mu.Lock() defer t.mu.Unlock() diff --git a/rafthttp/transport_test.go b/rafthttp/transport_test.go index 4a173f9d4..c998a44b2 100644 --- a/rafthttp/transport_test.go +++ b/rafthttp/transport_test.go @@ -66,6 +66,37 @@ func TestTransportSend(t *testing.T) { } } +func TestTransportCutMend(t *testing.T) { + ss := &stats.ServerStats{} + ss.Initialize() + peer1 := newFakePeer() + peer2 := newFakePeer() + tr := &Transport{ + ServerStats: ss, + peers: map[types.ID]Peer{types.ID(1): peer1, types.ID(2): peer2}, + } + + tr.CutPeer(types.ID(1)) + + wmsgsTo := []raftpb.Message{ + // good message + {Type: raftpb.MsgProp, To: 1}, + {Type: raftpb.MsgApp, To: 1}, + } + + tr.Send(wmsgsTo) + if len(peer1.msgs) > 0 { + t.Fatalf("msgs expected to be ignored, got %+v", peer1.msgs) + } + + tr.MendPeer(types.ID(1)) + + tr.Send(wmsgsTo) + if !reflect.DeepEqual(peer1.msgs, wmsgsTo) { + t.Errorf("msgs to peer 1 = %+v, want %+v", peer1.msgs, wmsgsTo) + } +} + func TestTransportAdd(t *testing.T) { ls := stats.NewLeaderStats("") tr := &Transport{