mirror of
https://github.com/etcd-io/etcd.git
synced 2024-09-27 06:25:44 +00:00
Merge pull request #1069 from jonboulle/methods
etcdhttp: check method for every endpoint, add tests
This commit is contained in:
commit
35ae488120
@ -68,6 +68,10 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h Handler) serveKeys(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
func (h Handler) serveKeys(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||||
|
if !allowMethod(w, r.Method, "GET", "PUT", "POST", "DELETE") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
rr, err := parseRequest(r, genID())
|
rr, err := parseRequest(r, genID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
writeError(w, err)
|
writeError(w, err)
|
||||||
@ -103,8 +107,7 @@ func (h Handler) serveKeys(ctx context.Context, w http.ResponseWriter, r *http.R
|
|||||||
// serveMachines responds address list in the format '0.0.0.0, 1.1.1.1'.
|
// serveMachines responds address list in the format '0.0.0.0, 1.1.1.1'.
|
||||||
// TODO: rethink the format of machine list because it is not json format.
|
// TODO: rethink the format of machine list because it is not json format.
|
||||||
func (h Handler) serveMachines(w http.ResponseWriter, r *http.Request) {
|
func (h Handler) serveMachines(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method != "GET" && r.Method != "HEAD" {
|
if !allowMethod(w, r.Method, "GET", "HEAD") {
|
||||||
allow(w, "GET", "HEAD")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
endpoints := h.Peers.Endpoints()
|
endpoints := h.Peers.Endpoints()
|
||||||
@ -112,6 +115,9 @@ func (h Handler) serveMachines(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h Handler) serveRaft(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
func (h Handler) serveRaft(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||||
|
if !allowMethod(w, r.Method, "POST") {
|
||||||
|
return
|
||||||
|
}
|
||||||
b, err := ioutil.ReadAll(r.Body)
|
b, err := ioutil.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("etcdhttp: error reading raft message:", err)
|
log.Println("etcdhttp: error reading raft message:", err)
|
||||||
@ -317,8 +323,16 @@ func waitForEvent(ctx context.Context, w http.ResponseWriter, wa store.Watcher)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// allow writes response for the case that Method Not Allowed
|
// allowMethod verifies that the given method is one of the allowed methods,
|
||||||
func allow(w http.ResponseWriter, m ...string) {
|
// and if not, it writes an error to w. A boolean is returned indicating
|
||||||
w.Header().Set("Allow", strings.Join(m, ","))
|
// whether or not the method is allowed.
|
||||||
|
func allowMethod(w http.ResponseWriter, m string, ms ...string) bool {
|
||||||
|
for _, meth := range ms {
|
||||||
|
if m == meth {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.Header().Set("Allow", strings.Join(ms, ","))
|
||||||
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
|
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
@ -680,3 +680,69 @@ func TestPeersEndpoints(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAllowMethod(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
m string
|
||||||
|
ms []string
|
||||||
|
w bool
|
||||||
|
wh string
|
||||||
|
}{
|
||||||
|
// Accepted methods
|
||||||
|
{
|
||||||
|
m: "GET",
|
||||||
|
ms: []string{"GET", "POST", "PUT"},
|
||||||
|
w: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
m: "POST",
|
||||||
|
ms: []string{"POST"},
|
||||||
|
w: true,
|
||||||
|
},
|
||||||
|
// Made-up methods no good
|
||||||
|
{
|
||||||
|
m: "FAKE",
|
||||||
|
ms: []string{"GET", "POST", "PUT"},
|
||||||
|
w: false,
|
||||||
|
wh: "GET,POST,PUT",
|
||||||
|
},
|
||||||
|
// Empty methods no good
|
||||||
|
{
|
||||||
|
m: "",
|
||||||
|
ms: []string{"GET", "POST"},
|
||||||
|
w: false,
|
||||||
|
wh: "GET,POST",
|
||||||
|
},
|
||||||
|
// Empty accepted methods no good
|
||||||
|
{
|
||||||
|
m: "GET",
|
||||||
|
ms: []string{""},
|
||||||
|
w: false,
|
||||||
|
wh: "",
|
||||||
|
},
|
||||||
|
// No methods accepted
|
||||||
|
{
|
||||||
|
m: "GET",
|
||||||
|
ms: []string{},
|
||||||
|
w: false,
|
||||||
|
wh: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
rw := httptest.NewRecorder()
|
||||||
|
g := allowMethod(rw, tt.m, tt.ms...)
|
||||||
|
if g != tt.w {
|
||||||
|
t.Errorf("#%d: got allowMethod()=%t, want %t", i, g, tt.w)
|
||||||
|
}
|
||||||
|
if !tt.w {
|
||||||
|
if rw.Code != http.StatusMethodNotAllowed {
|
||||||
|
t.Errorf("#%d: code=%d, want %d", i, rw.Code, http.StatusMethodNotAllowed)
|
||||||
|
}
|
||||||
|
gh := rw.Header().Get("Allow")
|
||||||
|
if gh != tt.wh {
|
||||||
|
t.Errorf("#%d: Allow header=%q, want %q", i, gh, tt.wh)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user