diff --git a/etcdserver/etcdhttp/peer.go b/etcdserver/etcdhttp/peer.go index ec3aa430b..9f13976ca 100644 --- a/etcdserver/etcdhttp/peer.go +++ b/etcdserver/etcdhttp/peer.go @@ -18,14 +18,11 @@ package etcdhttp import ( "encoding/json" - "io/ioutil" "log" "net/http" - "github.com/coreos/etcd/Godeps/_workspace/src/code.google.com/p/go.net/context" "github.com/coreos/etcd/etcdserver" - "github.com/coreos/etcd/pkg/types" - "github.com/coreos/etcd/raft/raftpb" + "github.com/coreos/etcd/rafthttp" ) const ( @@ -35,12 +32,7 @@ const ( // NewPeerHandler generates an http.Handler to handle etcd peer (raft) requests. func NewPeerHandler(server *etcdserver.EtcdServer) http.Handler { - rh := &raftHandler{ - stats: server, - server: server, - clusterInfo: server.Cluster, - } - + rh := rafthttp.NewHandler(server, server.Cluster.ID()) mh := &peerMembersHandler{ clusterInfo: server.Cluster, } @@ -52,55 +44,6 @@ func NewPeerHandler(server *etcdserver.EtcdServer) http.Handler { return mux } -type raftHandler struct { - stats etcdserver.Stats - server etcdserver.Server - clusterInfo etcdserver.ClusterInfo -} - -func (h *raftHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if !allowMethod(w, r.Method, "POST") { - return - } - - wcid := h.clusterInfo.ID().String() - w.Header().Set("X-Etcd-Cluster-ID", wcid) - - gcid := r.Header.Get("X-Etcd-Cluster-ID") - if gcid != wcid { - log.Printf("etcdhttp: request ignored due to cluster ID mismatch got %s want %s", gcid, wcid) - http.Error(w, "clusterID mismatch", http.StatusPreconditionFailed) - return - } - - b, err := ioutil.ReadAll(r.Body) - if err != nil { - log.Println("etcdhttp: error reading raft message:", err) - http.Error(w, "error reading raft message", http.StatusBadRequest) - return - } - var m raftpb.Message - if err := m.Unmarshal(b); err != nil { - log.Println("etcdhttp: error unmarshaling raft message:", err) - http.Error(w, "error unmarshaling raft message", http.StatusBadRequest) - return - } - if err := h.server.Process(context.TODO(), m); err != nil { - switch err { - case etcdserver.ErrRemoved: - log.Printf("etcdhttp: reject message from removed member %s", types.ID(m.From).String()) - http.Error(w, "cannot process message from removed member", http.StatusForbidden) - default: - writeError(w, err) - } - return - } - if m.Type == raftpb.MsgApp { - h.stats.UpdateRecvApp(types.ID(m.From), r.ContentLength) - } - w.WriteHeader(http.StatusNoContent) -} - type peerMembersHandler struct { clusterInfo etcdserver.ClusterInfo } diff --git a/etcdserver/etcdhttp/peer_test.go b/etcdserver/etcdhttp/peer_test.go index 495d9eb4a..29e8b0dc1 100644 --- a/etcdserver/etcdhttp/peer_test.go +++ b/etcdserver/etcdhttp/peer_test.go @@ -17,165 +17,15 @@ package etcdhttp import ( - "bytes" "encoding/json" - "errors" - "io" "net/http" "net/http/httptest" "path" - "strings" "testing" "github.com/coreos/etcd/etcdserver" - "github.com/coreos/etcd/raft/raftpb" ) -func mustMarshalMsg(t *testing.T, m raftpb.Message) []byte { - json, err := m.Marshal() - if err != nil { - t.Fatalf("error marshalling raft Message: %#v", err) - } - return json -} - -// errReader implements io.Reader to facilitate a broken request. -type errReader struct{} - -func (er *errReader) Read(_ []byte) (int, error) { return 0, errors.New("some error") } - -func TestServeRaft(t *testing.T) { - testCases := []struct { - method string - body io.Reader - serverErr error - clusterID string - - wcode int - }{ - { - // bad method - "GET", - bytes.NewReader( - mustMarshalMsg( - t, - raftpb.Message{}, - ), - ), - nil, - "0", - http.StatusMethodNotAllowed, - }, - { - // bad method - "PUT", - bytes.NewReader( - mustMarshalMsg( - t, - raftpb.Message{}, - ), - ), - nil, - "0", - http.StatusMethodNotAllowed, - }, - { - // bad method - "DELETE", - bytes.NewReader( - mustMarshalMsg( - t, - raftpb.Message{}, - ), - ), - nil, - "0", - http.StatusMethodNotAllowed, - }, - { - // bad request body - "POST", - &errReader{}, - nil, - "0", - http.StatusBadRequest, - }, - { - // bad request protobuf - "POST", - strings.NewReader("malformed garbage"), - nil, - "0", - http.StatusBadRequest, - }, - { - // good request, etcdserver.Server internal error - "POST", - bytes.NewReader( - mustMarshalMsg( - t, - raftpb.Message{}, - ), - ), - errors.New("some error"), - "0", - http.StatusInternalServerError, - }, - { - // good request from removed member - "POST", - bytes.NewReader( - mustMarshalMsg( - t, - raftpb.Message{}, - ), - ), - etcdserver.ErrRemoved, - "0", - http.StatusForbidden, - }, - { - // good request - "POST", - bytes.NewReader( - mustMarshalMsg( - t, - raftpb.Message{}, - ), - ), - nil, - "1", - http.StatusPreconditionFailed, - }, - { - // good request - "POST", - bytes.NewReader( - mustMarshalMsg( - t, - raftpb.Message{}, - ), - ), - nil, - "0", - http.StatusNoContent, - }, - } - for i, tt := range testCases { - req, err := http.NewRequest(tt.method, "foo", tt.body) - if err != nil { - t.Fatalf("#%d: could not create request: %#v", i, err) - } - req.Header.Set("X-Etcd-Cluster-ID", tt.clusterID) - rw := httptest.NewRecorder() - h := &raftHandler{stats: nil, server: &errServer{tt.serverErr}, clusterInfo: &fakeCluster{id: 0}} - h.ServeHTTP(rw, req) - if rw.Code != tt.wcode { - t.Errorf("#%d: got code=%d, want %d", i, rw.Code, tt.wcode) - } - } -} - func TestServeMembersFails(t *testing.T) { tests := []struct { method string diff --git a/etcdserver/sendhub.go b/etcdserver/sendhub.go new file mode 100644 index 000000000..ccacdc686 --- /dev/null +++ b/etcdserver/sendhub.go @@ -0,0 +1,131 @@ +/* + Copyright 2014 CoreOS, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package etcdserver + +import ( + "log" + "net/http" + "net/url" + "path" + + "github.com/coreos/etcd/etcdserver/stats" + "github.com/coreos/etcd/pkg/types" + "github.com/coreos/etcd/raft/raftpb" + "github.com/coreos/etcd/rafthttp" +) + +const ( + raftPrefix = "/raft" +) + +type sendHub struct { + tr http.RoundTripper + cl ClusterInfo + ss *stats.ServerStats + ls *stats.LeaderStats + senders map[types.ID]rafthttp.Sender + shouldstop chan struct{} +} + +// newSendHub creates the default send hub used to transport raft messages +// to other members. The returned sendHub will update the given ServerStats and +// LeaderStats appropriately. +func newSendHub(t http.RoundTripper, cl ClusterInfo, ss *stats.ServerStats, ls *stats.LeaderStats) *sendHub { + h := &sendHub{ + tr: t, + cl: cl, + ss: ss, + ls: ls, + senders: make(map[types.ID]rafthttp.Sender), + shouldstop: make(chan struct{}, 1), + } + for _, m := range cl.Members() { + h.Add(m) + } + return h +} + +func (h *sendHub) Send(msgs []raftpb.Message) { + for _, m := range msgs { + to := types.ID(m.To) + s, ok := h.senders[to] + if !ok { + if !h.cl.IsIDRemoved(to) { + log.Printf("etcdserver: send message to unknown receiver %s", to) + } + continue + } + + // TODO: don't block. we should be able to have 1000s + // of messages out at a time. + data, err := m.Marshal() + if err != nil { + log.Println("sender: dropping message:", err) + return // drop bad message + } + if m.Type == raftpb.MsgApp { + h.ss.SendAppendReq(len(data)) + } + + s.Send(data) + } +} + +func (h *sendHub) Stop() { + for _, s := range h.senders { + s.Stop() + } +} + +func (h *sendHub) ShouldStopNotify() <-chan struct{} { + return h.shouldstop +} + +func (h *sendHub) Add(m *Member) { + if _, ok := h.senders[m.ID]; ok { + return + } + // TODO: considering how to switch between all available peer urls + peerURL := m.PickPeerURL() + u, err := url.Parse(peerURL) + if err != nil { + log.Panicf("unexpect peer url %s", peerURL) + } + u.Path = path.Join(u.Path, raftPrefix) + fs := h.ls.Follower(m.ID.String()) + s := rafthttp.NewSender(h.tr, u.String(), h.cl.ID(), fs, h.shouldstop) + h.senders[m.ID] = s +} + +func (h *sendHub) Remove(id types.ID) { + h.senders[id].Stop() + delete(h.senders, id) +} + +func (h *sendHub) Update(m *Member) { + // TODO: return error or just panic? + if _, ok := h.senders[m.ID]; !ok { + return + } + peerURL := m.PickPeerURL() + u, err := url.Parse(peerURL) + if err != nil { + log.Panicf("unexpect peer url %s", peerURL) + } + u.Path = path.Join(u.Path, raftPrefix) + h.senders[m.ID].Update(u.String()) +} diff --git a/etcdserver/sendhub_test.go b/etcdserver/sendhub_test.go new file mode 100644 index 000000000..5d5f6017d --- /dev/null +++ b/etcdserver/sendhub_test.go @@ -0,0 +1,126 @@ +/* + Copyright 2014 CoreOS, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package etcdserver + +import ( + "net/http" + "testing" + "time" + + "github.com/coreos/etcd/etcdserver/stats" + "github.com/coreos/etcd/pkg/testutil" + "github.com/coreos/etcd/pkg/types" +) + +func TestSendHubInitSenders(t *testing.T) { + membs := []*Member{ + newTestMember(1, []string{"http://a"}, "", nil), + newTestMember(2, []string{"http://b"}, "", nil), + newTestMember(3, []string{"http://c"}, "", nil), + } + cl := newTestCluster(membs) + ls := stats.NewLeaderStats("") + h := newSendHub(nil, cl, nil, ls) + + ids := cl.MemberIDs() + if len(h.senders) != len(ids) { + t.Errorf("len(ids) = %d, want %d", len(h.senders), len(ids)) + } + for _, id := range ids { + if _, ok := h.senders[id]; !ok { + t.Errorf("senders[%s] is nil, want exists", id) + } + } +} + +func TestSendHubAdd(t *testing.T) { + cl := newTestCluster(nil) + ls := stats.NewLeaderStats("") + h := newSendHub(nil, cl, nil, ls) + m := newTestMember(1, []string{"http://a"}, "", nil) + h.Add(m) + + if _, ok := ls.Followers["1"]; !ok { + t.Errorf("FollowerStats[1] is nil, want exists") + } + s, ok := h.senders[types.ID(1)] + if !ok { + t.Fatalf("senders[1] is nil, want exists") + } + + h.Add(m) + ns := h.senders[types.ID(1)] + if s != ns { + t.Errorf("sender = %p, want %p", ns, s) + } +} + +func TestSendHubRemove(t *testing.T) { + membs := []*Member{ + newTestMember(1, []string{"http://a"}, "", nil), + } + cl := newTestCluster(membs) + ls := stats.NewLeaderStats("") + h := newSendHub(nil, cl, nil, ls) + h.Remove(types.ID(1)) + + if _, ok := h.senders[types.ID(1)]; ok { + t.Fatalf("senders[1] exists, want removed") + } +} + +func TestSendHubShouldStop(t *testing.T) { + membs := []*Member{ + newTestMember(1, []string{"http://a"}, "", nil), + } + tr := newRespRoundTripper(http.StatusForbidden, nil) + cl := newTestCluster(membs) + ls := stats.NewLeaderStats("") + h := newSendHub(tr, cl, nil, ls) + + shouldstop := h.ShouldStopNotify() + select { + case <-shouldstop: + t.Fatalf("received unexpected shouldstop notification") + case <-time.After(10 * time.Millisecond): + } + h.senders[1].Send([]byte("somedata")) + + testutil.ForceGosched() + select { + case <-shouldstop: + default: + t.Fatalf("cannot receive stop notification") + } +} + +type respRoundTripper struct { + code int + err error +} + +func newRespRoundTripper(code int, err error) *respRoundTripper { + return &respRoundTripper{code: code, err: err} +} +func (t *respRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return &http.Response{StatusCode: t.code, Body: &nopReadCloser{}}, t.err +} + +type nopReadCloser struct{} + +func (n *nopReadCloser) Read(p []byte) (int, error) { return 0, nil } +func (n *nopReadCloser) Close() error { return nil } diff --git a/etcdserver/server.go b/etcdserver/server.go index 9ecf92e4b..a65c70c3f 100644 --- a/etcdserver/server.go +++ b/etcdserver/server.go @@ -33,6 +33,7 @@ import ( "github.com/coreos/etcd/Godeps/_workspace/src/code.google.com/p/go.net/context" "github.com/coreos/etcd/discovery" + "github.com/coreos/etcd/etcdserver/etcdhttp/httptypes" pb "github.com/coreos/etcd/etcdserver/etcdserverpb" "github.com/coreos/etcd/etcdserver/stats" "github.com/coreos/etcd/pkg/pbutil" @@ -61,7 +62,6 @@ const ( var ( ErrUnknownMethod = errors.New("etcdserver: unknown method") ErrStopped = errors.New("etcdserver: server stopped") - ErrRemoved = errors.New("etcdserver: server removed") ErrIDRemoved = errors.New("etcdserver: ID removed") ErrIDExists = errors.New("etcdserver: ID exists") ErrIDNotFound = errors.New("etcdserver: ID not found") @@ -145,8 +145,6 @@ type Stats interface { LeaderStats() []byte // StoreStats returns statistics of the store backing this EtcdServer StoreStats() []byte - // UpdateRecvApp updates the underlying statistics in response to a receiving an Append request - UpdateRecvApp(from types.ID, length int64) } type RaftTimer interface { @@ -320,7 +318,11 @@ func (s *EtcdServer) ID() types.ID { return s.id } func (s *EtcdServer) Process(ctx context.Context, m raftpb.Message) error { if s.Cluster.IsIDRemoved(types.ID(m.From)) { - return ErrRemoved + log.Printf("etcdserver: reject message from removed member %s", types.ID(m.From).String()) + return httptypes.NewHTTPError(http.StatusForbidden, "cannot process message from removed member") + } + if m.Type == raftpb.MsgApp { + s.stats.RecvAppendReq(types.ID(m.From).String(), m.Size()) } return s.node.Step(ctx, m) } @@ -488,10 +490,6 @@ func (s *EtcdServer) LeaderStats() []byte { func (s *EtcdServer) StoreStats() []byte { return s.store.JsonStats() } -func (s *EtcdServer) UpdateRecvApp(from types.ID, length int64) { - s.stats.RecvAppendReq(from.String(), int(length)) -} - func (s *EtcdServer) AddMember(ctx context.Context, memb Member) error { // TODO: move Member to protobuf type b, err := json.Marshal(memb) diff --git a/rafthttp/http.go b/rafthttp/http.go new file mode 100644 index 000000000..87ff9f924 --- /dev/null +++ b/rafthttp/http.go @@ -0,0 +1,90 @@ +/* + Copyright 2014 CoreOS, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package rafthttp + +import ( + "io/ioutil" + "log" + "net/http" + + "github.com/coreos/etcd/pkg/types" + "github.com/coreos/etcd/raft/raftpb" + + "github.com/coreos/etcd/Godeps/_workspace/src/code.google.com/p/go.net/context" +) + +type Processor interface { + Process(ctx context.Context, m raftpb.Message) error +} + +func NewHandler(p Processor, cid types.ID) http.Handler { + return &handler{ + p: p, + cid: cid, + } +} + +type handler struct { + p Processor + cid types.ID +} + +func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + w.Header().Set("Allow", "POST") + http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) + return + } + + wcid := h.cid.String() + w.Header().Set("X-Etcd-Cluster-ID", wcid) + + gcid := r.Header.Get("X-Etcd-Cluster-ID") + if gcid != wcid { + log.Printf("rafthttp: request ignored due to cluster ID mismatch got %s want %s", gcid, wcid) + http.Error(w, "clusterID mismatch", http.StatusPreconditionFailed) + return + } + + b, err := ioutil.ReadAll(r.Body) + if err != nil { + log.Println("rafthttp: error reading raft message:", err) + http.Error(w, "error reading raft message", http.StatusBadRequest) + return + } + var m raftpb.Message + if err := m.Unmarshal(b); err != nil { + log.Println("rafthttp: error unmarshaling raft message:", err) + http.Error(w, "error unmarshaling raft message", http.StatusBadRequest) + return + } + if err := h.p.Process(context.TODO(), m); err != nil { + switch v := err.(type) { + case writerToResponse: + v.WriteTo(w) + default: + log.Printf("rafthttp: error processing raft message: %v", err) + http.Error(w, "error processing raft message", http.StatusInternalServerError) + } + return + } + w.WriteHeader(http.StatusNoContent) +} + +type writerToResponse interface { + WriteTo(w http.ResponseWriter) +} diff --git a/rafthttp/http_test.go b/rafthttp/http_test.go new file mode 100644 index 000000000..1718a9709 --- /dev/null +++ b/rafthttp/http_test.go @@ -0,0 +1,184 @@ +/* + Copyright 2014 CoreOS, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package rafthttp + +import ( + "bytes" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/coreos/etcd/pkg/pbutil" + "github.com/coreos/etcd/pkg/types" + "github.com/coreos/etcd/raft/raftpb" + + "github.com/coreos/etcd/Godeps/_workspace/src/code.google.com/p/go.net/context" +) + +func TestServeRaft(t *testing.T) { + testCases := []struct { + method string + body io.Reader + p Processor + clusterID string + + wcode int + }{ + { + // bad method + "GET", + bytes.NewReader( + pbutil.MustMarshal(&raftpb.Message{}), + ), + &nopProcessor{}, + "0", + http.StatusMethodNotAllowed, + }, + { + // bad method + "PUT", + bytes.NewReader( + pbutil.MustMarshal(&raftpb.Message{}), + ), + &nopProcessor{}, + "0", + http.StatusMethodNotAllowed, + }, + { + // bad method + "DELETE", + bytes.NewReader( + pbutil.MustMarshal(&raftpb.Message{}), + ), + &nopProcessor{}, + "0", + http.StatusMethodNotAllowed, + }, + { + // bad request body + "POST", + &errReader{}, + &nopProcessor{}, + "0", + http.StatusBadRequest, + }, + { + // bad request protobuf + "POST", + strings.NewReader("malformed garbage"), + &nopProcessor{}, + "0", + http.StatusBadRequest, + }, + { + // good request, wrong cluster ID + "POST", + bytes.NewReader( + pbutil.MustMarshal(&raftpb.Message{}), + ), + &nopProcessor{}, + "1", + http.StatusPreconditionFailed, + }, + { + // good request, Processor failure + "POST", + bytes.NewReader( + pbutil.MustMarshal(&raftpb.Message{}), + ), + &errProcessor{ + err: &resWriterToError{code: http.StatusForbidden}, + }, + "0", + http.StatusForbidden, + }, + { + // good request, Processor failure + "POST", + bytes.NewReader( + pbutil.MustMarshal(&raftpb.Message{}), + ), + &errProcessor{ + err: &resWriterToError{code: http.StatusInternalServerError}, + }, + "0", + http.StatusInternalServerError, + }, + { + // good request, Processor failure + "POST", + bytes.NewReader( + pbutil.MustMarshal(&raftpb.Message{}), + ), + &errProcessor{err: errors.New("blah")}, + "0", + http.StatusInternalServerError, + }, + { + // good request + "POST", + bytes.NewReader( + pbutil.MustMarshal(&raftpb.Message{}), + ), + &nopProcessor{}, + "0", + http.StatusNoContent, + }, + } + for i, tt := range testCases { + req, err := http.NewRequest(tt.method, "foo", tt.body) + if err != nil { + t.Fatalf("#%d: could not create request: %#v", i, err) + } + req.Header.Set("X-Etcd-Cluster-ID", tt.clusterID) + rw := httptest.NewRecorder() + h := NewHandler(tt.p, types.ID(0), &nopStats{}) + h.ServeHTTP(rw, req) + if rw.Code != tt.wcode { + t.Errorf("#%d: got code=%d, want %d", i, rw.Code, tt.wcode) + } + } +} + +// errReader implements io.Reader to facilitate a broken request. +type errReader struct{} + +func (er *errReader) Read(_ []byte) (int, error) { return 0, errors.New("some error") } + +type nopProcessor struct{} + +func (p *nopProcessor) Process(ctx context.Context, m raftpb.Message) error { return nil } + +type errProcessor struct { + err error +} + +func (p *errProcessor) Process(ctx context.Context, m raftpb.Message) error { return p.err } + +type nopStats struct{} + +func (s *nopStats) UpdateRecvApp(from types.ID, length int64) {} + +type resWriterToError struct { + code int +} + +func (e *resWriterToError) Error() string { return "" } +func (e *resWriterToError) WriteTo(w http.ResponseWriter) { w.WriteHeader(e.code) } diff --git a/etcdserver/sender.go b/rafthttp/sender.go similarity index 55% rename from etcdserver/sender.go rename to rafthttp/sender.go index 875c48bde..203a6c5aa 100644 --- a/etcdserver/sender.go +++ b/rafthttp/sender.go @@ -14,138 +14,36 @@ limitations under the License. */ -package etcdserver +package rafthttp import ( "bytes" "fmt" "log" "net/http" - "net/url" - "path" "sync" "time" "github.com/coreos/etcd/etcdserver/stats" "github.com/coreos/etcd/pkg/types" - "github.com/coreos/etcd/raft/raftpb" ) const ( - raftPrefix = "/raft" connPerSender = 4 senderBufSize = connPerSender * 4 ) -type sendHub struct { - tr http.RoundTripper - cl ClusterInfo - ss *stats.ServerStats - ls *stats.LeaderStats - senders map[types.ID]*sender - shouldstop chan struct{} +type Sender interface { + Update(u string) + // Send sends the data to the remote node. It is always non-blocking. + // It may be fail to send data if it returns nil error. + Send(data []byte) error + // Stop performs any necessary finalization and terminates the Sender + // elegantly. + Stop() } -// newSendHub creates the default send hub used to transport raft messages -// to other members. The returned sendHub will update the given ServerStats and -// LeaderStats appropriately. -func newSendHub(t http.RoundTripper, cl ClusterInfo, ss *stats.ServerStats, ls *stats.LeaderStats) *sendHub { - h := &sendHub{ - tr: t, - cl: cl, - ss: ss, - ls: ls, - senders: make(map[types.ID]*sender), - shouldstop: make(chan struct{}, 1), - } - for _, m := range cl.Members() { - h.Add(m) - } - return h -} - -func (h *sendHub) Send(msgs []raftpb.Message) { - for _, m := range msgs { - to := types.ID(m.To) - s, ok := h.senders[to] - if !ok { - if !h.cl.IsIDRemoved(to) { - log.Printf("etcdserver: send message to unknown receiver %s", to) - } - continue - } - - // TODO: don't block. we should be able to have 1000s - // of messages out at a time. - data, err := m.Marshal() - if err != nil { - log.Println("sender: dropping message:", err) - return // drop bad message - } - if m.Type == raftpb.MsgApp { - h.ss.SendAppendReq(len(data)) - } - - // TODO (xiangli): reasonable retry logic - s.send(data) - } -} - -func (h *sendHub) Stop() { - for _, s := range h.senders { - s.stop() - } -} - -func (h *sendHub) ShouldStopNotify() <-chan struct{} { - return h.shouldstop -} - -func (h *sendHub) Add(m *Member) { - if _, ok := h.senders[m.ID]; ok { - return - } - // TODO: considering how to switch between all available peer urls - u := fmt.Sprintf("%s%s", m.PickPeerURL(), raftPrefix) - fs := h.ls.Follower(m.ID.String()) - s := newSender(h.tr, u, h.cl.ID(), fs, h.shouldstop) - h.senders[m.ID] = s -} - -func (h *sendHub) Remove(id types.ID) { - h.senders[id].stop() - delete(h.senders, id) -} - -func (h *sendHub) Update(m *Member) { - // TODO: return error or just panic? - if _, ok := h.senders[m.ID]; !ok { - return - } - peerURL := m.PickPeerURL() - u, err := url.Parse(peerURL) - if err != nil { - log.Panicf("unexpect peer url %s", peerURL) - } - u.Path = path.Join(u.Path, raftPrefix) - s := h.senders[m.ID] - s.mu.Lock() - defer s.mu.Unlock() - s.u = u.String() -} - -type sender struct { - tr http.RoundTripper - u string - cid types.ID - fs *stats.FollowerStats - q chan []byte - mu sync.RWMutex - wg sync.WaitGroup - shouldstop chan struct{} -} - -func newSender(tr http.RoundTripper, u string, cid types.ID, fs *stats.FollowerStats, shouldstop chan struct{}) *sender { +func NewSender(tr http.RoundTripper, u string, cid types.ID, fs *stats.FollowerStats, shouldstop chan struct{}) *sender { s := &sender{ tr: tr, u: u, @@ -161,7 +59,25 @@ func newSender(tr http.RoundTripper, u string, cid types.ID, fs *stats.FollowerS return s } -func (s *sender) send(data []byte) error { +type sender struct { + tr http.RoundTripper + u string + cid types.ID + fs *stats.FollowerStats + q chan []byte + mu sync.RWMutex + wg sync.WaitGroup + shouldstop chan struct{} +} + +func (s *sender) Update(u string) { + s.mu.Lock() + defer s.mu.Unlock() + s.u = u +} + +// TODO (xiangli): reasonable retry logic +func (s *sender) Send(data []byte) error { select { case s.q <- data: return nil @@ -171,7 +87,7 @@ func (s *sender) send(data []byte) error { } } -func (s *sender) stop() { +func (s *sender) Stop() { close(s.q) s.wg.Wait() } diff --git a/etcdserver/sender_test.go b/rafthttp/sender_test.go similarity index 65% rename from etcdserver/sender_test.go rename to rafthttp/sender_test.go index e24637093..6e86a4f0c 100644 --- a/etcdserver/sender_test.go +++ b/rafthttp/sender_test.go @@ -14,7 +14,7 @@ limitations under the License. */ -package etcdserver +package rafthttp import ( "errors" @@ -22,109 +22,23 @@ import ( "net/http" "sync" "testing" - "time" "github.com/coreos/etcd/etcdserver/stats" "github.com/coreos/etcd/pkg/testutil" "github.com/coreos/etcd/pkg/types" ) -func TestSendHubInitSenders(t *testing.T) { - membs := []*Member{ - newTestMember(1, []string{"http://a"}, "", nil), - newTestMember(2, []string{"http://b"}, "", nil), - newTestMember(3, []string{"http://c"}, "", nil), - } - cl := newTestCluster(membs) - ls := stats.NewLeaderStats("") - h := newSendHub(nil, cl, nil, ls) - - ids := cl.MemberIDs() - if len(h.senders) != len(ids) { - t.Errorf("len(ids) = %d, want %d", len(h.senders), len(ids)) - } - for _, id := range ids { - if _, ok := h.senders[id]; !ok { - t.Errorf("senders[%s] is nil, want exists", id) - } - } -} - -func TestSendHubAdd(t *testing.T) { - cl := newTestCluster(nil) - ls := stats.NewLeaderStats("") - h := newSendHub(nil, cl, nil, ls) - m := newTestMember(1, []string{"http://a"}, "", nil) - h.Add(m) - - if _, ok := ls.Followers["1"]; !ok { - t.Errorf("FollowerStats[1] is nil, want exists") - } - s, ok := h.senders[types.ID(1)] - if !ok { - t.Fatalf("senders[1] is nil, want exists") - } - if s.u != "http://a/raft" { - t.Errorf("url = %s, want %s", s.u, "http://a/raft") - } - - h.Add(m) - ns := h.senders[types.ID(1)] - if s != ns { - t.Errorf("sender = %p, want %p", ns, s) - } -} - -func TestSendHubRemove(t *testing.T) { - membs := []*Member{ - newTestMember(1, []string{"http://a"}, "", nil), - } - cl := newTestCluster(membs) - ls := stats.NewLeaderStats("") - h := newSendHub(nil, cl, nil, ls) - h.Remove(types.ID(1)) - - if _, ok := h.senders[types.ID(1)]; ok { - t.Fatalf("senders[1] exists, want removed") - } -} - -func TestSendHubShouldStop(t *testing.T) { - membs := []*Member{ - newTestMember(1, []string{"http://a"}, "", nil), - } - tr := newRespRoundTripper(http.StatusForbidden, nil) - cl := newTestCluster(membs) - ls := stats.NewLeaderStats("") - h := newSendHub(tr, cl, nil, ls) - - shouldstop := h.ShouldStopNotify() - select { - case <-shouldstop: - t.Fatalf("received unexpected shouldstop notification") - case <-time.After(10 * time.Millisecond): - } - h.senders[1].send([]byte("somedata")) - - testutil.ForceGosched() - select { - case <-shouldstop: - default: - t.Fatalf("cannot receive stop notification") - } -} - // TestSenderSend tests that send func could post data using roundtripper // and increase success count in stats. func TestSenderSend(t *testing.T) { tr := &roundTripperRecorder{} fs := &stats.FollowerStats{} - s := newSender(tr, "http://10.0.0.1", types.ID(1), fs, nil) + s := NewSender(tr, "http://10.0.0.1", types.ID(1), fs, nil) - if err := s.send([]byte("some data")); err != nil { + if err := s.Send([]byte("some data")); err != nil { t.Fatalf("unexpect send error: %v", err) } - s.stop() + s.Stop() if tr.Request() == nil { t.Errorf("sender fails to post the data") @@ -139,12 +53,12 @@ func TestSenderSend(t *testing.T) { func TestSenderExceedMaximalServing(t *testing.T) { tr := newRoundTripperBlocker() fs := &stats.FollowerStats{} - s := newSender(tr, "http://10.0.0.1", types.ID(1), fs, nil) + s := NewSender(tr, "http://10.0.0.1", types.ID(1), fs, nil) // keep the sender busy and make the buffer full // nothing can go out as we block the sender for i := 0; i < connPerSender+senderBufSize; i++ { - if err := s.send([]byte("some data")); err != nil { + if err := s.Send([]byte("some data")); err != nil { t.Errorf("send err = %v, want nil", err) } // force the sender to grab data @@ -152,7 +66,7 @@ func TestSenderExceedMaximalServing(t *testing.T) { } // try to send a data when we are sure the buffer is full - if err := s.send([]byte("some data")); err == nil { + if err := s.Send([]byte("some data")); err == nil { t.Errorf("unexpect send success") } @@ -161,22 +75,22 @@ func TestSenderExceedMaximalServing(t *testing.T) { testutil.ForceGosched() // It could send new data after previous ones succeed - if err := s.send([]byte("some data")); err != nil { + if err := s.Send([]byte("some data")); err != nil { t.Errorf("send err = %v, want nil", err) } - s.stop() + s.Stop() } // TestSenderSendFailed tests that when send func meets the post error, // it increases fail count in stats. func TestSenderSendFailed(t *testing.T) { fs := &stats.FollowerStats{} - s := newSender(newRespRoundTripper(0, errors.New("blah")), "http://10.0.0.1", types.ID(1), fs, nil) + s := NewSender(newRespRoundTripper(0, errors.New("blah")), "http://10.0.0.1", types.ID(1), fs, nil) - if err := s.send([]byte("some data")); err != nil { - t.Fatalf("unexpect send error: %v", err) + if err := s.Send([]byte("some data")); err != nil { + t.Fatalf("unexpect Send error: %v", err) } - s.stop() + s.Stop() fs.Lock() defer fs.Unlock() @@ -187,11 +101,11 @@ func TestSenderSendFailed(t *testing.T) { func TestSenderPost(t *testing.T) { tr := &roundTripperRecorder{} - s := newSender(tr, "http://10.0.0.1", types.ID(1), nil, nil) + s := NewSender(tr, "http://10.0.0.1", types.ID(1), nil, nil) if err := s.post([]byte("some data")); err != nil { t.Fatalf("unexpect post error: %v", err) } - s.stop() + s.Stop() if g := tr.Request().Method; g != "POST" { t.Errorf("method = %s, want %s", g, "POST") @@ -230,9 +144,9 @@ func TestSenderPostBad(t *testing.T) { } for i, tt := range tests { shouldstop := make(chan struct{}) - s := newSender(newRespRoundTripper(tt.code, tt.err), tt.u, types.ID(1), nil, shouldstop) + s := NewSender(newRespRoundTripper(tt.code, tt.err), tt.u, types.ID(1), nil, shouldstop) err := s.post([]byte("some data")) - s.stop() + s.Stop() if err == nil { t.Errorf("#%d: err = nil, want not nil", i) @@ -251,9 +165,9 @@ func TestSenderPostShouldStop(t *testing.T) { } for i, tt := range tests { shouldstop := make(chan struct{}, 1) - s := newSender(newRespRoundTripper(tt.code, tt.err), tt.u, types.ID(1), nil, shouldstop) + s := NewSender(newRespRoundTripper(tt.code, tt.err), tt.u, types.ID(1), nil, shouldstop) s.post([]byte("some data")) - s.stop() + s.Stop() select { case <-shouldstop: default: