diff --git a/etcdserver/etcdhttp/http.go b/etcdserver/etcdhttp/http.go index 4a67f270d..209cb5ec8 100644 --- a/etcdserver/etcdhttp/http.go +++ b/etcdserver/etcdhttp/http.go @@ -68,6 +68,10 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } func (h Handler) serveKeys(ctx context.Context, w http.ResponseWriter, r *http.Request) { + if !allowMethod(w, r.Method, "GET", "PUT", "POST", "DELETE") { + return + } + rr, err := parseRequest(r, genID()) if err != nil { writeError(w, err) @@ -103,8 +107,7 @@ func (h Handler) serveKeys(ctx context.Context, w http.ResponseWriter, r *http.R // serveMachines responds address list in the format '0.0.0.0, 1.1.1.1'. // TODO: rethink the format of machine list because it is not json format. func (h Handler) serveMachines(w http.ResponseWriter, r *http.Request) { - if r.Method != "GET" && r.Method != "HEAD" { - allow(w, "GET", "HEAD") + if !allowMethod(w, r.Method, "GET", "HEAD") { return } endpoints := h.Peers.Endpoints() @@ -112,6 +115,9 @@ func (h Handler) serveMachines(w http.ResponseWriter, r *http.Request) { } func (h Handler) serveRaft(ctx context.Context, w http.ResponseWriter, r *http.Request) { + if !allowMethod(w, r.Method, "POST") { + return + } b, err := ioutil.ReadAll(r.Body) if err != nil { log.Println("etcdhttp: error reading raft message:", err) @@ -317,8 +323,16 @@ func waitForEvent(ctx context.Context, w http.ResponseWriter, wa store.Watcher) } } -// allow writes response for the case that Method Not Allowed -func allow(w http.ResponseWriter, m ...string) { - w.Header().Set("Allow", strings.Join(m, ",")) +// allowMethod verifies that the given method is one of the allowed methods, +// and if not, it writes an error to w. A boolean is returned indicating +// whether or not the method is allowed. +func allowMethod(w http.ResponseWriter, m string, ms ...string) bool { + for _, meth := range ms { + if m == meth { + return true + } + } + w.Header().Set("Allow", strings.Join(ms, ",")) http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) + return false } diff --git a/etcdserver/etcdhttp/http_test.go b/etcdserver/etcdhttp/http_test.go index 6eeceb8e2..19cafccc8 100644 --- a/etcdserver/etcdhttp/http_test.go +++ b/etcdserver/etcdhttp/http_test.go @@ -680,3 +680,69 @@ func TestPeersEndpoints(t *testing.T) { } } } + +func TestAllowMethod(t *testing.T) { + tests := []struct { + m string + ms []string + w bool + wh string + }{ + // Accepted methods + { + m: "GET", + ms: []string{"GET", "POST", "PUT"}, + w: true, + }, + { + m: "POST", + ms: []string{"POST"}, + w: true, + }, + // Made-up methods no good + { + m: "FAKE", + ms: []string{"GET", "POST", "PUT"}, + w: false, + wh: "GET,POST,PUT", + }, + // Empty methods no good + { + m: "", + ms: []string{"GET", "POST"}, + w: false, + wh: "GET,POST", + }, + // Empty accepted methods no good + { + m: "GET", + ms: []string{""}, + w: false, + wh: "", + }, + // No methods accepted + { + m: "GET", + ms: []string{}, + w: false, + wh: "", + }, + } + + for i, tt := range tests { + rw := httptest.NewRecorder() + g := allowMethod(rw, tt.m, tt.ms...) + if g != tt.w { + t.Errorf("#%d: got allowMethod()=%t, want %t", i, g, tt.w) + } + if !tt.w { + if rw.Code != http.StatusMethodNotAllowed { + t.Errorf("#%d: code=%d, want %d", i, rw.Code, http.StatusMethodNotAllowed) + } + gh := rw.Header().Get("Allow") + if gh != tt.wh { + t.Errorf("#%d: Allow header=%q, want %q", i, gh, tt.wh) + } + } + } +}