diff --git a/http/query_params.go b/http/query_params.go new file mode 100644 index 000000000..5746d2430 --- /dev/null +++ b/http/query_params.go @@ -0,0 +1,36 @@ +package http + +import ( + "net/http" + "strings" +) + +func NewLowerQueryParamsHandler(hdlr http.Handler) *LowerQueryParamsHandler { + return &LowerQueryParamsHandler{hdlr} +} + +type LowerQueryParamsHandler struct { + Handler http.Handler +} + +func (h *LowerQueryParamsHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + err := req.ParseForm() + if err == nil { + lowerBoolQueryParams(req) + } + h.Handler.ServeHTTP(w, req) +} + +func lowerBoolQueryParams(req *http.Request) { + form := req.Form + for key, vals := range form { + for i, val := range vals { + lowered := strings.ToLower(val) + if lowered == "true" || lowered == "false" { + req.Form[key][i] = lowered + } else { + req.Form[key][i] = val + } + } + } +} diff --git a/http/query_params_test.go b/http/query_params_test.go new file mode 100644 index 000000000..1920ee134 --- /dev/null +++ b/http/query_params_test.go @@ -0,0 +1,46 @@ +package http + +import ( + "net/http" + "testing" +) + +type NilResponseWriter struct{} + +func (w NilResponseWriter) Header() http.Header { + return http.Header{} +} + +func (w NilResponseWriter) Write(data []byte) (int, error) { + return 0, nil +} + +func (w NilResponseWriter) WriteHeader(code int) { + return +} + +type FunctionHandler struct { + f func(*http.Request) +} + +func (h FunctionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + h.f(r) +} + +func TestQueryParamsLowered(t *testing.T) { + assertFunc := func(req *http.Request) { + if len(req.Form["One"]) != 1 || req.Form["One"][0] != "true" { + t.Errorf("Unexpected value for One: %s", req.Form["One"]) + } else if len(req.Form["TWO"]) != 1 || req.Form["TWO"][0] != "false" { + t.Errorf("Unexpected value for TWO") + } else if len(req.Form["three"]) != 2 || req.Form["three"][0] != "true" || req.Form["three"][1] != "false" { + t.Errorf("Unexpected value for three") + } + } + assertHdlr := FunctionHandler{assertFunc} + hdlr := NewLowerQueryParamsHandler(assertHdlr) + respWriter := NilResponseWriter{} + + req, _ := http.NewRequest("GET", "http://example.com?One=TRUE&TWO=False&three=true&three=FALSE", nil) + hdlr.ServeHTTP(respWriter, req) +} diff --git a/server/server.go b/server/server.go index eb6539f54..a660857ed 100644 --- a/server/server.go +++ b/server/server.go @@ -15,6 +15,7 @@ import ( "github.com/coreos/etcd/log" "github.com/coreos/etcd/metrics" "github.com/coreos/etcd/mod" + ehttp "github.com/coreos/etcd/http" uhttp "github.com/coreos/etcd/pkg/http" "github.com/coreos/etcd/server/v1" "github.com/coreos/etcd/server/v2" @@ -107,17 +108,20 @@ func (s *Server) installV1(r *mux.Router) { } func (s *Server) installV2(r *mux.Router) { - s.handleFuncV2(r, "/v2/keys/{key:.*}", v2.GetHandler).Methods("GET") - s.handleFuncV2(r, "/v2/keys/{key:.*}", v2.PostHandler).Methods("POST") - s.handleFuncV2(r, "/v2/keys/{key:.*}", v2.PutHandler).Methods("PUT") - s.handleFuncV2(r, "/v2/keys/{key:.*}", v2.DeleteHandler).Methods("DELETE") - s.handleFunc(r, "/v2/leader", s.GetLeaderHandler).Methods("GET") - s.handleFunc(r, "/v2/machines", s.GetPeersHandler).Methods("GET") - s.handleFunc(r, "/v2/peers", s.GetPeersHandler).Methods("GET") - s.handleFunc(r, "/v2/stats/self", s.GetStatsHandler).Methods("GET") - s.handleFunc(r, "/v2/stats/leader", s.GetLeaderStatsHandler).Methods("GET") - s.handleFunc(r, "/v2/stats/store", s.GetStoreStatsHandler).Methods("GET") - s.handleFunc(r, "/v2/speedTest", s.SpeedTestHandler).Methods("GET") + r2 := mux.NewRouter() + r.PathPrefix("/v2").Handler(ehttp.NewLowerQueryParamsHandler(r2)) + + s.handleFuncV2(r2, "/v2/keys/{key:.*}", v2.GetHandler).Methods("GET") + s.handleFuncV2(r2, "/v2/keys/{key:.*}", v2.PostHandler).Methods("POST") + s.handleFuncV2(r2, "/v2/keys/{key:.*}", v2.PutHandler).Methods("PUT") + s.handleFuncV2(r2, "/v2/keys/{key:.*}", v2.DeleteHandler).Methods("DELETE") + s.handleFunc(r2, "/v2/leader", s.GetLeaderHandler).Methods("GET") + s.handleFunc(r2, "/v2/machines", s.GetPeersHandler).Methods("GET") + s.handleFunc(r2, "/v2/peers", s.GetPeersHandler).Methods("GET") + s.handleFunc(r2, "/v2/stats/self", s.GetStatsHandler).Methods("GET") + s.handleFunc(r2, "/v2/stats/leader", s.GetLeaderStatsHandler).Methods("GET") + s.handleFunc(r2, "/v2/stats/store", s.GetStoreStatsHandler).Methods("GET") + s.handleFunc(r2, "/v2/speedTest", s.SpeedTestHandler).Methods("GET") } func (s *Server) installMod(r *mux.Router) { diff --git a/test.sh b/test.sh index 0dbc2cd38..c6551cffa 100755 --- a/test.sh +++ b/test.sh @@ -2,6 +2,9 @@ . ./build +go test -i ./http +go test -v ./http + go test -i ./store go test -v ./store