etcdhttp: perform validation of query parameters

Add basic input validation of all query parameters supported by
serveKeys. Also restructures etcdhttp a bit to better facilitate
testing.

Test coverage is slightly improved.
This commit is contained in:
Jonathan Boulle
2014-09-10 12:00:20 -07:00
parent 3d272c2686
commit e736a11ac4
2 changed files with 220 additions and 117 deletions

View File

@@ -20,7 +20,7 @@ import (
"math/rand"
"github.com/coreos/etcd/elog"
etcderrors "github.com/coreos/etcd/error"
etcdErr "github.com/coreos/etcd/error"
"github.com/coreos/etcd/etcdserver"
"github.com/coreos/etcd/etcdserver/etcdserverpb"
"github.com/coreos/etcd/raft/raftpb"
@@ -33,6 +33,8 @@ const (
machinesPrefix = "/v2/machines"
)
var emptyReq = etcdserverpb.Request{}
type Peers map[int64][]string
func (ps Peers) Pick(id int64) string {
@@ -178,28 +180,32 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func (h Handler) serveKeys(ctx context.Context, w http.ResponseWriter, r *http.Request) {
rr, err := parseRequest(r, genId())
if err != nil {
log.Println(err) // reading of body failed
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
resp, err := h.Server.Do(ctx, rr)
switch e := err.(type) {
case nil:
case *etcderrors.Error:
// TODO: gross. this should be handled in encodeResponse
log.Println(err)
e.Write(w)
return
default:
log.Println(err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
if err != nil {
writeInternalError(w, err)
return
}
if err := encodeResponse(ctx, w, resp); err != nil {
http.Error(w, "Timeout while waiting for response", http.StatusGatewayTimeout)
var ev *store.Event
switch {
case resp.Event != nil:
ev = resp.Event
case resp.Watcher != nil:
ev, err = waitForEvent(ctx, w, resp.Watcher)
if err != nil {
http.Error(w, err.Error(), http.StatusGatewayTimeout)
return
}
default:
writeInternalError(w, errors.New("received response with no Event/Watcher!"))
return
}
writeEvent(w, ev)
}
// serveMachines responds address list in the format '0.0.0.0, 1.1.1.1'.
@@ -249,38 +255,60 @@ func genId() int64 {
}
func parseRequest(r *http.Request, id int64) (etcdserverpb.Request, error) {
if err := r.ParseForm(); err != nil {
return etcdserverpb.Request{}, err
}
if !strings.HasPrefix(r.URL.Path, keysPrefix) {
return etcdserverpb.Request{}, errors.New("unexpected key prefix!")
var err error
if err = r.ParseForm(); err != nil {
return emptyReq, err
}
if !strings.HasPrefix(r.URL.Path, keysPrefix) {
return emptyReq, errors.New("unexpected key prefix!")
}
path := r.URL.Path[len(keysPrefix):]
q := r.URL.Query()
// TODO(jonboulle): perform strict validation of all parameters
// https://github.com/coreos/etcd/issues/1011
var pIdx, wIdx, ttl uint64
if pIdx, err = parseUint64(q.Get("prevIndex")); err != nil {
return emptyReq, errors.New("invalid value for prevIndex")
}
if wIdx, err = parseUint64(q.Get("waitIndex")); err != nil {
return emptyReq, errors.New("invalid value for waitIndex")
}
if ttl, err = parseUint64(q.Get("ttl")); err != nil {
return emptyReq, errors.New("invalid value for ttl")
}
var rec, sort, wait bool
if rec, err = parseBool(q.Get("recursive")); err != nil {
return emptyReq, errors.New("invalid value for recursive")
}
if sort, err = parseBool(q.Get("sorted")); err != nil {
return emptyReq, errors.New("invalid value for sorted")
}
if wait, err = parseBool(q.Get("wait")); err != nil {
return emptyReq, errors.New("invalid value for wait")
}
rr := etcdserverpb.Request{
Id: id,
Method: r.Method,
Val: r.FormValue("value"),
Path: r.URL.Path[len(keysPrefix):],
Path: path,
PrevValue: q.Get("prevValue"),
PrevIndex: parseUint64(q.Get("prevIndex")),
Recursive: parseBool(q.Get("recursive")),
Since: parseUint64(q.Get("waitIndex")),
Sorted: parseBool(q.Get("sorted")),
Wait: parseBool(q.Get("wait")),
PrevIndex: pIdx,
Recursive: rec,
Since: wIdx,
Sorted: sort,
Wait: wait,
}
// PrevExists is nullable, so we leave it null if prevExist wasn't
// specified.
_, ok := q["prevExists"]
if ok {
bv := parseBool(q.Get("prevExists"))
// prevExists is nullable, so leave it null if not specified
if _, ok := q["prevExists"]; ok {
bv, _ := parseBool(q.Get("prevExists"))
rr.PrevExists = &bv
}
ttl := parseUint64(q.Get("ttl"))
if ttl > 0 {
expr := time.Duration(ttl) * time.Second
// TODO(jonboulle): use fake clock instead of time module
@@ -291,32 +319,40 @@ func parseRequest(r *http.Request, id int64) (etcdserverpb.Request, error) {
return rr, nil
}
func parseBool(s string) bool {
v, _ := strconv.ParseBool(s)
return v
}
func parseUint64(s string) uint64 {
v, _ := strconv.ParseUint(s, 10, 64)
return v
}
// encodeResponse serializes the given etcdserver Response and writes the
// resulting JSON to the given ResponseWriter, utilizing the provided context
func encodeResponse(ctx context.Context, w http.ResponseWriter, resp etcdserver.Response) (err error) {
var ev *store.Event
switch {
case resp.Event != nil:
ev = resp.Event
case resp.Watcher != nil:
ev, err = waitForEvent(ctx, w, resp.Watcher)
if err != nil {
return err
}
default:
panic("should not be reachable")
func parseBool(s string) (bool, error) {
if s == "" {
return false, nil
}
return strconv.ParseBool(s)
}
func parseUint64(s string) (uint64, error) {
if s == "" {
return 0, nil
}
return strconv.ParseUint(s, 10, 64)
}
// writeInternalError logs and writes the given Error to the ResponseWriter
// If Error is an etcdErr, it is rendered to the ResponseWriter
func writeInternalError(w http.ResponseWriter, err error) {
if err == nil {
return
}
log.Println(err)
if e, ok := err.(*etcdErr.Error); ok {
e.Write(w)
} else {
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}
}
// writeEvent serializes the given Event and writes the resulting JSON to the
// given ResponseWriter
func writeEvent(w http.ResponseWriter, ev *store.Event) {
if ev == nil {
return
}
w.Header().Set("Content-Type", "application/json")
w.Header().Add("X-Etcd-Index", fmt.Sprint(ev.Index()))
@@ -327,10 +363,9 @@ func encodeResponse(ctx context.Context, w http.ResponseWriter, resp etcdserver.
if err := json.NewEncoder(w).Encode(ev); err != nil {
panic(err) // should never be reached
}
return nil
}
// waitForEvent waits for a given watcher to return its associated
// waitForEvent waits for a given Watcher to return its associated
// event. It returns a non-nil error if the given Context times out
// or the given ResponseWriter triggers a CloseNotify.
func waitForEvent(ctx context.Context, w http.ResponseWriter, wa store.Watcher) (*store.Event, error) {
@@ -340,7 +375,6 @@ func waitForEvent(ctx context.Context, w http.ResponseWriter, wa store.Watcher)
if x, ok := w.(http.CloseNotifier); ok {
nch = x.CloseNotify()
}
select {
case ev := <-wa.EventChan():
return ev, nil

View File

@@ -1,6 +1,7 @@
package etcdhttp
import (
"errors"
"net/http"
"net/http/httptest"
"net/url"
@@ -9,7 +10,7 @@ import (
"sync"
"testing"
"github.com/coreos/etcd/etcdserver"
etcdErr "github.com/coreos/etcd/error"
"github.com/coreos/etcd/etcdserver/etcdserverpb"
"github.com/coreos/etcd/store"
"github.com/coreos/etcd/third_party/code.google.com/p/go.net/context"
@@ -25,6 +26,12 @@ func mustNewURL(t *testing.T, s string) *url.URL {
return u
}
func mustNewRequest(t *testing.T, p string) *http.Request {
return &http.Request{
URL: mustNewURL(t, path.Join(keysPrefix, p)),
}
}
func TestBadParseRequest(t *testing.T) {
tests := []struct {
in *http.Request
@@ -42,6 +49,47 @@ func TestBadParseRequest(t *testing.T) {
URL: mustNewURL(t, "/badprefix/"),
},
},
// bad values for prevIndex, waitIndex, ttl
{
mustNewRequest(t, "?prevIndex=foo"),
},
{
mustNewRequest(t, "?prevIndex=1.5"),
},
{
mustNewRequest(t, "?prevIndex=-1"),
},
{
mustNewRequest(t, "?waitIndex=garbage"),
},
{
mustNewRequest(t, "?waitIndex=??"),
},
{
mustNewRequest(t, "?ttl=-1"),
},
// bad values for recursive, sorted, wait
{
mustNewRequest(t, "?recursive=hahaha"),
},
{
mustNewRequest(t, "?recursive=1234"),
},
{
mustNewRequest(t, "?recursive=?"),
},
{
mustNewRequest(t, "?sorted=hahaha"),
},
{
mustNewRequest(t, "?sorted=!!"),
},
{
mustNewRequest(t, "?wait=notreally"),
},
{
mustNewRequest(t, "?wait=what!"),
},
}
for i, tt := range tests {
got, err := parseRequest(tt.in, 1234)
@@ -61,9 +109,7 @@ func TestGoodParseRequest(t *testing.T) {
}{
{
// good prefix, all other values default
&http.Request{
URL: mustNewURL(t, path.Join(keysPrefix, "foo")),
},
mustNewRequest(t, "foo"),
etcdserverpb.Request{
Id: 1234,
Path: "/foo",
@@ -71,9 +117,7 @@ func TestGoodParseRequest(t *testing.T) {
},
{
// value specified
&http.Request{
URL: mustNewURL(t, path.Join(keysPrefix, "foo?value=some_value")),
},
mustNewRequest(t, "foo?value=some_value"),
etcdserverpb.Request{
Id: 1234,
Val: "some_value",
@@ -82,9 +126,7 @@ func TestGoodParseRequest(t *testing.T) {
},
{
// prevIndex specified
&http.Request{
URL: mustNewURL(t, path.Join(keysPrefix, "foo?prevIndex=98765")),
},
mustNewRequest(t, "foo?prevIndex=98765"),
etcdserverpb.Request{
Id: 1234,
PrevIndex: 98765,
@@ -93,9 +135,7 @@ func TestGoodParseRequest(t *testing.T) {
},
{
// recursive specified
&http.Request{
URL: mustNewURL(t, path.Join(keysPrefix, "foo?recursive=true")),
},
mustNewRequest(t, "foo?recursive=true"),
etcdserverpb.Request{
Id: 1234,
Recursive: true,
@@ -104,9 +144,7 @@ func TestGoodParseRequest(t *testing.T) {
},
{
// sorted specified
&http.Request{
URL: mustNewURL(t, path.Join(keysPrefix, "foo?sorted=true")),
},
mustNewRequest(t, "foo?sorted=true"),
etcdserverpb.Request{
Id: 1234,
Sorted: true,
@@ -115,9 +153,7 @@ func TestGoodParseRequest(t *testing.T) {
},
{
// wait specified
&http.Request{
URL: mustNewURL(t, path.Join(keysPrefix, "foo?wait=true")),
},
mustNewRequest(t, "foo?wait=true"),
etcdserverpb.Request{
Id: 1234,
Wait: true,
@@ -126,9 +162,7 @@ func TestGoodParseRequest(t *testing.T) {
},
{
// prevExists should be non-null if specified
&http.Request{
URL: mustNewURL(t, path.Join(keysPrefix, "foo?prevExists=true")),
},
mustNewRequest(t, "foo?prevExists=true"),
etcdserverpb.Request{
Id: 1234,
PrevExists: boolp(true),
@@ -137,9 +171,7 @@ func TestGoodParseRequest(t *testing.T) {
},
{
// prevExists should be non-null if specified
&http.Request{
URL: mustNewURL(t, path.Join(keysPrefix, "foo?prevExists=false")),
},
mustNewRequest(t, "foo?prevExists=false"),
etcdserverpb.Request{
Id: 1234,
PrevExists: boolp(false),
@@ -177,22 +209,77 @@ func (w *eventingWatcher) EventChan() chan *store.Event {
func (w *eventingWatcher) Remove() {}
func TestEncodeResponse(t *testing.T) {
func TestWriteInternalError(t *testing.T) {
// nil error should not panic
rw := httptest.NewRecorder()
writeInternalError(rw, nil)
h := rw.Header()
if len(h) > 0 {
t.Fatalf("unexpected non-empty headers: %#v", h)
}
b := rw.Body.String()
if len(b) > 0 {
t.Fatalf("unexpected non-empty body: %q", b)
}
tests := []struct {
resp etcdserver.Response
err error
code int
idx string
}{
{
etcdErr.NewError(etcdErr.EcodeKeyNotFound, "/foo/bar", 123),
http.StatusNotFound,
"123",
},
{
etcdErr.NewError(etcdErr.EcodeTestFailed, "/foo/bar", 456),
http.StatusPreconditionFailed,
"456",
},
{
err: errors.New("something went wrong"),
code: http.StatusInternalServerError,
},
}
for i, tt := range tests {
rw := httptest.NewRecorder()
writeInternalError(rw, tt.err)
if code := rw.Code; code != tt.code {
t.Errorf("#%d: got %d, want %d", i, code, tt.code)
}
if idx := rw.Header().Get("X-Etcd-Index"); idx != tt.idx {
t.Errorf("#%d: got %q, want %q", i, idx, tt.idx)
}
}
}
func TestWriteEvent(t *testing.T) {
// nil event should not panic
rw := httptest.NewRecorder()
writeEvent(rw, nil)
h := rw.Header()
if len(h) > 0 {
t.Fatalf("unexpected non-empty headers: %#v", h)
}
b := rw.Body.String()
if len(b) > 0 {
t.Fatalf("unexpected non-empty body: %q", b)
}
tests := []struct {
ev *store.Event
idx string
code int
err error
}{
// standard case, standard 200 response
{
etcdserver.Response{
Event: &store.Event{
Action: store.Get,
Node: &store.NodeExtern{},
PrevNode: &store.NodeExtern{},
},
Watcher: nil,
&store.Event{
Action: store.Get,
Node: &store.NodeExtern{},
PrevNode: &store.NodeExtern{},
},
"0",
http.StatusOK,
@@ -200,21 +287,10 @@ func TestEncodeResponse(t *testing.T) {
},
// check new nodes return StatusCreated
{
etcdserver.Response{
Event: &store.Event{
Action: store.Create,
Node: &store.NodeExtern{},
PrevNode: &store.NodeExtern{},
},
Watcher: nil,
},
"0",
http.StatusCreated,
nil,
},
{
etcdserver.Response{
Watcher: &eventingWatcher{store.Create},
&store.Event{
Action: store.Create,
Node: &store.NodeExtern{},
PrevNode: &store.NodeExtern{},
},
"0",
http.StatusCreated,
@@ -224,20 +300,13 @@ func TestEncodeResponse(t *testing.T) {
for i, tt := range tests {
rw := httptest.NewRecorder()
err := encodeResponse(context.Background(), rw, tt.resp)
if err != tt.err {
t.Errorf("case %d: unexpected err: got %v, want %v", i, err, tt.err)
continue
}
writeEvent(rw, tt.ev)
if gct := rw.Header().Get("Content-Type"); gct != "application/json" {
t.Errorf("case %d: bad Content-Type: got %q, want application/json", i, gct)
}
if gei := rw.Header().Get("X-Etcd-Index"); gei != tt.idx {
t.Errorf("case %d: bad X-Etcd-Index header: got %s, want %s", i, gei, tt.idx)
}
if rw.Code != tt.code {
t.Errorf("case %d: bad response code: got %d, want %v", i, rw.Code, tt.code)
}