From 894e678ad6e95e4705286d333639ae260e5a10cb Mon Sep 17 00:00:00 2001 From: Xiang Li Date: Tue, 21 Oct 2014 10:41:32 -0700 Subject: [PATCH] etcdserver: checking clusterID --- etcdserver/cluster_store.go | 20 +++++++++++++++++--- etcdserver/cluster_store_test.go | 9 +++++---- etcdserver/etcdhttp/http.go | 13 +++++++++++-- etcdserver/etcdhttp/http_test.go | 27 +++++++++++++++++++++++++-- etcdserver/server.go | 2 +- 5 files changed, 59 insertions(+), 12 deletions(-) diff --git a/etcdserver/cluster_store.go b/etcdserver/cluster_store.go index a17b99962..d7666c038 100644 --- a/etcdserver/cluster_store.go +++ b/etcdserver/cluster_store.go @@ -22,6 +22,7 @@ import ( "fmt" "log" "net/http" + "strconv" "time" etcdErr "github.com/coreos/etcd/error" @@ -46,6 +47,9 @@ type ClusterStore interface { type clusterStore struct { Store store.Store + // TODO: write the id into the actual store? + // TODO: save the id as string? + id uint64 } // Add puts a new Member into the store. @@ -72,6 +76,7 @@ func (s *clusterStore) Add(m Member) { // lock here. func (s *clusterStore) Get() Cluster { c := NewCluster() + c.id = s.id e, err := s.Store.Get(membersKVPrefix, true, true) if err != nil { if v, ok := err.(*etcdErr.Error); ok && v.ErrorCode == etcdErr.EcodeKeyNotFound { @@ -141,6 +146,7 @@ func Sender(t *http.Transport, cls ClusterStore, ss *stats.ServerStats, ls *stat // ClusterStore, retrying up to 3 times for each message. The given // ServerStats and LeaderStats are updated appropriately func send(c *http.Client, cls ClusterStore, m raftpb.Message, ss *stats.ServerStats, ls *stats.LeaderStats) { + cid := cls.Get().ID() // TODO (xiangli): reasonable retry logic for i := 0; i < 3; i++ { u := cls.Get().Pick(m.To) @@ -167,7 +173,7 @@ func send(c *http.Client, cls ClusterStore, m raftpb.Message, ss *stats.ServerSt fs := ls.Follower(to) start := time.Now() - sent := httpPost(c, u, data) + sent := httpPost(c, u, cid, data) end := time.Now() if sent { fs.Succ(end.Sub(start)) @@ -180,12 +186,20 @@ func send(c *http.Client, cls ClusterStore, m raftpb.Message, ss *stats.ServerSt // httpPost POSTs a data payload to a url using the given client. Returns true // if the POST succeeds, false on any failure. -func httpPost(c *http.Client, url string, data []byte) bool { - resp, err := c.Post(url, "application/protobuf", bytes.NewBuffer(data)) +func httpPost(c *http.Client, url string, cid uint64, data []byte) bool { + req, err := http.NewRequest("POST", url, bytes.NewBuffer(data)) if err != nil { // TODO: log the error? return false } + req.Header.Set("Content-Type", "application/protobuf") + req.Header.Set("X-Etcd-Cluster-ID", strconv.FormatUint(cid, 16)) + resp, err := c.Do(req) + if err != nil { + // TODO: log the error? + return false + } + resp.Body.Close() if resp.StatusCode != http.StatusNoContent { // TODO: log the error? diff --git a/etcdserver/cluster_store_test.go b/etcdserver/cluster_store_test.go index 3200f9ec2..f73424286 100644 --- a/etcdserver/cluster_store_test.go +++ b/etcdserver/cluster_store_test.go @@ -92,14 +92,15 @@ func TestClusterStoreGet(t *testing.T) { }, } for i, tt := range tests { - cs := &clusterStore{Store: newGetAllStore()} - for _, m := range tt.mems { - cs.Add(m) - } c := NewCluster() if err := c.AddSlice(tt.mems); err != nil { t.Fatal(err) } + c.GenID(nil) + cs := &clusterStore{Store: newGetAllStore(), id: c.id} + for _, m := range tt.mems { + cs.Add(m) + } if g := cs.Get(); !reflect.DeepEqual(&g, c) { t.Errorf("#%d: mems = %v, want %v", i, &g, c) } diff --git a/etcdserver/etcdhttp/http.go b/etcdserver/etcdhttp/http.go index 44c08f4d5..5e028aebe 100644 --- a/etcdserver/etcdhttp/http.go +++ b/etcdserver/etcdhttp/http.go @@ -80,8 +80,9 @@ func NewClientHandler(server *etcdserver.EtcdServer) http.Handler { // NewPeerHandler generates an http.Handler to handle etcd peer (raft) requests. func NewPeerHandler(server *etcdserver.EtcdServer) http.Handler { sh := &serverHandler{ - server: server, - stats: server, + server: server, + stats: server, + clusterStore: server.ClusterStore, } mux := http.NewServeMux() mux.HandleFunc(raftPrefix, sh.serveRaft) @@ -215,6 +216,14 @@ func (h serverHandler) serveRaft(w http.ResponseWriter, r *http.Request) { return } + gcid := r.Header.Get("X-Etcd-Cluster-ID") + wcid := strconv.FormatUint(h.clusterStore.Get().ID(), 16) + if gcid != wcid { + log.Printf("etcdhttp: request ignored: clusterID mismatch got %s want %x", gcid, wcid) + http.Error(w, "clusterID mismatch", http.StatusPreconditionFailed) + return + } + b, err := ioutil.ReadAll(r.Body) if err != nil { log.Println("etcdhttp: error reading raft message:", err) diff --git a/etcdserver/etcdhttp/http_test.go b/etcdserver/etcdhttp/http_test.go index 2c6ada670..287742a87 100644 --- a/etcdserver/etcdhttp/http_test.go +++ b/etcdserver/etcdhttp/http_test.go @@ -862,6 +862,7 @@ func TestServeRaft(t *testing.T) { method string body io.Reader serverErr error + clusterID string wcode int }{ @@ -875,6 +876,7 @@ func TestServeRaft(t *testing.T) { ), ), nil, + "0", http.StatusMethodNotAllowed, }, { @@ -887,6 +889,7 @@ func TestServeRaft(t *testing.T) { ), ), nil, + "0", http.StatusMethodNotAllowed, }, { @@ -899,6 +902,7 @@ func TestServeRaft(t *testing.T) { ), ), nil, + "0", http.StatusMethodNotAllowed, }, { @@ -906,6 +910,7 @@ func TestServeRaft(t *testing.T) { "POST", &errReader{}, nil, + "0", http.StatusBadRequest, }, { @@ -913,6 +918,7 @@ func TestServeRaft(t *testing.T) { "POST", strings.NewReader("malformed garbage"), nil, + "0", http.StatusBadRequest, }, { @@ -925,6 +931,7 @@ func TestServeRaft(t *testing.T) { ), ), errors.New("some error"), + "0", http.StatusInternalServerError, }, { @@ -937,6 +944,20 @@ func TestServeRaft(t *testing.T) { ), ), nil, + "1", + http.StatusPreconditionFailed, + }, + { + // good request + "POST", + bytes.NewReader( + mustMarshalMsg( + t, + raftpb.Message{}, + ), + ), + nil, + "0", http.StatusNoContent, }, } @@ -945,9 +966,11 @@ func TestServeRaft(t *testing.T) { if err != nil { t.Fatalf("#%d: could not create request: %#v", i, err) } + req.Header.Set("X-Etcd-Cluster-ID", tt.clusterID) h := &serverHandler{ - timeout: time.Hour, - server: &errServer{tt.serverErr}, + timeout: time.Hour, + server: &errServer{tt.serverErr}, + clusterStore: &fakeCluster{}, } rw := httptest.NewRecorder() h.serveRaft(rw, req) diff --git a/etcdserver/server.go b/etcdserver/server.go index 69bc2cc8b..e81e92de6 100644 --- a/etcdserver/server.go +++ b/etcdserver/server.go @@ -204,7 +204,7 @@ func NewServer(cfg *ServerConfig) *EtcdServer { id, cid, n, w = restartNode(cfg, index, snapshot) } - cls := &clusterStore{Store: st} + cls := &clusterStore{Store: st, id: cid} sstats := &stats.ServerStats{ Name: cfg.Name,