From fe7cfe4d3ded5812bbcd5fe972646f106b2db201 Mon Sep 17 00:00:00 2001 From: Anthony Romano Date: Thu, 4 Feb 2016 22:05:59 -0800 Subject: [PATCH] rafthttp: plumb local peer URLs through transport --- etcdserver/server.go | 1 + rafthttp/http.go | 31 +++++++++++++++++++--------- rafthttp/http_test.go | 3 +-- rafthttp/peer.go | 16 ++++----------- rafthttp/pipeline.go | 11 +++++----- rafthttp/pipeline_test.go | 24 ++++++++++++++-------- rafthttp/remote.go | 4 +--- rafthttp/snapshot_sender.go | 10 +++++----- rafthttp/stream.go | 40 ++++++++++++++++--------------------- rafthttp/stream_test.go | 40 +++++++++++++++++-------------------- rafthttp/transport.go | 16 +++++++++------ rafthttp/transport_test.go | 2 +- rafthttp/util.go | 17 +++++++++++++++- 13 files changed, 117 insertions(+), 98 deletions(-) diff --git a/etcdserver/server.go b/etcdserver/server.go index ab737c167..6c1ccb6dc 100644 --- a/etcdserver/server.go +++ b/etcdserver/server.go @@ -374,6 +374,7 @@ func NewServer(cfg *ServerConfig) (*EtcdServer, error) { TLSInfo: cfg.PeerTLSInfo, DialTimeout: cfg.peerDialTimeout(), ID: id, + URLs: cfg.PeerURLs, ClusterID: cl.ID(), Raft: srv, Snapshotter: ss, diff --git a/rafthttp/http.go b/rafthttp/http.go index 2493d4a3f..35c265274 100644 --- a/rafthttp/http.go +++ b/rafthttp/http.go @@ -59,6 +59,7 @@ type writerToResponse interface { } type pipelineHandler struct { + tr Transporter r Raft cid types.ID } @@ -68,8 +69,9 @@ type pipelineHandler struct { // // The handler reads out the raft message from request body, // and forwards it to the given raft state machine for processing. -func newPipelineHandler(r Raft, cid types.ID) http.Handler { +func newPipelineHandler(tr Transporter, r Raft, cid types.ID) http.Handler { return &pipelineHandler{ + tr: tr, r: r, cid: cid, } @@ -89,6 +91,12 @@ func (h *pipelineHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + if from, err := types.IDFromString(r.Header.Get("X-Server-From")); err != nil { + if urls := r.Header.Get("X-PeerURLs"); urls != "" { + h.tr.AddRemote(from, strings.Split(urls, ",")) + } + } + // Limit the data size that could be read from the request body, which ensures that read from // connection will not time out accidentally due to possible blocking in underlying implementation. limitedr := pioutil.NewLimitedBufferReader(r.Body, connReadLimitByte) @@ -114,19 +122,22 @@ func (h *pipelineHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } return } + // Write StatusNoContent header after the message has been processed by // raft, which facilitates the client to report MsgSnap status. w.WriteHeader(http.StatusNoContent) } type snapshotHandler struct { + tr Transporter r Raft snapshotter *snap.Snapshotter cid types.ID } -func newSnapshotHandler(r Raft, snapshotter *snap.Snapshotter, cid types.ID) http.Handler { +func newSnapshotHandler(tr Transporter, r Raft, snapshotter *snap.Snapshotter, cid types.ID) http.Handler { return &snapshotHandler{ + tr: tr, r: r, snapshotter: snapshotter, cid: cid, @@ -156,6 +167,12 @@ func (h *snapshotHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + if from, err := types.IDFromString(r.Header.Get("X-Server-From")); err != nil { + if urls := r.Header.Get("X-PeerURLs"); urls != "" { + h.tr.AddRemote(from, strings.Split(urls, ",")) + } + } + dec := &messageDecoder{r: r.Body} m, err := dec.decode() if err != nil { @@ -256,19 +273,15 @@ func (h *streamHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } p := h.peerGetter.Get(from) - if p == nil { - if urls := r.Header.Get("X-Server-Peers"); urls != "" { - h.tr.AddPeer(from, strings.Split(urls, ",")) - } - p = h.peerGetter.Get(from) - } - if p == nil { // This may happen in following cases: // 1. user starts a remote peer that belongs to a different cluster // with the same cluster ID. // 2. local etcd falls behind of the cluster, and cannot recognize // the members that joined after its current progress. + if urls := r.Header.Get("X-PeerURLs"); urls != "" { + h.tr.AddRemote(from, strings.Split(urls, ",")) + } plog.Errorf("failed to find member %s in cluster %s", from, h.cid) http.Error(w, "error sender not found", http.StatusNotFound) return diff --git a/rafthttp/http_test.go b/rafthttp/http_test.go index 5d293afe2..262ab8afb 100644 --- a/rafthttp/http_test.go +++ b/rafthttp/http_test.go @@ -151,7 +151,7 @@ func TestServeRaftPrefix(t *testing.T) { req.Header.Set("X-Etcd-Cluster-ID", tt.clusterID) req.Header.Set("X-Server-Version", version.Version) rw := httptest.NewRecorder() - h := newPipelineHandler(tt.p, types.ID(0)) + h := newPipelineHandler(NewNopTransporter(), tt.p, types.ID(0)) h.ServeHTTP(rw, req) if rw.Code != tt.wcode { t.Errorf("#%d: got code=%d, want %d", i, rw.Code, tt.wcode) @@ -364,4 +364,3 @@ func (pr *fakePeer) update(urls types.URLs) { pr.peerURLs = urls func (pr *fakePeer) attachOutgoingConn(conn *outgoingConn) { pr.connc <- conn } func (pr *fakePeer) activeSince() time.Time { return time.Time{} } func (pr *fakePeer) stop() {} -func (pr *fakePeer) urls() types.URLs { return pr.peerURLs } diff --git a/rafthttp/peer.go b/rafthttp/peer.go index 0213594c3..f292d45de 100644 --- a/rafthttp/peer.go +++ b/rafthttp/peer.go @@ -65,9 +65,6 @@ type Peer interface { // update updates the urls of remote peer. update(urls types.URLs) - // urls retrieves the urls of the remote peer - urls() types.URLs - // attachOutgoingConn attaches the outgoing connection to the peer for // stream usage. After the call, the ownership of the outgoing // connection hands over to the peer. The peer will close the connection @@ -124,7 +121,6 @@ type peer struct { func startPeer(transport *Transport, urls types.URLs, local, to, cid types.ID, r Raft, fs *stats.FollowerStats, errorc chan error, v3demo bool) *peer { status := newPeerStatus(to) picker := newURLPicker(urls) - pipelineRt := transport.pipelineRt p := &peer{ id: to, r: r, @@ -133,8 +129,8 @@ func startPeer(transport *Transport, urls types.URLs, local, to, cid types.ID, r picker: picker, msgAppV2Writer: startStreamWriter(to, status, fs, r), writer: startStreamWriter(to, status, fs, r), - pipeline: newPipeline(pipelineRt, picker, local, to, cid, status, fs, r, errorc), - snapSender: newSnapshotSender(pipelineRt, picker, local, to, cid, status, r, errorc), + pipeline: newPipeline(transport, picker, local, to, cid, status, fs, r, errorc), + snapSender: newSnapshotSender(transport, picker, local, to, cid, status, r, errorc), sendc: make(chan raftpb.Message), recvc: make(chan raftpb.Message, recvBufSize), propc: make(chan raftpb.Message, maxPendingProposals), @@ -161,8 +157,8 @@ func startPeer(transport *Transport, urls types.URLs, local, to, cid types.ID, r } }() - p.msgAppV2Reader = startStreamReader(p, transport.streamRt, picker, streamTypeMsgAppV2, local, to, cid, status, p.recvc, p.propc, errorc) - reader := startStreamReader(p, transport.streamRt, picker, streamTypeMessage, local, to, cid, status, p.recvc, p.propc, errorc) + p.msgAppV2Reader = startStreamReader(transport, picker, streamTypeMsgAppV2, local, to, cid, status, p.recvc, p.propc, errorc) + reader := startStreamReader(transport, picker, streamTypeMessage, local, to, cid, status, p.recvc, p.propc, errorc) go func() { var paused bool for { @@ -229,10 +225,6 @@ func (p *peer) update(urls types.URLs) { } } -func (p *peer) urls() types.URLs { - return p.picker.urls -} - func (p *peer) attachOutgoingConn(conn *outgoingConn) { var ok bool switch conn.t { diff --git a/rafthttp/pipeline.go b/rafthttp/pipeline.go index 9eab7a6e5..b47dd5614 100644 --- a/rafthttp/pipeline.go +++ b/rafthttp/pipeline.go @@ -18,7 +18,6 @@ import ( "bytes" "errors" "io/ioutil" - "net/http" "sync" "time" @@ -45,7 +44,7 @@ type pipeline struct { from, to types.ID cid types.ID - tr http.RoundTripper + tr *Transport picker *urlPicker status *peerStatus fs *stats.FollowerStats @@ -58,7 +57,7 @@ type pipeline struct { stopc chan struct{} } -func newPipeline(tr http.RoundTripper, picker *urlPicker, from, to, cid types.ID, status *peerStatus, fs *stats.FollowerStats, r Raft, errorc chan error) *pipeline { +func newPipeline(tr *Transport, picker *urlPicker, from, to, cid types.ID, status *peerStatus, fs *stats.FollowerStats, r Raft, errorc chan error) *pipeline { p := &pipeline{ from: from, to: to, @@ -126,10 +125,10 @@ func (p *pipeline) handle() { // error on any failure. func (p *pipeline) post(data []byte) (err error) { u := p.picker.pick() - req := createPostRequest(u, RaftPrefix, bytes.NewBuffer(data), "application/protobuf", p.from, p.cid) + req := createPostRequest(u, RaftPrefix, bytes.NewBuffer(data), "application/protobuf", p.tr.URLs, p.from, p.cid) done := make(chan struct{}, 1) - cancel := httputil.RequestCanceler(p.tr, req) + cancel := httputil.RequestCanceler(p.tr.pipelineRt, req) go func() { select { case <-done: @@ -139,7 +138,7 @@ func (p *pipeline) post(data []byte) (err error) { } }() - resp, err := p.tr.RoundTrip(req) + resp, err := p.tr.pipelineRt.RoundTrip(req) done <- struct{}{} if err != nil { p.picker.unreachable(u) diff --git a/rafthttp/pipeline_test.go b/rafthttp/pipeline_test.go index 8bf909592..69f49db04 100644 --- a/rafthttp/pipeline_test.go +++ b/rafthttp/pipeline_test.go @@ -37,7 +37,8 @@ func TestPipelineSend(t *testing.T) { tr := &roundTripperRecorder{} picker := mustNewURLPicker(t, []string{"http://localhost:2380"}) fs := &stats.FollowerStats{} - p := newPipeline(tr, picker, types.ID(2), types.ID(1), types.ID(1), newPeerStatus(types.ID(1)), fs, &fakeRaft{}, nil) + tp := &Transport{pipelineRt: tr} + p := newPipeline(tp, picker, types.ID(2), types.ID(1), types.ID(1), newPeerStatus(types.ID(1)), fs, &fakeRaft{}, nil) p.msgc <- raftpb.Message{Type: raftpb.MsgApp} testutil.WaitSchedule() @@ -59,7 +60,8 @@ func TestPipelineKeepSendingWhenPostError(t *testing.T) { tr := &respRoundTripper{err: fmt.Errorf("roundtrip error")} picker := mustNewURLPicker(t, []string{"http://localhost:2380"}) fs := &stats.FollowerStats{} - p := newPipeline(tr, picker, types.ID(2), types.ID(1), types.ID(1), newPeerStatus(types.ID(1)), fs, &fakeRaft{}, nil) + tp := &Transport{pipelineRt: tr} + p := newPipeline(tp, picker, types.ID(2), types.ID(1), types.ID(1), newPeerStatus(types.ID(1)), fs, &fakeRaft{}, nil) for i := 0; i < 50; i++ { p.msgc <- raftpb.Message{Type: raftpb.MsgApp} @@ -79,7 +81,8 @@ func TestPipelineExceedMaximumServing(t *testing.T) { tr := newRoundTripperBlocker() picker := mustNewURLPicker(t, []string{"http://localhost:2380"}) fs := &stats.FollowerStats{} - p := newPipeline(tr, picker, types.ID(2), types.ID(1), types.ID(1), newPeerStatus(types.ID(1)), fs, &fakeRaft{}, nil) + tp := &Transport{pipelineRt: tr} + p := newPipeline(tp, picker, types.ID(2), types.ID(1), types.ID(1), newPeerStatus(types.ID(1)), fs, &fakeRaft{}, nil) // keep the sender busy and make the buffer full // nothing can go out as we block the sender @@ -119,7 +122,8 @@ func TestPipelineExceedMaximumServing(t *testing.T) { func TestPipelineSendFailed(t *testing.T) { picker := mustNewURLPicker(t, []string{"http://localhost:2380"}) fs := &stats.FollowerStats{} - p := newPipeline(newRespRoundTripper(0, errors.New("blah")), picker, types.ID(2), types.ID(1), types.ID(1), newPeerStatus(types.ID(1)), fs, &fakeRaft{}, nil) + tp := &Transport{pipelineRt: newRespRoundTripper(0, errors.New("blah"))} + p := newPipeline(tp, picker, types.ID(2), types.ID(1), types.ID(1), newPeerStatus(types.ID(1)), fs, &fakeRaft{}, nil) p.msgc <- raftpb.Message{Type: raftpb.MsgApp} testutil.WaitSchedule() @@ -135,7 +139,8 @@ func TestPipelineSendFailed(t *testing.T) { func TestPipelinePost(t *testing.T) { tr := &roundTripperRecorder{} picker := mustNewURLPicker(t, []string{"http://localhost:2380"}) - p := newPipeline(tr, picker, types.ID(2), types.ID(1), types.ID(1), newPeerStatus(types.ID(1)), nil, &fakeRaft{}, nil) + tp := &Transport{pipelineRt: tr} + p := newPipeline(tp, picker, types.ID(2), types.ID(1), types.ID(1), newPeerStatus(types.ID(1)), nil, &fakeRaft{}, nil) if err := p.post([]byte("some data")); err != nil { t.Fatalf("unexpected post error: %v", err) } @@ -182,7 +187,8 @@ func TestPipelinePostBad(t *testing.T) { } for i, tt := range tests { picker := mustNewURLPicker(t, []string{tt.u}) - p := newPipeline(newRespRoundTripper(tt.code, tt.err), picker, types.ID(2), types.ID(1), types.ID(1), newPeerStatus(types.ID(1)), nil, &fakeRaft{}, make(chan error)) + tp := &Transport{pipelineRt: newRespRoundTripper(tt.code, tt.err)} + p := newPipeline(tp, picker, types.ID(2), types.ID(1), types.ID(1), newPeerStatus(types.ID(1)), nil, &fakeRaft{}, make(chan error)) err := p.post([]byte("some data")) p.stop() @@ -203,7 +209,8 @@ func TestPipelinePostErrorc(t *testing.T) { for i, tt := range tests { picker := mustNewURLPicker(t, []string{tt.u}) errorc := make(chan error, 1) - p := newPipeline(newRespRoundTripper(tt.code, tt.err), picker, types.ID(2), types.ID(1), types.ID(1), newPeerStatus(types.ID(1)), nil, &fakeRaft{}, errorc) + tp := &Transport{pipelineRt: newRespRoundTripper(tt.code, tt.err)} + p := newPipeline(tp, picker, types.ID(2), types.ID(1), types.ID(1), newPeerStatus(types.ID(1)), nil, &fakeRaft{}, errorc) p.post([]byte("some data")) p.stop() select { @@ -216,7 +223,8 @@ func TestPipelinePostErrorc(t *testing.T) { func TestStopBlockedPipeline(t *testing.T) { picker := mustNewURLPicker(t, []string{"http://localhost:2380"}) - p := newPipeline(newRoundTripperBlocker(), picker, types.ID(2), types.ID(1), types.ID(1), newPeerStatus(types.ID(1)), nil, &fakeRaft{}, nil) + tp := &Transport{pipelineRt: newRoundTripperBlocker()} + p := newPipeline(tp, picker, types.ID(2), types.ID(1), types.ID(1), newPeerStatus(types.ID(1)), nil, &fakeRaft{}, nil) // send many messages that most of them will be blocked in buffer for i := 0; i < connPerPipeline*10; i++ { p.msgc <- raftpb.Message{} diff --git a/rafthttp/remote.go b/rafthttp/remote.go index 425e7aebd..a41bf2ac3 100644 --- a/rafthttp/remote.go +++ b/rafthttp/remote.go @@ -15,8 +15,6 @@ package rafthttp import ( - "net/http" - "github.com/coreos/etcd/pkg/types" "github.com/coreos/etcd/raft/raftpb" ) @@ -27,7 +25,7 @@ type remote struct { pipeline *pipeline } -func startRemote(tr http.RoundTripper, urls types.URLs, local, to, cid types.ID, r Raft, errorc chan error) *remote { +func startRemote(tr *Transport, urls types.URLs, local, to, cid types.ID, r Raft, errorc chan error) *remote { picker := newURLPicker(urls) status := newPeerStatus(to) return &remote{ diff --git a/rafthttp/snapshot_sender.go b/rafthttp/snapshot_sender.go index de27fd970..12e62d7c5 100644 --- a/rafthttp/snapshot_sender.go +++ b/rafthttp/snapshot_sender.go @@ -37,7 +37,7 @@ type snapshotSender struct { from, to types.ID cid types.ID - tr http.RoundTripper + tr *Transport picker *urlPicker status *peerStatus r Raft @@ -46,7 +46,7 @@ type snapshotSender struct { stopc chan struct{} } -func newSnapshotSender(tr http.RoundTripper, picker *urlPicker, from, to, cid types.ID, status *peerStatus, r Raft, errorc chan error) *snapshotSender { +func newSnapshotSender(tr *Transport, picker *urlPicker, from, to, cid types.ID, status *peerStatus, r Raft, errorc chan error) *snapshotSender { return &snapshotSender{ from: from, to: to, @@ -71,7 +71,7 @@ func (s *snapshotSender) send(merged snap.Message) { defer body.Close() u := s.picker.pick() - req := createPostRequest(u, RaftSnapshotPrefix, body, "application/octet-stream", s.from, s.cid) + req := createPostRequest(u, RaftSnapshotPrefix, body, "application/octet-stream", s.tr.URLs, s.from, s.cid) plog.Infof("start to send database snapshot [index: %d, to %s]...", m.Snapshot.Metadata.Index, types.ID(m.To)) @@ -105,7 +105,7 @@ func (s *snapshotSender) send(merged snap.Message) { // post posts the given request. // It returns nil when request is sent out and processed successfully. func (s *snapshotSender) post(req *http.Request) (err error) { - cancel := httputil.RequestCanceler(s.tr, req) + cancel := httputil.RequestCanceler(s.tr.pipelineRt, req) type responseAndError struct { resp *http.Response @@ -115,7 +115,7 @@ func (s *snapshotSender) post(req *http.Request) (err error) { result := make(chan responseAndError, 1) go func() { - resp, err := s.tr.RoundTrip(req) + resp, err := s.tr.pipelineRt.RoundTrip(req) if err != nil { result <- responseAndError{resp, nil, err} return diff --git a/rafthttp/stream.go b/rafthttp/stream.go index 870bcdfc0..ecf494130 100644 --- a/rafthttp/stream.go +++ b/rafthttp/stream.go @@ -226,8 +226,7 @@ func (cw *streamWriter) stop() { // streamReader is a long-running go-routine that dials to the remote stream // endpoint and reads messages from the response body returned. type streamReader struct { - localPeer Peer - tr http.RoundTripper + tr *Transport picker *urlPicker t streamType local, remote types.ID @@ -244,21 +243,20 @@ type streamReader struct { done chan struct{} } -func startStreamReader(p Peer, tr http.RoundTripper, picker *urlPicker, t streamType, local, remote, cid types.ID, status *peerStatus, recvc chan<- raftpb.Message, propc chan<- raftpb.Message, errorc chan<- error) *streamReader { +func startStreamReader(tr *Transport, picker *urlPicker, t streamType, local, remote, cid types.ID, status *peerStatus, recvc chan<- raftpb.Message, propc chan<- raftpb.Message, errorc chan<- error) *streamReader { r := &streamReader{ - localPeer: p, - tr: tr, - picker: picker, - t: t, - local: local, - remote: remote, - cid: cid, - status: status, - recvc: recvc, - propc: propc, - errorc: errorc, - stopc: make(chan struct{}), - done: make(chan struct{}), + tr: tr, + picker: picker, + t: t, + local: local, + remote: remote, + cid: cid, + status: status, + recvc: recvc, + propc: propc, + errorc: errorc, + stopc: make(chan struct{}), + done: make(chan struct{}), } go r.run() return r @@ -374,11 +372,7 @@ func (cr *streamReader) dial(t streamType) (io.ReadCloser, error) { req.Header.Set("X-Etcd-Cluster-ID", cr.cid.String()) req.Header.Set("X-Raft-To", cr.remote.String()) - var peerURLs []string - for _, url := range cr.localPeer.urls() { - peerURLs = append(peerURLs, url.String()) - } - req.Header.Set("X-Server-Peers", strings.Join(peerURLs, ",")) + setPeerURLsHeader(req, cr.tr.URLs) cr.mu.Lock() select { @@ -387,10 +381,10 @@ func (cr *streamReader) dial(t streamType) (io.ReadCloser, error) { return nil, fmt.Errorf("stream reader is stopped") default: } - cr.cancel = httputil.RequestCanceler(cr.tr, req) + cr.cancel = httputil.RequestCanceler(cr.tr.streamRt, req) cr.mu.Unlock() - resp, err := cr.tr.RoundTrip(req) + resp, err := cr.tr.streamRt.RoundTrip(req) if err != nil { cr.picker.unreachable(u) return nil, err diff --git a/rafthttp/stream_test.go b/rafthttp/stream_test.go index 4aece9987..75572ccde 100644 --- a/rafthttp/stream_test.go +++ b/rafthttp/stream_test.go @@ -116,12 +116,11 @@ func TestStreamReaderDialRequest(t *testing.T) { for i, tt := range []streamType{streamTypeMessage, streamTypeMsgAppV2} { tr := &roundTripperRecorder{} sr := &streamReader{ - tr: tr, - localPeer: newFakePeer(), - picker: mustNewURLPicker(t, []string{"http://localhost:2380"}), - local: types.ID(1), - remote: types.ID(2), - cid: types.ID(1), + tr: &Transport{streamRt: tr}, + picker: mustNewURLPicker(t, []string{"http://localhost:2380"}), + local: types.ID(1), + remote: types.ID(2), + cid: types.ID(1), } sr.dial(tt) @@ -167,13 +166,12 @@ func TestStreamReaderDialResult(t *testing.T) { err: tt.err, } sr := &streamReader{ - tr: tr, - localPeer: newFakePeer(), - picker: mustNewURLPicker(t, []string{"http://localhost:2380"}), - local: types.ID(1), - remote: types.ID(2), - cid: types.ID(1), - errorc: make(chan error, 1), + tr: &Transport{streamRt: tr}, + picker: mustNewURLPicker(t, []string{"http://localhost:2380"}), + local: types.ID(1), + remote: types.ID(2), + cid: types.ID(1), + errorc: make(chan error, 1), } _, err := sr.dial(streamTypeMessage) @@ -196,12 +194,11 @@ func TestStreamReaderDialDetectUnsupport(t *testing.T) { header: http.Header{}, } sr := &streamReader{ - tr: tr, - localPeer: newFakePeer(), - picker: mustNewURLPicker(t, []string{"http://localhost:2380"}), - local: types.ID(1), - remote: types.ID(2), - cid: types.ID(1), + tr: &Transport{streamRt: tr}, + picker: mustNewURLPicker(t, []string{"http://localhost:2380"}), + local: types.ID(1), + remote: types.ID(2), + cid: types.ID(1), } _, err := sr.dial(typ) @@ -257,9 +254,8 @@ func TestStream(t *testing.T) { h.sw = sw picker := mustNewURLPicker(t, []string{srv.URL}) - tr := &http.Transport{} - peer := newFakePeer() - sr := startStreamReader(peer, tr, picker, tt.t, types.ID(1), types.ID(2), types.ID(1), newPeerStatus(types.ID(1)), recvc, propc, nil) + tr := &Transport{streamRt: &http.Transport{}} + sr := startStreamReader(tr, picker, tt.t, types.ID(1), types.ID(2), types.ID(1), newPeerStatus(types.ID(1)), recvc, propc, nil) defer sr.stop() // wait for stream to work var writec chan<- raftpb.Message diff --git a/rafthttp/transport.go b/rafthttp/transport.go index 289e5ba93..cd43d7e32 100644 --- a/rafthttp/transport.go +++ b/rafthttp/transport.go @@ -97,9 +97,10 @@ type Transport struct { DialTimeout time.Duration // maximum duration before timing out dial of the request TLSInfo transport.TLSInfo // TLS information used when creating connection - ID types.ID // local member ID - ClusterID types.ID // raft cluster ID for request validation - Raft Raft // raft state machine, to which the Transport forwards received messages and reports status + ID types.ID // local member ID + URLs types.URLs // local peer URLs + ClusterID types.ID // raft cluster ID for request validation + Raft Raft // raft state machine, to which the Transport forwards received messages and reports status Snapshotter *snap.Snapshotter ServerStats *stats.ServerStats // used to record general transportation statistics // used to record transportation statistics with followers when @@ -139,9 +140,9 @@ func (t *Transport) Start() error { } func (t *Transport) Handler() http.Handler { - pipelineHandler := newPipelineHandler(t.Raft, t.ClusterID) + pipelineHandler := newPipelineHandler(t, t.Raft, t.ClusterID) streamHandler := newStreamHandler(t, t, t.Raft, t.ID, t.ClusterID) - snapHandler := newSnapshotHandler(t.Raft, t.Snapshotter, t.ClusterID) + snapHandler := newSnapshotHandler(t, t.Raft, t.Snapshotter, t.ClusterID) mux := http.NewServeMux() mux.Handle(RaftPrefix, pipelineHandler) mux.Handle(RaftStreamPrefix+"/", streamHandler) @@ -205,6 +206,9 @@ func (t *Transport) Stop() { func (t *Transport) AddRemote(id types.ID, us []string) { t.mu.Lock() defer t.mu.Unlock() + if _, ok := t.peers[id]; ok { + return + } if _, ok := t.remotes[id]; ok { return } @@ -212,7 +216,7 @@ func (t *Transport) AddRemote(id types.ID, us []string) { if err != nil { plog.Panicf("newURLs %+v should never fail: %+v", us, err) } - t.remotes[id] = startRemote(t.pipelineRt, urls, t.ID, id, t.ClusterID, t.Raft, t.ErrorC) + t.remotes[id] = startRemote(t, urls, t.ID, id, t.ClusterID, t.Raft, t.ErrorC) } func (t *Transport) AddPeer(id types.ID, us []string) { diff --git a/rafthttp/transport_test.go b/rafthttp/transport_test.go index ca465f83f..3493af360 100644 --- a/rafthttp/transport_test.go +++ b/rafthttp/transport_test.go @@ -121,7 +121,7 @@ func TestTransportUpdate(t *testing.T) { tr.UpdatePeer(types.ID(1), []string{u}) wurls := types.URLs(testutil.MustNewURLs(t, []string{"http://localhost:2380"})) if !reflect.DeepEqual(peer.peerURLs, wurls) { - t.Errorf("urls = %+v, want %+v", peer.urls, wurls) + t.Errorf("urls = %+v, want %+v", peer.peerURLs, wurls) } } diff --git a/rafthttp/util.go b/rafthttp/util.go index d90fd2d6c..5a18a03c1 100644 --- a/rafthttp/util.go +++ b/rafthttp/util.go @@ -86,7 +86,7 @@ func readEntryFrom(r io.Reader, ent *raftpb.Entry) error { } // createPostRequest creates a HTTP POST request that sends raft message. -func createPostRequest(u url.URL, path string, body io.Reader, ct string, from, cid types.ID) *http.Request { +func createPostRequest(u url.URL, path string, body io.Reader, ct string, urls types.URLs, from, cid types.ID) *http.Request { uu := u uu.Path = path req, err := http.NewRequest("POST", uu.String(), body) @@ -98,6 +98,8 @@ func createPostRequest(u url.URL, path string, body io.Reader, ct string, from, req.Header.Set("X-Server-Version", version.Version) req.Header.Set("X-Min-Cluster-Version", version.MinClusterVersion) req.Header.Set("X-Etcd-Cluster-ID", cid.String()) + setPeerURLsHeader(req, urls) + return req } @@ -187,3 +189,16 @@ func checkVersionCompability(name string, server, minCluster *semver.Version) er } return nil } + +// setPeerURLsHeader reports local urls for peer discovery +func setPeerURLsHeader(req *http.Request, urls types.URLs) { + if urls == nil { + // often not set in unit tests + return + } + var peerURLs []string + for _, url := range urls { + peerURLs = append(peerURLs, url.String()) + } + req.Header.Set("X-PeerURLs", strings.Join(peerURLs, ",")) +}