diff --git a/etcdserver/server.go b/etcdserver/server.go index 80f205efe..0506cc532 100644 --- a/etcdserver/server.go +++ b/etcdserver/server.go @@ -345,7 +345,16 @@ func NewServer(cfg *ServerConfig) (*EtcdServer, error) { } // TODO: move transport initialization near the definition of remote - tr := rafthttp.NewTransporter(cfg.Transport, id, cl.ID(), srv, srv.errorc, sstats, lstats) + tr := &rafthttp.Transport{ + RoundTripper: cfg.Transport, + ID: id, + ClusterID: cl.ID(), + Raft: srv, + ServerStats: sstats, + LeaderStats: lstats, + ErrorC: srv.errorc, + } + tr.Start() // add all remotes into transport for _, m := range remotes { if m.ID != id { diff --git a/etcdserver/server_test.go b/etcdserver/server_test.go index 14623d2ad..b51a15e6a 100644 --- a/etcdserver/server_test.go +++ b/etcdserver/server_test.go @@ -1468,6 +1468,7 @@ func (n *readyNode) Ready() <-chan raft.Ready { return n.readyc } type nopTransporter struct{} +func (s *nopTransporter) Start() {} func (s *nopTransporter) Handler() http.Handler { return nil } func (s *nopTransporter) Send(m []raftpb.Message) {} func (s *nopTransporter) AddRemote(id types.ID, us []string) {} diff --git a/rafthttp/functional_test.go b/rafthttp/functional_test.go index 8e2556908..b665857c0 100644 --- a/rafthttp/functional_test.go +++ b/rafthttp/functional_test.go @@ -30,14 +30,30 @@ import ( func TestSendMessage(t *testing.T) { // member 1 - tr := NewTransporter(&http.Transport{}, types.ID(1), types.ID(1), &fakeRaft{}, nil, newServerStats(), stats.NewLeaderStats("1")) + tr := &Transport{ + RoundTripper: &http.Transport{}, + ID: types.ID(1), + ClusterID: types.ID(1), + Raft: &fakeRaft{}, + ServerStats: newServerStats(), + LeaderStats: stats.NewLeaderStats("1"), + } + tr.Start() srv := httptest.NewServer(tr.Handler()) defer srv.Close() // member 2 recvc := make(chan raftpb.Message, 1) p := &fakeRaft{recvc: recvc} - tr2 := NewTransporter(&http.Transport{}, types.ID(2), types.ID(1), p, nil, newServerStats(), stats.NewLeaderStats("2")) + tr2 := &Transport{ + RoundTripper: &http.Transport{}, + ID: types.ID(2), + ClusterID: types.ID(1), + Raft: p, + ServerStats: newServerStats(), + LeaderStats: stats.NewLeaderStats("2"), + } + tr2.Start() srv2 := httptest.NewServer(tr2.Handler()) defer srv2.Close() @@ -45,7 +61,7 @@ func TestSendMessage(t *testing.T) { defer tr.Stop() tr2.AddPeer(types.ID(1), []string{srv.URL}) defer tr2.Stop() - if !waitStreamWorking(tr.(*transport).Get(types.ID(2)).(*peer)) { + if !waitStreamWorking(tr.Get(types.ID(2)).(*peer)) { t.Fatalf("stream from 1 to 2 is not in work as expected") } @@ -75,14 +91,30 @@ func TestSendMessage(t *testing.T) { // remote in a limited time when all underlying connections are broken. func TestSendMessageWhenStreamIsBroken(t *testing.T) { // member 1 - tr := NewTransporter(&http.Transport{}, types.ID(1), types.ID(1), &fakeRaft{}, nil, newServerStats(), stats.NewLeaderStats("1")) + tr := &Transport{ + RoundTripper: &http.Transport{}, + ID: types.ID(1), + ClusterID: types.ID(1), + Raft: &fakeRaft{}, + ServerStats: newServerStats(), + LeaderStats: stats.NewLeaderStats("1"), + } + tr.Start() srv := httptest.NewServer(tr.Handler()) defer srv.Close() // member 2 recvc := make(chan raftpb.Message, 1) p := &fakeRaft{recvc: recvc} - tr2 := NewTransporter(&http.Transport{}, types.ID(2), types.ID(1), p, nil, newServerStats(), stats.NewLeaderStats("2")) + tr2 := &Transport{ + RoundTripper: &http.Transport{}, + ID: types.ID(2), + ClusterID: types.ID(1), + Raft: p, + ServerStats: newServerStats(), + LeaderStats: stats.NewLeaderStats("2"), + } + tr2.Start() srv2 := httptest.NewServer(tr2.Handler()) defer srv2.Close() @@ -90,7 +122,7 @@ func TestSendMessageWhenStreamIsBroken(t *testing.T) { defer tr.Stop() tr2.AddPeer(types.ID(1), []string{srv.URL}) defer tr2.Stop() - if !waitStreamWorking(tr.(*transport).Get(types.ID(2)).(*peer)) { + if !waitStreamWorking(tr.Get(types.ID(2)).(*peer)) { t.Fatalf("stream from 1 to 2 is not in work as expected") } diff --git a/rafthttp/transport.go b/rafthttp/transport.go index 32ac5f764..653611cb3 100644 --- a/rafthttp/transport.go +++ b/rafthttp/transport.go @@ -38,6 +38,9 @@ type Raft interface { } type Transporter interface { + // Start starts the given Transporter. + // Start MUST be called before calling other functions in the interface. + Start() // Handler returns the HTTP handler of the transporter. // A transporter HTTP handler handles the HTTP requests // from remote peers. @@ -78,13 +81,26 @@ type Transporter interface { Stop() } -type transport struct { - roundTripper http.RoundTripper - id types.ID - clusterID types.ID - raft Raft - serverStats *stats.ServerStats - leaderStats *stats.LeaderStats +// Transport implements Transporter interface. It provides the functionality +// to send raft messages to peers, and receive raft messages from peers. +// User should call Handler method to get a handler to serve requests +// received from peerURLs. +// User needs to call Start before calling other functions, and call +// Stop when the Transport is no longer used. +type Transport struct { + RoundTripper http.RoundTripper // roundTripper to send requests + ID types.ID // local member ID + ClusterID types.ID // raft cluster ID for request validation + Raft Raft // raft state machine, to which the Transport forwards received messages and reports status + ServerStats *stats.ServerStats // used to record general transportation statistics + // used to record transportation statistics with followers when + // performing as leader in raft protocol + LeaderStats *stats.LeaderStats + // error channel used to report detected critical error, e.g., + // the member has been permanently removed from the cluster + // When an error is received from ErrorC, user should stop raft state + // machine and thus stop the Transport. + ErrorC chan error mu sync.RWMutex // protect the term, remote and peer map term uint64 // the latest term that has been observed @@ -92,28 +108,17 @@ type transport struct { peers map[types.ID]Peer // peers map prober probing.Prober - errorc chan error } -func NewTransporter(rt http.RoundTripper, id, cid types.ID, r Raft, errorc chan error, ss *stats.ServerStats, ls *stats.LeaderStats) Transporter { - return &transport{ - roundTripper: rt, - id: id, - clusterID: cid, - raft: r, - serverStats: ss, - leaderStats: ls, - remotes: make(map[types.ID]*remote), - peers: make(map[types.ID]Peer), - - prober: probing.NewProber(rt), - errorc: errorc, - } +func (t *Transport) Start() { + t.remotes = make(map[types.ID]*remote) + t.peers = make(map[types.ID]Peer) + t.prober = probing.NewProber(t.RoundTripper) } -func (t *transport) Handler() http.Handler { - pipelineHandler := NewHandler(t.raft, t.clusterID) - streamHandler := newStreamHandler(t, t.raft, t.id, t.clusterID) +func (t *Transport) Handler() http.Handler { + pipelineHandler := NewHandler(t.Raft, t.ClusterID) + streamHandler := newStreamHandler(t, t.Raft, t.ID, t.ClusterID) mux := http.NewServeMux() mux.Handle(RaftPrefix, pipelineHandler) mux.Handle(RaftStreamPrefix+"/", streamHandler) @@ -121,13 +126,13 @@ func (t *transport) Handler() http.Handler { return mux } -func (t *transport) Get(id types.ID) Peer { +func (t *Transport) Get(id types.ID) Peer { t.mu.RLock() defer t.mu.RUnlock() return t.peers[id] } -func (t *transport) maybeUpdatePeersTerm(term uint64) { +func (t *Transport) maybeUpdatePeersTerm(term uint64) { t.mu.Lock() defer t.mu.Unlock() if t.term >= term { @@ -139,7 +144,7 @@ func (t *transport) maybeUpdatePeersTerm(term uint64) { } } -func (t *transport) Send(msgs []raftpb.Message) { +func (t *Transport) Send(msgs []raftpb.Message) { for _, m := range msgs { // intentionally dropped message if m.To == 0 { @@ -154,7 +159,7 @@ func (t *transport) Send(msgs []raftpb.Message) { p, ok := t.peers[to] if ok { if m.Type == raftpb.MsgApp { - t.serverStats.SendAppendReq(m.Size()) + t.ServerStats.SendAppendReq(m.Size()) } p.Send(m) continue @@ -170,7 +175,7 @@ func (t *transport) Send(msgs []raftpb.Message) { } } -func (t *transport) Stop() { +func (t *Transport) Stop() { for _, r := range t.remotes { r.Stop() } @@ -178,12 +183,12 @@ func (t *transport) Stop() { p.Stop() } t.prober.RemoveAll() - if tr, ok := t.roundTripper.(*http.Transport); ok { + if tr, ok := t.RoundTripper.(*http.Transport); ok { tr.CloseIdleConnections() } } -func (t *transport) AddRemote(id types.ID, us []string) { +func (t *Transport) AddRemote(id types.ID, us []string) { t.mu.Lock() defer t.mu.Unlock() if _, ok := t.remotes[id]; ok { @@ -193,10 +198,10 @@ func (t *transport) AddRemote(id types.ID, us []string) { if err != nil { plog.Panicf("newURLs %+v should never fail: %+v", us, err) } - t.remotes[id] = startRemote(t.roundTripper, urls, t.id, id, t.clusterID, t.raft, t.errorc) + t.remotes[id] = startRemote(t.RoundTripper, urls, t.ID, id, t.ClusterID, t.Raft, t.ErrorC) } -func (t *transport) AddPeer(id types.ID, us []string) { +func (t *Transport) AddPeer(id types.ID, us []string) { t.mu.Lock() defer t.mu.Unlock() if _, ok := t.peers[id]; ok { @@ -206,18 +211,18 @@ func (t *transport) AddPeer(id types.ID, us []string) { if err != nil { plog.Panicf("newURLs %+v should never fail: %+v", us, err) } - fs := t.leaderStats.Follower(id.String()) - t.peers[id] = startPeer(t.roundTripper, urls, t.id, id, t.clusterID, t.raft, fs, t.errorc, t.term) + fs := t.LeaderStats.Follower(id.String()) + t.peers[id] = startPeer(t.RoundTripper, urls, t.ID, id, t.ClusterID, t.Raft, fs, t.ErrorC, t.term) addPeerToProber(t.prober, id.String(), us) } -func (t *transport) RemovePeer(id types.ID) { +func (t *Transport) RemovePeer(id types.ID) { t.mu.Lock() defer t.mu.Unlock() t.removePeer(id) } -func (t *transport) RemoveAllPeers() { +func (t *Transport) RemoveAllPeers() { t.mu.Lock() defer t.mu.Unlock() for id := range t.peers { @@ -226,18 +231,18 @@ func (t *transport) RemoveAllPeers() { } // the caller of this function must have the peers mutex. -func (t *transport) removePeer(id types.ID) { +func (t *Transport) removePeer(id types.ID) { if peer, ok := t.peers[id]; ok { peer.Stop() } else { plog.Panicf("unexpected removal of unknown peer '%d'", id) } delete(t.peers, id) - delete(t.leaderStats.Followers, id.String()) + delete(t.LeaderStats.Followers, id.String()) t.prober.Remove(id.String()) } -func (t *transport) UpdatePeer(id types.ID, us []string) { +func (t *Transport) UpdatePeer(id types.ID, us []string) { t.mu.Lock() defer t.mu.Unlock() // TODO: return error or just panic? @@ -254,7 +259,7 @@ func (t *transport) UpdatePeer(id types.ID, us []string) { addPeerToProber(t.prober, id.String(), us) } -func (t *transport) ActiveSince(id types.ID) time.Time { +func (t *Transport) ActiveSince(id types.ID) time.Time { t.mu.Lock() defer t.mu.Unlock() if p, ok := t.peers[id]; ok { @@ -269,13 +274,13 @@ type Pausable interface { } // for testing -func (t *transport) Pause() { +func (t *Transport) Pause() { for _, p := range t.peers { p.(Pausable).Pause() } } -func (t *transport) Resume() { +func (t *Transport) Resume() { for _, p := range t.peers { p.(Pausable).Resume() } diff --git a/rafthttp/transport_bench_test.go b/rafthttp/transport_bench_test.go index 7c247499e..1b9bbe9ed 100644 --- a/rafthttp/transport_bench_test.go +++ b/rafthttp/transport_bench_test.go @@ -30,13 +30,29 @@ import ( func BenchmarkSendingMsgApp(b *testing.B) { // member 1 - tr := NewTransporter(&http.Transport{}, types.ID(1), types.ID(1), &fakeRaft{}, nil, newServerStats(), stats.NewLeaderStats("1")) + tr := &Transport{ + RoundTripper: &http.Transport{}, + ID: types.ID(1), + ClusterID: types.ID(1), + Raft: &fakeRaft{}, + ServerStats: newServerStats(), + LeaderStats: stats.NewLeaderStats("1"), + } + tr.Start() srv := httptest.NewServer(tr.Handler()) defer srv.Close() // member 2 r := &countRaft{} - tr2 := NewTransporter(&http.Transport{}, types.ID(2), types.ID(1), r, nil, newServerStats(), stats.NewLeaderStats("2")) + tr2 := &Transport{ + RoundTripper: &http.Transport{}, + ID: types.ID(2), + ClusterID: types.ID(1), + Raft: r, + ServerStats: newServerStats(), + LeaderStats: stats.NewLeaderStats("2"), + } + tr2.Start() srv2 := httptest.NewServer(tr2.Handler()) defer srv2.Close() @@ -44,7 +60,7 @@ func BenchmarkSendingMsgApp(b *testing.B) { defer tr.Stop() tr2.AddPeer(types.ID(1), []string{srv.URL}) defer tr2.Stop() - if !waitStreamWorking(tr.(*transport).Get(types.ID(2)).(*peer)) { + if !waitStreamWorking(tr.Get(types.ID(2)).(*peer)) { b.Fatalf("stream from 1 to 2 is not in work as expected") } diff --git a/rafthttp/transport_test.go b/rafthttp/transport_test.go index 4a8dc4cf7..9013100c8 100644 --- a/rafthttp/transport_test.go +++ b/rafthttp/transport_test.go @@ -34,8 +34,8 @@ func TestTransportSend(t *testing.T) { ss.Initialize() peer1 := newFakePeer() peer2 := newFakePeer() - tr := &transport{ - serverStats: ss, + tr := &Transport{ + ServerStats: ss, peers: map[types.ID]Peer{types.ID(1): peer1, types.ID(2): peer2}, } wmsgsIgnored := []raftpb.Message{ @@ -69,9 +69,9 @@ func TestTransportSend(t *testing.T) { func TestTransportAdd(t *testing.T) { ls := stats.NewLeaderStats("") term := uint64(10) - tr := &transport{ - roundTripper: &roundTripperRecorder{}, - leaderStats: ls, + tr := &Transport{ + RoundTripper: &roundTripperRecorder{}, + LeaderStats: ls, term: term, peers: make(map[types.ID]Peer), prober: probing.NewProber(nil), @@ -102,9 +102,9 @@ func TestTransportAdd(t *testing.T) { } func TestTransportRemove(t *testing.T) { - tr := &transport{ - roundTripper: &roundTripperRecorder{}, - leaderStats: stats.NewLeaderStats(""), + tr := &Transport{ + RoundTripper: &roundTripperRecorder{}, + LeaderStats: stats.NewLeaderStats(""), peers: make(map[types.ID]Peer), prober: probing.NewProber(nil), } @@ -119,7 +119,7 @@ func TestTransportRemove(t *testing.T) { func TestTransportUpdate(t *testing.T) { peer := newFakePeer() - tr := &transport{ + tr := &Transport{ peers: map[types.ID]Peer{types.ID(1): peer}, prober: probing.NewProber(nil), } @@ -133,12 +133,12 @@ func TestTransportUpdate(t *testing.T) { func TestTransportErrorc(t *testing.T) { errorc := make(chan error, 1) - tr := &transport{ - roundTripper: newRespRoundTripper(http.StatusForbidden, nil), - leaderStats: stats.NewLeaderStats(""), + tr := &Transport{ + RoundTripper: newRespRoundTripper(http.StatusForbidden, nil), + LeaderStats: stats.NewLeaderStats(""), + ErrorC: errorc, peers: make(map[types.ID]Peer), prober: probing.NewProber(nil), - errorc: errorc, } tr.AddPeer(1, []string{"http://localhost:2380"}) defer tr.Stop()