mirror of
https://github.com/etcd-io/etcd.git
synced 2024-09-27 06:25:44 +00:00
etcdhttp: perform validation of query parameters
Add basic input validation of all query parameters supported by serveKeys. Also restructures etcdhttp a bit to better facilitate testing. Test coverage is slightly improved.
This commit is contained in:
@@ -20,7 +20,7 @@ import (
|
||||
"math/rand"
|
||||
|
||||
"github.com/coreos/etcd/elog"
|
||||
etcderrors "github.com/coreos/etcd/error"
|
||||
etcdErr "github.com/coreos/etcd/error"
|
||||
"github.com/coreos/etcd/etcdserver"
|
||||
"github.com/coreos/etcd/etcdserver/etcdserverpb"
|
||||
"github.com/coreos/etcd/raft/raftpb"
|
||||
@@ -33,6 +33,8 @@ const (
|
||||
machinesPrefix = "/v2/machines"
|
||||
)
|
||||
|
||||
var emptyReq = etcdserverpb.Request{}
|
||||
|
||||
type Peers map[int64][]string
|
||||
|
||||
func (ps Peers) Pick(id int64) string {
|
||||
@@ -178,28 +180,32 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
func (h Handler) serveKeys(ctx context.Context, w http.ResponseWriter, r *http.Request) {
|
||||
rr, err := parseRequest(r, genId())
|
||||
if err != nil {
|
||||
log.Println(err) // reading of body failed
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := h.Server.Do(ctx, rr)
|
||||
switch e := err.(type) {
|
||||
case nil:
|
||||
case *etcderrors.Error:
|
||||
// TODO: gross. this should be handled in encodeResponse
|
||||
log.Println(err)
|
||||
e.Write(w)
|
||||
return
|
||||
default:
|
||||
log.Println(err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
if err != nil {
|
||||
writeInternalError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := encodeResponse(ctx, w, resp); err != nil {
|
||||
http.Error(w, "Timeout while waiting for response", http.StatusGatewayTimeout)
|
||||
var ev *store.Event
|
||||
switch {
|
||||
case resp.Event != nil:
|
||||
ev = resp.Event
|
||||
case resp.Watcher != nil:
|
||||
ev, err = waitForEvent(ctx, w, resp.Watcher)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusGatewayTimeout)
|
||||
return
|
||||
}
|
||||
default:
|
||||
writeInternalError(w, errors.New("received response with no Event/Watcher!"))
|
||||
return
|
||||
}
|
||||
|
||||
writeEvent(w, ev)
|
||||
}
|
||||
|
||||
// serveMachines responds address list in the format '0.0.0.0, 1.1.1.1'.
|
||||
@@ -249,38 +255,60 @@ func genId() int64 {
|
||||
}
|
||||
|
||||
func parseRequest(r *http.Request, id int64) (etcdserverpb.Request, error) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
return etcdserverpb.Request{}, err
|
||||
}
|
||||
if !strings.HasPrefix(r.URL.Path, keysPrefix) {
|
||||
return etcdserverpb.Request{}, errors.New("unexpected key prefix!")
|
||||
var err error
|
||||
|
||||
if err = r.ParseForm(); err != nil {
|
||||
return emptyReq, err
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(r.URL.Path, keysPrefix) {
|
||||
return emptyReq, errors.New("unexpected key prefix!")
|
||||
}
|
||||
path := r.URL.Path[len(keysPrefix):]
|
||||
|
||||
q := r.URL.Query()
|
||||
// TODO(jonboulle): perform strict validation of all parameters
|
||||
// https://github.com/coreos/etcd/issues/1011
|
||||
|
||||
var pIdx, wIdx, ttl uint64
|
||||
if pIdx, err = parseUint64(q.Get("prevIndex")); err != nil {
|
||||
return emptyReq, errors.New("invalid value for prevIndex")
|
||||
}
|
||||
if wIdx, err = parseUint64(q.Get("waitIndex")); err != nil {
|
||||
return emptyReq, errors.New("invalid value for waitIndex")
|
||||
}
|
||||
if ttl, err = parseUint64(q.Get("ttl")); err != nil {
|
||||
return emptyReq, errors.New("invalid value for ttl")
|
||||
}
|
||||
|
||||
var rec, sort, wait bool
|
||||
if rec, err = parseBool(q.Get("recursive")); err != nil {
|
||||
return emptyReq, errors.New("invalid value for recursive")
|
||||
}
|
||||
if sort, err = parseBool(q.Get("sorted")); err != nil {
|
||||
return emptyReq, errors.New("invalid value for sorted")
|
||||
}
|
||||
if wait, err = parseBool(q.Get("wait")); err != nil {
|
||||
return emptyReq, errors.New("invalid value for wait")
|
||||
}
|
||||
|
||||
rr := etcdserverpb.Request{
|
||||
Id: id,
|
||||
Method: r.Method,
|
||||
Val: r.FormValue("value"),
|
||||
Path: r.URL.Path[len(keysPrefix):],
|
||||
Path: path,
|
||||
PrevValue: q.Get("prevValue"),
|
||||
PrevIndex: parseUint64(q.Get("prevIndex")),
|
||||
Recursive: parseBool(q.Get("recursive")),
|
||||
Since: parseUint64(q.Get("waitIndex")),
|
||||
Sorted: parseBool(q.Get("sorted")),
|
||||
Wait: parseBool(q.Get("wait")),
|
||||
PrevIndex: pIdx,
|
||||
Recursive: rec,
|
||||
Since: wIdx,
|
||||
Sorted: sort,
|
||||
Wait: wait,
|
||||
}
|
||||
|
||||
// PrevExists is nullable, so we leave it null if prevExist wasn't
|
||||
// specified.
|
||||
_, ok := q["prevExists"]
|
||||
if ok {
|
||||
bv := parseBool(q.Get("prevExists"))
|
||||
// prevExists is nullable, so leave it null if not specified
|
||||
if _, ok := q["prevExists"]; ok {
|
||||
bv, _ := parseBool(q.Get("prevExists"))
|
||||
rr.PrevExists = &bv
|
||||
}
|
||||
|
||||
ttl := parseUint64(q.Get("ttl"))
|
||||
if ttl > 0 {
|
||||
expr := time.Duration(ttl) * time.Second
|
||||
// TODO(jonboulle): use fake clock instead of time module
|
||||
@@ -291,32 +319,40 @@ func parseRequest(r *http.Request, id int64) (etcdserverpb.Request, error) {
|
||||
return rr, nil
|
||||
}
|
||||
|
||||
func parseBool(s string) bool {
|
||||
v, _ := strconv.ParseBool(s)
|
||||
return v
|
||||
}
|
||||
|
||||
func parseUint64(s string) uint64 {
|
||||
v, _ := strconv.ParseUint(s, 10, 64)
|
||||
return v
|
||||
}
|
||||
|
||||
// encodeResponse serializes the given etcdserver Response and writes the
|
||||
// resulting JSON to the given ResponseWriter, utilizing the provided context
|
||||
func encodeResponse(ctx context.Context, w http.ResponseWriter, resp etcdserver.Response) (err error) {
|
||||
var ev *store.Event
|
||||
switch {
|
||||
case resp.Event != nil:
|
||||
ev = resp.Event
|
||||
case resp.Watcher != nil:
|
||||
ev, err = waitForEvent(ctx, w, resp.Watcher)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
panic("should not be reachable")
|
||||
func parseBool(s string) (bool, error) {
|
||||
if s == "" {
|
||||
return false, nil
|
||||
}
|
||||
return strconv.ParseBool(s)
|
||||
}
|
||||
|
||||
func parseUint64(s string) (uint64, error) {
|
||||
if s == "" {
|
||||
return 0, nil
|
||||
}
|
||||
return strconv.ParseUint(s, 10, 64)
|
||||
}
|
||||
|
||||
// writeInternalError logs and writes the given Error to the ResponseWriter
|
||||
// If Error is an etcdErr, it is rendered to the ResponseWriter
|
||||
func writeInternalError(w http.ResponseWriter, err error) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
log.Println(err)
|
||||
if e, ok := err.(*etcdErr.Error); ok {
|
||||
e.Write(w)
|
||||
} else {
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
// writeEvent serializes the given Event and writes the resulting JSON to the
|
||||
// given ResponseWriter
|
||||
func writeEvent(w http.ResponseWriter, ev *store.Event) {
|
||||
if ev == nil {
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Add("X-Etcd-Index", fmt.Sprint(ev.Index()))
|
||||
|
||||
@@ -327,10 +363,9 @@ func encodeResponse(ctx context.Context, w http.ResponseWriter, resp etcdserver.
|
||||
if err := json.NewEncoder(w).Encode(ev); err != nil {
|
||||
panic(err) // should never be reached
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// waitForEvent waits for a given watcher to return its associated
|
||||
// waitForEvent waits for a given Watcher to return its associated
|
||||
// event. It returns a non-nil error if the given Context times out
|
||||
// or the given ResponseWriter triggers a CloseNotify.
|
||||
func waitForEvent(ctx context.Context, w http.ResponseWriter, wa store.Watcher) (*store.Event, error) {
|
||||
@@ -340,7 +375,6 @@ func waitForEvent(ctx context.Context, w http.ResponseWriter, wa store.Watcher)
|
||||
if x, ok := w.(http.CloseNotifier); ok {
|
||||
nch = x.CloseNotify()
|
||||
}
|
||||
|
||||
select {
|
||||
case ev := <-wa.EventChan():
|
||||
return ev, nil
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package etcdhttp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
@@ -9,7 +10,7 @@ import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/etcd/etcdserver"
|
||||
etcdErr "github.com/coreos/etcd/error"
|
||||
"github.com/coreos/etcd/etcdserver/etcdserverpb"
|
||||
"github.com/coreos/etcd/store"
|
||||
"github.com/coreos/etcd/third_party/code.google.com/p/go.net/context"
|
||||
@@ -25,6 +26,12 @@ func mustNewURL(t *testing.T, s string) *url.URL {
|
||||
return u
|
||||
}
|
||||
|
||||
func mustNewRequest(t *testing.T, p string) *http.Request {
|
||||
return &http.Request{
|
||||
URL: mustNewURL(t, path.Join(keysPrefix, p)),
|
||||
}
|
||||
}
|
||||
|
||||
func TestBadParseRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
in *http.Request
|
||||
@@ -42,6 +49,47 @@ func TestBadParseRequest(t *testing.T) {
|
||||
URL: mustNewURL(t, "/badprefix/"),
|
||||
},
|
||||
},
|
||||
// bad values for prevIndex, waitIndex, ttl
|
||||
{
|
||||
mustNewRequest(t, "?prevIndex=foo"),
|
||||
},
|
||||
{
|
||||
mustNewRequest(t, "?prevIndex=1.5"),
|
||||
},
|
||||
{
|
||||
mustNewRequest(t, "?prevIndex=-1"),
|
||||
},
|
||||
{
|
||||
mustNewRequest(t, "?waitIndex=garbage"),
|
||||
},
|
||||
{
|
||||
mustNewRequest(t, "?waitIndex=??"),
|
||||
},
|
||||
{
|
||||
mustNewRequest(t, "?ttl=-1"),
|
||||
},
|
||||
// bad values for recursive, sorted, wait
|
||||
{
|
||||
mustNewRequest(t, "?recursive=hahaha"),
|
||||
},
|
||||
{
|
||||
mustNewRequest(t, "?recursive=1234"),
|
||||
},
|
||||
{
|
||||
mustNewRequest(t, "?recursive=?"),
|
||||
},
|
||||
{
|
||||
mustNewRequest(t, "?sorted=hahaha"),
|
||||
},
|
||||
{
|
||||
mustNewRequest(t, "?sorted=!!"),
|
||||
},
|
||||
{
|
||||
mustNewRequest(t, "?wait=notreally"),
|
||||
},
|
||||
{
|
||||
mustNewRequest(t, "?wait=what!"),
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
got, err := parseRequest(tt.in, 1234)
|
||||
@@ -61,9 +109,7 @@ func TestGoodParseRequest(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
// good prefix, all other values default
|
||||
&http.Request{
|
||||
URL: mustNewURL(t, path.Join(keysPrefix, "foo")),
|
||||
},
|
||||
mustNewRequest(t, "foo"),
|
||||
etcdserverpb.Request{
|
||||
Id: 1234,
|
||||
Path: "/foo",
|
||||
@@ -71,9 +117,7 @@ func TestGoodParseRequest(t *testing.T) {
|
||||
},
|
||||
{
|
||||
// value specified
|
||||
&http.Request{
|
||||
URL: mustNewURL(t, path.Join(keysPrefix, "foo?value=some_value")),
|
||||
},
|
||||
mustNewRequest(t, "foo?value=some_value"),
|
||||
etcdserverpb.Request{
|
||||
Id: 1234,
|
||||
Val: "some_value",
|
||||
@@ -82,9 +126,7 @@ func TestGoodParseRequest(t *testing.T) {
|
||||
},
|
||||
{
|
||||
// prevIndex specified
|
||||
&http.Request{
|
||||
URL: mustNewURL(t, path.Join(keysPrefix, "foo?prevIndex=98765")),
|
||||
},
|
||||
mustNewRequest(t, "foo?prevIndex=98765"),
|
||||
etcdserverpb.Request{
|
||||
Id: 1234,
|
||||
PrevIndex: 98765,
|
||||
@@ -93,9 +135,7 @@ func TestGoodParseRequest(t *testing.T) {
|
||||
},
|
||||
{
|
||||
// recursive specified
|
||||
&http.Request{
|
||||
URL: mustNewURL(t, path.Join(keysPrefix, "foo?recursive=true")),
|
||||
},
|
||||
mustNewRequest(t, "foo?recursive=true"),
|
||||
etcdserverpb.Request{
|
||||
Id: 1234,
|
||||
Recursive: true,
|
||||
@@ -104,9 +144,7 @@ func TestGoodParseRequest(t *testing.T) {
|
||||
},
|
||||
{
|
||||
// sorted specified
|
||||
&http.Request{
|
||||
URL: mustNewURL(t, path.Join(keysPrefix, "foo?sorted=true")),
|
||||
},
|
||||
mustNewRequest(t, "foo?sorted=true"),
|
||||
etcdserverpb.Request{
|
||||
Id: 1234,
|
||||
Sorted: true,
|
||||
@@ -115,9 +153,7 @@ func TestGoodParseRequest(t *testing.T) {
|
||||
},
|
||||
{
|
||||
// wait specified
|
||||
&http.Request{
|
||||
URL: mustNewURL(t, path.Join(keysPrefix, "foo?wait=true")),
|
||||
},
|
||||
mustNewRequest(t, "foo?wait=true"),
|
||||
etcdserverpb.Request{
|
||||
Id: 1234,
|
||||
Wait: true,
|
||||
@@ -126,9 +162,7 @@ func TestGoodParseRequest(t *testing.T) {
|
||||
},
|
||||
{
|
||||
// prevExists should be non-null if specified
|
||||
&http.Request{
|
||||
URL: mustNewURL(t, path.Join(keysPrefix, "foo?prevExists=true")),
|
||||
},
|
||||
mustNewRequest(t, "foo?prevExists=true"),
|
||||
etcdserverpb.Request{
|
||||
Id: 1234,
|
||||
PrevExists: boolp(true),
|
||||
@@ -137,9 +171,7 @@ func TestGoodParseRequest(t *testing.T) {
|
||||
},
|
||||
{
|
||||
// prevExists should be non-null if specified
|
||||
&http.Request{
|
||||
URL: mustNewURL(t, path.Join(keysPrefix, "foo?prevExists=false")),
|
||||
},
|
||||
mustNewRequest(t, "foo?prevExists=false"),
|
||||
etcdserverpb.Request{
|
||||
Id: 1234,
|
||||
PrevExists: boolp(false),
|
||||
@@ -177,22 +209,77 @@ func (w *eventingWatcher) EventChan() chan *store.Event {
|
||||
|
||||
func (w *eventingWatcher) Remove() {}
|
||||
|
||||
func TestEncodeResponse(t *testing.T) {
|
||||
func TestWriteInternalError(t *testing.T) {
|
||||
// nil error should not panic
|
||||
rw := httptest.NewRecorder()
|
||||
writeInternalError(rw, nil)
|
||||
h := rw.Header()
|
||||
if len(h) > 0 {
|
||||
t.Fatalf("unexpected non-empty headers: %#v", h)
|
||||
}
|
||||
b := rw.Body.String()
|
||||
if len(b) > 0 {
|
||||
t.Fatalf("unexpected non-empty body: %q", b)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
resp etcdserver.Response
|
||||
err error
|
||||
code int
|
||||
idx string
|
||||
}{
|
||||
{
|
||||
etcdErr.NewError(etcdErr.EcodeKeyNotFound, "/foo/bar", 123),
|
||||
http.StatusNotFound,
|
||||
"123",
|
||||
},
|
||||
{
|
||||
etcdErr.NewError(etcdErr.EcodeTestFailed, "/foo/bar", 456),
|
||||
http.StatusPreconditionFailed,
|
||||
"456",
|
||||
},
|
||||
{
|
||||
err: errors.New("something went wrong"),
|
||||
code: http.StatusInternalServerError,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
rw := httptest.NewRecorder()
|
||||
writeInternalError(rw, tt.err)
|
||||
if code := rw.Code; code != tt.code {
|
||||
t.Errorf("#%d: got %d, want %d", i, code, tt.code)
|
||||
}
|
||||
if idx := rw.Header().Get("X-Etcd-Index"); idx != tt.idx {
|
||||
t.Errorf("#%d: got %q, want %q", i, idx, tt.idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteEvent(t *testing.T) {
|
||||
// nil event should not panic
|
||||
rw := httptest.NewRecorder()
|
||||
writeEvent(rw, nil)
|
||||
h := rw.Header()
|
||||
if len(h) > 0 {
|
||||
t.Fatalf("unexpected non-empty headers: %#v", h)
|
||||
}
|
||||
b := rw.Body.String()
|
||||
if len(b) > 0 {
|
||||
t.Fatalf("unexpected non-empty body: %q", b)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
ev *store.Event
|
||||
idx string
|
||||
code int
|
||||
err error
|
||||
}{
|
||||
// standard case, standard 200 response
|
||||
{
|
||||
etcdserver.Response{
|
||||
Event: &store.Event{
|
||||
Action: store.Get,
|
||||
Node: &store.NodeExtern{},
|
||||
PrevNode: &store.NodeExtern{},
|
||||
},
|
||||
Watcher: nil,
|
||||
&store.Event{
|
||||
Action: store.Get,
|
||||
Node: &store.NodeExtern{},
|
||||
PrevNode: &store.NodeExtern{},
|
||||
},
|
||||
"0",
|
||||
http.StatusOK,
|
||||
@@ -200,21 +287,10 @@ func TestEncodeResponse(t *testing.T) {
|
||||
},
|
||||
// check new nodes return StatusCreated
|
||||
{
|
||||
etcdserver.Response{
|
||||
Event: &store.Event{
|
||||
Action: store.Create,
|
||||
Node: &store.NodeExtern{},
|
||||
PrevNode: &store.NodeExtern{},
|
||||
},
|
||||
Watcher: nil,
|
||||
},
|
||||
"0",
|
||||
http.StatusCreated,
|
||||
nil,
|
||||
},
|
||||
{
|
||||
etcdserver.Response{
|
||||
Watcher: &eventingWatcher{store.Create},
|
||||
&store.Event{
|
||||
Action: store.Create,
|
||||
Node: &store.NodeExtern{},
|
||||
PrevNode: &store.NodeExtern{},
|
||||
},
|
||||
"0",
|
||||
http.StatusCreated,
|
||||
@@ -224,20 +300,13 @@ func TestEncodeResponse(t *testing.T) {
|
||||
|
||||
for i, tt := range tests {
|
||||
rw := httptest.NewRecorder()
|
||||
err := encodeResponse(context.Background(), rw, tt.resp)
|
||||
if err != tt.err {
|
||||
t.Errorf("case %d: unexpected err: got %v, want %v", i, err, tt.err)
|
||||
continue
|
||||
}
|
||||
|
||||
writeEvent(rw, tt.ev)
|
||||
if gct := rw.Header().Get("Content-Type"); gct != "application/json" {
|
||||
t.Errorf("case %d: bad Content-Type: got %q, want application/json", i, gct)
|
||||
}
|
||||
|
||||
if gei := rw.Header().Get("X-Etcd-Index"); gei != tt.idx {
|
||||
t.Errorf("case %d: bad X-Etcd-Index header: got %s, want %s", i, gei, tt.idx)
|
||||
}
|
||||
|
||||
if rw.Code != tt.code {
|
||||
t.Errorf("case %d: bad response code: got %d, want %v", i, rw.Code, tt.code)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user