Merge pull request #3579 from gyuho/etcdserver/etcdhttp/httptypes/errors.go-WriteTo-returns-error

httptypes: WriteTo to return error
This commit is contained in:
Xiang Li 2015-09-25 14:31:48 -07:00
commit c9be719d92
8 changed files with 109 additions and 83 deletions

View File

@ -84,14 +84,16 @@ func isCapabilityEnabled(c capability) bool {
func capabilityHandler(c capability, fn func(http.ResponseWriter, *http.Request)) http.HandlerFunc { func capabilityHandler(c capability, fn func(http.ResponseWriter, *http.Request)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
if !isCapabilityEnabled(c) { if !isCapabilityEnabled(c) {
notCapable(w, c) notCapable(w, r, c)
return return
} }
fn(w, r) fn(w, r)
} }
} }
func notCapable(w http.ResponseWriter, c capability) { func notCapable(w http.ResponseWriter, r *http.Request, c capability) {
herr := httptypes.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Not capable of accessing %s feature during rolling upgrades.", c)) herr := httptypes.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Not capable of accessing %s feature during rolling upgrades.", c))
herr.WriteTo(w) if err := herr.WriteTo(w); err != nil {
plog.Debugf("error writing HTTPError (%v) to %s", err, r.RemoteAddr)
}
} }

View File

@ -186,7 +186,7 @@ func (h *membersHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
if !hasWriteRootAccess(h.sec, r) { if !hasWriteRootAccess(h.sec, r) {
writeNoAuth(w) writeNoAuth(w, r)
return return
} }
w.Header().Set("X-Etcd-Cluster-ID", h.cluster.ID().String()) w.Header().Set("X-Etcd-Cluster-ID", h.cluster.ID().String())
@ -206,7 +206,7 @@ func (h *membersHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
case "leader": case "leader":
id := h.server.Leader() id := h.server.Leader()
if id == 0 { if id == 0 {
writeError(w, httptypes.NewHTTPError(http.StatusServiceUnavailable, "During election")) writeError(w, r, httptypes.NewHTTPError(http.StatusServiceUnavailable, "During election"))
return return
} }
m := newMember(h.cluster.Member(id)) m := newMember(h.cluster.Member(id))
@ -215,7 +215,7 @@ func (h *membersHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
plog.Warningf("failed to encode members response (%v)", err) plog.Warningf("failed to encode members response (%v)", err)
} }
default: default:
writeError(w, httptypes.NewHTTPError(http.StatusNotFound, "Not found")) writeError(w, r, httptypes.NewHTTPError(http.StatusNotFound, "Not found"))
} }
case "POST": case "POST":
req := httptypes.MemberCreateRequest{} req := httptypes.MemberCreateRequest{}
@ -227,11 +227,11 @@ func (h *membersHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
err := h.server.AddMember(ctx, *m) err := h.server.AddMember(ctx, *m)
switch { switch {
case err == etcdserver.ErrIDExists || err == etcdserver.ErrPeerURLexists: case err == etcdserver.ErrIDExists || err == etcdserver.ErrPeerURLexists:
writeError(w, httptypes.NewHTTPError(http.StatusConflict, err.Error())) writeError(w, r, httptypes.NewHTTPError(http.StatusConflict, err.Error()))
return return
case err != nil: case err != nil:
plog.Errorf("error adding member %s (%v)", m.ID, err) plog.Errorf("error adding member %s (%v)", m.ID, err)
writeError(w, err) writeError(w, r, err)
return return
} }
res := newMember(m) res := newMember(m)
@ -248,12 +248,12 @@ func (h *membersHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
err := h.server.RemoveMember(ctx, uint64(id)) err := h.server.RemoveMember(ctx, uint64(id))
switch { switch {
case err == etcdserver.ErrIDRemoved: case err == etcdserver.ErrIDRemoved:
writeError(w, httptypes.NewHTTPError(http.StatusGone, fmt.Sprintf("Member permanently removed: %s", id))) writeError(w, r, httptypes.NewHTTPError(http.StatusGone, fmt.Sprintf("Member permanently removed: %s", id)))
case err == etcdserver.ErrIDNotFound: case err == etcdserver.ErrIDNotFound:
writeError(w, httptypes.NewHTTPError(http.StatusNotFound, fmt.Sprintf("No such member: %s", id))) writeError(w, r, httptypes.NewHTTPError(http.StatusNotFound, fmt.Sprintf("No such member: %s", id)))
case err != nil: case err != nil:
plog.Errorf("error removing member %s (%v)", id, err) plog.Errorf("error removing member %s (%v)", id, err)
writeError(w, err) writeError(w, r, err)
default: default:
w.WriteHeader(http.StatusNoContent) w.WriteHeader(http.StatusNoContent)
} }
@ -273,12 +273,12 @@ func (h *membersHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
err := h.server.UpdateMember(ctx, m) err := h.server.UpdateMember(ctx, m)
switch { switch {
case err == etcdserver.ErrPeerURLexists: case err == etcdserver.ErrPeerURLexists:
writeError(w, httptypes.NewHTTPError(http.StatusConflict, err.Error())) writeError(w, r, httptypes.NewHTTPError(http.StatusConflict, err.Error()))
case err == etcdserver.ErrIDNotFound: case err == etcdserver.ErrIDNotFound:
writeError(w, httptypes.NewHTTPError(http.StatusNotFound, fmt.Sprintf("No such member: %s", id))) writeError(w, r, httptypes.NewHTTPError(http.StatusNotFound, fmt.Sprintf("No such member: %s", id)))
case err != nil: case err != nil:
plog.Errorf("error updating member %s (%v)", m.ID, err) plog.Errorf("error updating member %s (%v)", m.ID, err)
writeError(w, err) writeError(w, r, err)
default: default:
w.WriteHeader(http.StatusNoContent) w.WriteHeader(http.StatusNoContent)
} }
@ -311,7 +311,7 @@ func (h *statsHandler) serveLeader(w http.ResponseWriter, r *http.Request) {
} }
stats := h.stats.LeaderStats() stats := h.stats.LeaderStats()
if stats == nil { if stats == nil {
writeError(w, httptypes.NewHTTPError(http.StatusForbidden, "not current leader")) writeError(w, r, httptypes.NewHTTPError(http.StatusForbidden, "not current leader"))
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
@ -397,13 +397,13 @@ func logHandleFunc(w http.ResponseWriter, r *http.Request) {
d := json.NewDecoder(r.Body) d := json.NewDecoder(r.Body)
if err := d.Decode(&in); err != nil { if err := d.Decode(&in); err != nil {
writeError(w, httptypes.NewHTTPError(http.StatusBadRequest, "Invalid json body")) writeError(w, r, httptypes.NewHTTPError(http.StatusBadRequest, "Invalid json body"))
return return
} }
logl, err := capnslog.ParseLevel(strings.ToUpper(in.Level)) logl, err := capnslog.ParseLevel(strings.ToUpper(in.Level))
if err != nil { if err != nil {
writeError(w, httptypes.NewHTTPError(http.StatusBadRequest, "Invalid log level "+in.Level)) writeError(w, r, httptypes.NewHTTPError(http.StatusBadRequest, "Invalid log level "+in.Level))
return return
} }
@ -683,16 +683,16 @@ func trimErrorPrefix(err error, prefix string) error {
func unmarshalRequest(r *http.Request, req json.Unmarshaler, w http.ResponseWriter) bool { func unmarshalRequest(r *http.Request, req json.Unmarshaler, w http.ResponseWriter) bool {
ctype := r.Header.Get("Content-Type") ctype := r.Header.Get("Content-Type")
if ctype != "application/json" { if ctype != "application/json" {
writeError(w, httptypes.NewHTTPError(http.StatusUnsupportedMediaType, fmt.Sprintf("Bad Content-Type %s, accept application/json", ctype))) writeError(w, r, httptypes.NewHTTPError(http.StatusUnsupportedMediaType, fmt.Sprintf("Bad Content-Type %s, accept application/json", ctype)))
return false return false
} }
b, err := ioutil.ReadAll(r.Body) b, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
writeError(w, httptypes.NewHTTPError(http.StatusBadRequest, err.Error())) writeError(w, r, httptypes.NewHTTPError(http.StatusBadRequest, err.Error()))
return false return false
} }
if err := req.UnmarshalJSON(b); err != nil { if err := req.UnmarshalJSON(b); err != nil {
writeError(w, httptypes.NewHTTPError(http.StatusBadRequest, err.Error())) writeError(w, r, httptypes.NewHTTPError(http.StatusBadRequest, err.Error()))
return false return false
} }
return true return true
@ -706,7 +706,7 @@ func getID(p string, w http.ResponseWriter) (types.ID, bool) {
} }
id, err := types.IDFromString(idStr) id, err := types.IDFromString(idStr)
if err != nil { if err != nil {
writeError(w, httptypes.NewHTTPError(http.StatusNotFound, fmt.Sprintf("No such member: %s", idStr))) writeError(w, nil, httptypes.NewHTTPError(http.StatusNotFound, fmt.Sprintf("No such member: %s", idStr)))
return 0, false return 0, false
} }
return id, true return id, true

View File

@ -126,9 +126,11 @@ func hasGuestAccess(sec auth.Store, r *http.Request, key string) bool {
return false return false
} }
func writeNoAuth(w http.ResponseWriter) { func writeNoAuth(w http.ResponseWriter, r *http.Request) {
herr := httptypes.NewHTTPError(http.StatusUnauthorized, "Insufficient credentials") herr := httptypes.NewHTTPError(http.StatusUnauthorized, "Insufficient credentials")
herr.WriteTo(w) if err := herr.WriteTo(w); err != nil {
plog.Debugf("error writing HTTPError (%v) to %s", err, r.RemoteAddr)
}
} }
func handleAuth(mux *http.ServeMux, sh *authHandler) { func handleAuth(mux *http.ServeMux, sh *authHandler) {
@ -144,7 +146,7 @@ func (sh *authHandler) baseRoles(w http.ResponseWriter, r *http.Request) {
return return
} }
if !hasRootAccess(sh.sec, r) { if !hasRootAccess(sh.sec, r) {
writeNoAuth(w) writeNoAuth(w, r)
return return
} }
@ -153,7 +155,7 @@ func (sh *authHandler) baseRoles(w http.ResponseWriter, r *http.Request) {
roles, err := sh.sec.AllRoles() roles, err := sh.sec.AllRoles()
if err != nil { if err != nil {
writeError(w, err) writeError(w, r, err)
return return
} }
if roles == nil { if roles == nil {
@ -162,7 +164,7 @@ func (sh *authHandler) baseRoles(w http.ResponseWriter, r *http.Request) {
err = r.ParseForm() err = r.ParseForm()
if err != nil { if err != nil {
writeError(w, err) writeError(w, r, err)
return return
} }
@ -173,7 +175,7 @@ func (sh *authHandler) baseRoles(w http.ResponseWriter, r *http.Request) {
var role auth.Role var role auth.Role
role, err = sh.sec.GetRole(roleName) role, err = sh.sec.GetRole(roleName)
if err != nil { if err != nil {
writeError(w, err) writeError(w, r, err)
return return
} }
rolesCollections.Roles = append(rolesCollections.Roles, role) rolesCollections.Roles = append(rolesCollections.Roles, role)
@ -182,7 +184,7 @@ func (sh *authHandler) baseRoles(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
plog.Warningf("baseRoles error encoding on %s", r.URL) plog.Warningf("baseRoles error encoding on %s", r.URL)
writeError(w, err) writeError(w, r, err)
return return
} }
} }
@ -197,7 +199,7 @@ func (sh *authHandler) handleRoles(w http.ResponseWriter, r *http.Request) {
return return
} }
if len(pieces) != 3 { if len(pieces) != 3 {
writeError(w, httptypes.NewHTTPError(http.StatusBadRequest, "Invalid path")) writeError(w, r, httptypes.NewHTTPError(http.StatusBadRequest, "Invalid path"))
return return
} }
sh.forRole(w, r, pieces[2]) sh.forRole(w, r, pieces[2])
@ -208,7 +210,7 @@ func (sh *authHandler) forRole(w http.ResponseWriter, r *http.Request, role stri
return return
} }
if !hasRootAccess(sh.sec, r) { if !hasRootAccess(sh.sec, r) {
writeNoAuth(w) writeNoAuth(w, r)
return return
} }
w.Header().Set("X-Etcd-Cluster-ID", sh.cluster.ID().String()) w.Header().Set("X-Etcd-Cluster-ID", sh.cluster.ID().String())
@ -218,7 +220,7 @@ func (sh *authHandler) forRole(w http.ResponseWriter, r *http.Request, role stri
case "GET": case "GET":
data, err := sh.sec.GetRole(role) data, err := sh.sec.GetRole(role)
if err != nil { if err != nil {
writeError(w, err) writeError(w, r, err)
return return
} }
err = json.NewEncoder(w).Encode(data) err = json.NewEncoder(w).Encode(data)
@ -231,11 +233,11 @@ func (sh *authHandler) forRole(w http.ResponseWriter, r *http.Request, role stri
var in auth.Role var in auth.Role
err := json.NewDecoder(r.Body).Decode(&in) err := json.NewDecoder(r.Body).Decode(&in)
if err != nil { if err != nil {
writeError(w, httptypes.NewHTTPError(http.StatusBadRequest, "Invalid JSON in request body.")) writeError(w, r, httptypes.NewHTTPError(http.StatusBadRequest, "Invalid JSON in request body."))
return return
} }
if in.Role != role { if in.Role != role {
writeError(w, httptypes.NewHTTPError(http.StatusBadRequest, "Role JSON name does not match the name in the URL")) writeError(w, r, httptypes.NewHTTPError(http.StatusBadRequest, "Role JSON name does not match the name in the URL"))
return return
} }
@ -245,19 +247,19 @@ func (sh *authHandler) forRole(w http.ResponseWriter, r *http.Request, role stri
if in.Grant.IsEmpty() && in.Revoke.IsEmpty() { if in.Grant.IsEmpty() && in.Revoke.IsEmpty() {
err = sh.sec.CreateRole(in) err = sh.sec.CreateRole(in)
if err != nil { if err != nil {
writeError(w, err) writeError(w, r, err)
return return
} }
w.WriteHeader(http.StatusCreated) w.WriteHeader(http.StatusCreated)
out = in out = in
} else { } else {
if !in.Permissions.IsEmpty() { if !in.Permissions.IsEmpty() {
writeError(w, httptypes.NewHTTPError(http.StatusBadRequest, "Role JSON contains both permissions and grant/revoke")) writeError(w, r, httptypes.NewHTTPError(http.StatusBadRequest, "Role JSON contains both permissions and grant/revoke"))
return return
} }
out, err = sh.sec.UpdateRole(in) out, err = sh.sec.UpdateRole(in)
if err != nil { if err != nil {
writeError(w, err) writeError(w, r, err)
return return
} }
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
@ -272,7 +274,7 @@ func (sh *authHandler) forRole(w http.ResponseWriter, r *http.Request, role stri
case "DELETE": case "DELETE":
err := sh.sec.DeleteRole(role) err := sh.sec.DeleteRole(role)
if err != nil { if err != nil {
writeError(w, err) writeError(w, r, err)
return return
} }
} }
@ -288,7 +290,7 @@ func (sh *authHandler) baseUsers(w http.ResponseWriter, r *http.Request) {
return return
} }
if !hasRootAccess(sh.sec, r) { if !hasRootAccess(sh.sec, r) {
writeNoAuth(w) writeNoAuth(w, r)
return return
} }
w.Header().Set("X-Etcd-Cluster-ID", sh.cluster.ID().String()) w.Header().Set("X-Etcd-Cluster-ID", sh.cluster.ID().String())
@ -296,7 +298,7 @@ func (sh *authHandler) baseUsers(w http.ResponseWriter, r *http.Request) {
users, err := sh.sec.AllUsers() users, err := sh.sec.AllUsers()
if err != nil { if err != nil {
writeError(w, err) writeError(w, r, err)
return return
} }
if users == nil { if users == nil {
@ -305,7 +307,7 @@ func (sh *authHandler) baseUsers(w http.ResponseWriter, r *http.Request) {
err = r.ParseForm() err = r.ParseForm()
if err != nil { if err != nil {
writeError(w, err) writeError(w, r, err)
return return
} }
@ -316,7 +318,7 @@ func (sh *authHandler) baseUsers(w http.ResponseWriter, r *http.Request) {
var user auth.User var user auth.User
user, err = sh.sec.GetUser(userName) user, err = sh.sec.GetUser(userName)
if err != nil { if err != nil {
writeError(w, err) writeError(w, r, err)
return return
} }
@ -325,7 +327,7 @@ func (sh *authHandler) baseUsers(w http.ResponseWriter, r *http.Request) {
var role auth.Role var role auth.Role
role, err = sh.sec.GetRole(roleName) role, err = sh.sec.GetRole(roleName)
if err != nil { if err != nil {
writeError(w, err) writeError(w, r, err)
return return
} }
uwr.Roles = append(uwr.Roles, role) uwr.Roles = append(uwr.Roles, role)
@ -337,7 +339,7 @@ func (sh *authHandler) baseUsers(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
plog.Warningf("baseUsers error encoding on %s", r.URL) plog.Warningf("baseUsers error encoding on %s", r.URL)
writeError(w, err) writeError(w, r, err)
return return
} }
} }
@ -352,7 +354,7 @@ func (sh *authHandler) handleUsers(w http.ResponseWriter, r *http.Request) {
return return
} }
if len(pieces) != 3 { if len(pieces) != 3 {
writeError(w, httptypes.NewHTTPError(http.StatusBadRequest, "Invalid path")) writeError(w, r, httptypes.NewHTTPError(http.StatusBadRequest, "Invalid path"))
return return
} }
sh.forUser(w, r, pieces[2]) sh.forUser(w, r, pieces[2])
@ -363,7 +365,7 @@ func (sh *authHandler) forUser(w http.ResponseWriter, r *http.Request, user stri
return return
} }
if !hasRootAccess(sh.sec, r) { if !hasRootAccess(sh.sec, r) {
writeNoAuth(w) writeNoAuth(w, r)
return return
} }
w.Header().Set("X-Etcd-Cluster-ID", sh.cluster.ID().String()) w.Header().Set("X-Etcd-Cluster-ID", sh.cluster.ID().String())
@ -373,13 +375,13 @@ func (sh *authHandler) forUser(w http.ResponseWriter, r *http.Request, user stri
case "GET": case "GET":
u, err := sh.sec.GetUser(user) u, err := sh.sec.GetUser(user)
if err != nil { if err != nil {
writeError(w, err) writeError(w, r, err)
return return
} }
err = r.ParseForm() err = r.ParseForm()
if err != nil { if err != nil {
writeError(w, err) writeError(w, r, err)
return return
} }
@ -388,7 +390,7 @@ func (sh *authHandler) forUser(w http.ResponseWriter, r *http.Request, user stri
var role auth.Role var role auth.Role
role, err = sh.sec.GetRole(roleName) role, err = sh.sec.GetRole(roleName)
if err != nil { if err != nil {
writeError(w, err) writeError(w, r, err)
return return
} }
uwr.Roles = append(uwr.Roles, role) uwr.Roles = append(uwr.Roles, role)
@ -404,11 +406,11 @@ func (sh *authHandler) forUser(w http.ResponseWriter, r *http.Request, user stri
var u auth.User var u auth.User
err := json.NewDecoder(r.Body).Decode(&u) err := json.NewDecoder(r.Body).Decode(&u)
if err != nil { if err != nil {
writeError(w, httptypes.NewHTTPError(http.StatusBadRequest, "Invalid JSON in request body.")) writeError(w, r, httptypes.NewHTTPError(http.StatusBadRequest, "Invalid JSON in request body."))
return return
} }
if u.User != user { if u.User != user {
writeError(w, httptypes.NewHTTPError(http.StatusBadRequest, "User JSON name does not match the name in the URL")) writeError(w, r, httptypes.NewHTTPError(http.StatusBadRequest, "User JSON name does not match the name in the URL"))
return return
} }
@ -428,18 +430,18 @@ func (sh *authHandler) forUser(w http.ResponseWriter, r *http.Request, user stri
} }
if err != nil { if err != nil {
writeError(w, err) writeError(w, r, err)
return return
} }
} else { } else {
// update case // update case
if len(u.Roles) != 0 { if len(u.Roles) != 0 {
writeError(w, httptypes.NewHTTPError(http.StatusBadRequest, "User JSON contains both roles and grant/revoke")) writeError(w, r, httptypes.NewHTTPError(http.StatusBadRequest, "User JSON contains both roles and grant/revoke"))
return return
} }
out, err = sh.sec.UpdateUser(u) out, err = sh.sec.UpdateUser(u)
if err != nil { if err != nil {
writeError(w, err) writeError(w, r, err)
return return
} }
} }
@ -461,7 +463,7 @@ func (sh *authHandler) forUser(w http.ResponseWriter, r *http.Request, user stri
case "DELETE": case "DELETE":
err := sh.sec.DeleteUser(user) err := sh.sec.DeleteUser(user)
if err != nil { if err != nil {
writeError(w, err) writeError(w, r, err)
return return
} }
} }
@ -476,7 +478,7 @@ func (sh *authHandler) enableDisable(w http.ResponseWriter, r *http.Request) {
return return
} }
if !hasWriteRootAccess(sh.sec, r) { if !hasWriteRootAccess(sh.sec, r) {
writeNoAuth(w) writeNoAuth(w, r)
return return
} }
w.Header().Set("X-Etcd-Cluster-ID", sh.cluster.ID().String()) w.Header().Set("X-Etcd-Cluster-ID", sh.cluster.ID().String())
@ -492,13 +494,13 @@ func (sh *authHandler) enableDisable(w http.ResponseWriter, r *http.Request) {
case "PUT": case "PUT":
err := sh.sec.EnableAuth() err := sh.sec.EnableAuth()
if err != nil { if err != nil {
writeError(w, err) writeError(w, r, err)
return return
} }
case "DELETE": case "DELETE":
err := sh.sec.DisableAuth() err := sh.sec.DisableAuth()
if err != nil { if err != nil {
writeError(w, err) writeError(w, r, err)
return return
} }
} }

View File

@ -41,7 +41,7 @@ var (
// writeError logs and writes the given Error to the ResponseWriter // writeError logs and writes the given Error to the ResponseWriter
// If Error is an etcdErr, it is rendered to the ResponseWriter // If Error is an etcdErr, it is rendered to the ResponseWriter
// Otherwise, it is assumed to be an InternalServerError // Otherwise, it is assumed to be an InternalServerError
func writeError(w http.ResponseWriter, err error) { func writeError(w http.ResponseWriter, r *http.Request, err error) {
if err == nil { if err == nil {
return return
} }
@ -49,10 +49,14 @@ func writeError(w http.ResponseWriter, err error) {
case *etcdErr.Error: case *etcdErr.Error:
e.WriteTo(w) e.WriteTo(w)
case *httptypes.HTTPError: case *httptypes.HTTPError:
e.WriteTo(w) if et := e.WriteTo(w); et != nil {
plog.Debugf("error writing HTTPError (%v) to %s", et, r.RemoteAddr)
}
case auth.Error: case auth.Error:
herr := httptypes.NewHTTPError(e.HTTPStatus(), e.Error()) herr := httptypes.NewHTTPError(e.HTTPStatus(), e.Error())
herr.WriteTo(w) if et := herr.WriteTo(w); et != nil {
plog.Debugf("error writing HTTPError (%v) to %s", et, r.RemoteAddr)
}
default: default:
switch err { switch err {
case etcdserver.ErrTimeoutDueToLeaderFail, etcdserver.ErrTimeoutDueToConnectionLost: case etcdserver.ErrTimeoutDueToLeaderFail, etcdserver.ErrTimeoutDueToConnectionLost:
@ -61,7 +65,9 @@ func writeError(w http.ResponseWriter, err error) {
plog.Errorf("got unexpected response error (%v)", err) plog.Errorf("got unexpected response error (%v)", err)
} }
herr := httptypes.NewHTTPError(http.StatusInternalServerError, "Internal Server Error") herr := httptypes.NewHTTPError(http.StatusInternalServerError, "Internal Server Error")
herr.WriteTo(w) if et := herr.WriteTo(w); et != nil {
plog.Debugf("error writing HTTPError (%v) to %s", et, r.RemoteAddr)
}
} }
} }

View File

@ -81,7 +81,8 @@ func (fs *errServer) ClusterVersion() *semver.Version { return nil }
func TestWriteError(t *testing.T) { func TestWriteError(t *testing.T) {
// nil error should not panic // nil error should not panic
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
writeError(rec, nil) r := new(http.Request)
writeError(rec, r, nil)
h := rec.Header() h := rec.Header()
if len(h) > 0 { if len(h) > 0 {
t.Fatalf("unexpected non-empty headers: %#v", h) t.Fatalf("unexpected non-empty headers: %#v", h)
@ -114,7 +115,7 @@ func TestWriteError(t *testing.T) {
for i, tt := range tests { for i, tt := range tests {
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
writeError(rw, tt.err) writeError(rw, r, tt.err)
if code := rw.Code; code != tt.wcode { if code := rw.Code; code != tt.wcode {
t.Errorf("#%d: code=%d, want %d", i, code, tt.wcode) t.Errorf("#%d: code=%d, want %d", i, code, tt.wcode)
} }

View File

@ -35,15 +35,17 @@ func (e HTTPError) Error() string {
return e.Message return e.Message
} }
// TODO(xiangli): handle http write errors func (e HTTPError) WriteTo(w http.ResponseWriter) error {
func (e HTTPError) WriteTo(w http.ResponseWriter) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(e.Code) w.WriteHeader(e.Code)
b, err := json.Marshal(e) b, err := json.Marshal(e)
if err != nil { if err != nil {
plog.Panicf("marshal HTTPError should never fail (%v)", err) plog.Panicf("marshal HTTPError should never fail (%v)", err)
} }
w.Write(b) if _, err := w.Write(b); err != nil {
return err
}
return nil
} }
func NewHTTPError(code int, m string) *HTTPError { func NewHTTPError(code int, m string) *HTTPError {

View File

@ -24,7 +24,9 @@ import (
func TestHTTPErrorWriteTo(t *testing.T) { func TestHTTPErrorWriteTo(t *testing.T) {
err := NewHTTPError(http.StatusBadRequest, "what a bad request you made!") err := NewHTTPError(http.StatusBadRequest, "what a bad request you made!")
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
err.WriteTo(rr) if e := err.WriteTo(rr); e != nil {
t.Fatalf("HTTPError.WriteTo error (%v)", e)
}
wcode := http.StatusBadRequest wcode := http.StatusBadRequest
wheader := http.Header(map[string][]string{ wheader := http.Header(map[string][]string{

View File

@ -28,23 +28,28 @@ import (
"time" "time"
"github.com/coreos/etcd/Godeps/_workspace/src/github.com/coreos/pkg/capnslog"
"github.com/coreos/etcd/etcdserver/etcdhttp/httptypes" "github.com/coreos/etcd/etcdserver/etcdhttp/httptypes"
"github.com/coreos/etcd/pkg/httputil" "github.com/coreos/etcd/pkg/httputil"
) )
// Hop-by-hop headers. These are removed when sent to the backend. var (
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html plog = capnslog.NewPackageLogger("github.com/coreos/etcd", "proxy")
// This list of headers borrowed from stdlib httputil.ReverseProxy
var singleHopHeaders = []string{ // Hop-by-hop headers. These are removed when sent to the backend.
"Connection", // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
"Keep-Alive", // This list of headers borrowed from stdlib httputil.ReverseProxy
"Proxy-Authenticate", singleHopHeaders = []string{
"Proxy-Authorization", "Connection",
"Te", // canonicalized version of "TE" "Keep-Alive",
"Trailers", "Proxy-Authenticate",
"Transfer-Encoding", "Proxy-Authorization",
"Upgrade", "Te", // canonicalized version of "TE"
} "Trailers",
"Transfer-Encoding",
"Upgrade",
}
)
func removeSingleHopHeaders(hdrs *http.Header) { func removeSingleHopHeaders(hdrs *http.Header) {
for _, h := range singleHopHeaders { for _, h := range singleHopHeaders {
@ -72,7 +77,9 @@ func (p *reverseProxy) ServeHTTP(rw http.ResponseWriter, clientreq *http.Request
if err != nil { if err != nil {
msg := fmt.Sprintf("proxy: failed to read request body: %v", err) msg := fmt.Sprintf("proxy: failed to read request body: %v", err)
e := httptypes.NewHTTPError(http.StatusInternalServerError, msg) e := httptypes.NewHTTPError(http.StatusInternalServerError, msg)
e.WriteTo(rw) if we := e.WriteTo(rw); we != nil {
plog.Debugf("error writing HTTPError (%v) to %s", we, clientreq.RemoteAddr)
}
return return
} }
} }
@ -93,7 +100,9 @@ func (p *reverseProxy) ServeHTTP(rw http.ResponseWriter, clientreq *http.Request
// TODO: limit the rate of the error logging. // TODO: limit the rate of the error logging.
log.Printf(msg) log.Printf(msg)
e := httptypes.NewHTTPError(http.StatusServiceUnavailable, msg) e := httptypes.NewHTTPError(http.StatusServiceUnavailable, msg)
e.WriteTo(rw) if we := e.WriteTo(rw); we != nil {
plog.Debugf("error writing HTTPError (%v) to %s", we, clientreq.RemoteAddr)
}
return return
} }
@ -145,7 +154,9 @@ func (p *reverseProxy) ServeHTTP(rw http.ResponseWriter, clientreq *http.Request
reportRequestDropped(clientreq, failedGettingResponse) reportRequestDropped(clientreq, failedGettingResponse)
log.Printf(msg) log.Printf(msg)
e := httptypes.NewHTTPError(http.StatusBadGateway, msg) e := httptypes.NewHTTPError(http.StatusBadGateway, msg)
e.WriteTo(rw) if we := e.WriteTo(rw); we != nil {
plog.Debugf("error writing HTTPError (%v) to %s", we, clientreq.RemoteAddr)
}
return return
} }