mirror of
https://github.com/etcd-io/etcd.git
synced 2024-09-27 06:25:44 +00:00
embed: support "CORS" handler in v3 HTTP requests
Signed-off-by: Gyuho Lee <gyuhox@gmail.com>
This commit is contained in:
parent
c7cecca575
commit
9ea8be0c2b
@ -23,6 +23,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@ -33,7 +34,6 @@ import (
|
|||||||
"github.com/coreos/etcd/etcdserver/api/v2v3"
|
"github.com/coreos/etcd/etcdserver/api/v2v3"
|
||||||
"github.com/coreos/etcd/etcdserver/api/v3client"
|
"github.com/coreos/etcd/etcdserver/api/v3client"
|
||||||
"github.com/coreos/etcd/etcdserver/api/v3rpc"
|
"github.com/coreos/etcd/etcdserver/api/v3rpc"
|
||||||
"github.com/coreos/etcd/pkg/cors"
|
|
||||||
"github.com/coreos/etcd/pkg/debugutil"
|
"github.com/coreos/etcd/pkg/debugutil"
|
||||||
runtimeutil "github.com/coreos/etcd/pkg/runtime"
|
runtimeutil "github.com/coreos/etcd/pkg/runtime"
|
||||||
"github.com/coreos/etcd/pkg/transport"
|
"github.com/coreos/etcd/pkg/transport"
|
||||||
@ -168,6 +168,7 @@ func StartEtcd(inCfg *Config) (e *Etcd, err error) {
|
|||||||
StrictReconfigCheck: cfg.StrictReconfigCheck,
|
StrictReconfigCheck: cfg.StrictReconfigCheck,
|
||||||
ClientCertAuthEnabled: cfg.ClientTLSInfo.ClientCertAuth,
|
ClientCertAuthEnabled: cfg.ClientTLSInfo.ClientCertAuth,
|
||||||
AuthToken: cfg.AuthToken,
|
AuthToken: cfg.AuthToken,
|
||||||
|
CORS: cfg.CORS,
|
||||||
HostWhitelist: cfg.HostWhitelist,
|
HostWhitelist: cfg.HostWhitelist,
|
||||||
InitialCorruptCheck: cfg.ExperimentalInitialCorruptCheck,
|
InitialCorruptCheck: cfg.ExperimentalInitialCorruptCheck,
|
||||||
CorruptCheckTime: cfg.ExperimentalCorruptCheckTime,
|
CorruptCheckTime: cfg.ExperimentalCorruptCheckTime,
|
||||||
@ -473,8 +474,13 @@ func (e *Etcd) serveClients() (err error) {
|
|||||||
plog.Infof("ClientTLS: %s", e.cfg.ClientTLSInfo)
|
plog.Infof("ClientTLS: %s", e.cfg.ClientTLSInfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
if e.cfg.CorsInfo.String() != "" {
|
if len(e.cfg.CORS) > 0 {
|
||||||
plog.Infof("cors = %s", e.cfg.CorsInfo)
|
ss := make([]string, 0, len(e.cfg.CORS))
|
||||||
|
for v := range e.cfg.CORS {
|
||||||
|
ss = append(ss, v)
|
||||||
|
}
|
||||||
|
sort.Strings(ss)
|
||||||
|
plog.Infof("cors = %q", ss)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start a client server goroutine for each listen address
|
// Start a client server goroutine for each listen address
|
||||||
@ -491,7 +497,6 @@ func (e *Etcd) serveClients() (err error) {
|
|||||||
etcdhttp.HandleBasic(mux, e.Server)
|
etcdhttp.HandleBasic(mux, e.Server)
|
||||||
h = mux
|
h = mux
|
||||||
}
|
}
|
||||||
h = http.Handler(&cors.CORSHandler{Handler: h, Info: e.cfg.CorsInfo})
|
|
||||||
|
|
||||||
gopts := []grpc.ServerOption{}
|
gopts := []grpc.ServerOption{}
|
||||||
if e.cfg.GRPCKeepAliveMinTime > time.Duration(0) {
|
if e.cfg.GRPCKeepAliveMinTime > time.Duration(0) {
|
||||||
|
@ -116,7 +116,7 @@ func (sctx *serveCtx) serve(
|
|||||||
httpmux := sctx.createMux(gwmux, handler)
|
httpmux := sctx.createMux(gwmux, handler)
|
||||||
|
|
||||||
srvhttp := &http.Server{
|
srvhttp := &http.Server{
|
||||||
Handler: wrapMux(s, httpmux),
|
Handler: createAccessController(s, httpmux),
|
||||||
ErrorLog: logger, // do not log user error
|
ErrorLog: logger, // do not log user error
|
||||||
}
|
}
|
||||||
httpl := m.Match(cmux.HTTP1())
|
httpl := m.Match(cmux.HTTP1())
|
||||||
@ -159,7 +159,7 @@ func (sctx *serveCtx) serve(
|
|||||||
httpmux := sctx.createMux(gwmux, handler)
|
httpmux := sctx.createMux(gwmux, handler)
|
||||||
|
|
||||||
srv := &http.Server{
|
srv := &http.Server{
|
||||||
Handler: wrapMux(s, httpmux),
|
Handler: createAccessController(s, httpmux),
|
||||||
TLSConfig: tlscfg,
|
TLSConfig: tlscfg,
|
||||||
ErrorLog: logger, // do not log user error
|
ErrorLog: logger, // do not log user error
|
||||||
}
|
}
|
||||||
@ -250,20 +250,20 @@ func (sctx *serveCtx) createMux(gwmux *gw.ServeMux, handler http.Handler) *http.
|
|||||||
return httpmux
|
return httpmux
|
||||||
}
|
}
|
||||||
|
|
||||||
// wrapMux wraps HTTP multiplexer:
|
// createAccessController wraps HTTP multiplexer:
|
||||||
// - mutate gRPC gateway request paths
|
// - mutate gRPC gateway request paths
|
||||||
// - check hostname whitelist
|
// - check hostname whitelist
|
||||||
// client HTTP requests goes here first
|
// client HTTP requests goes here first
|
||||||
func wrapMux(s *etcdserver.EtcdServer, mux *http.ServeMux) http.Handler {
|
func createAccessController(s *etcdserver.EtcdServer, mux *http.ServeMux) http.Handler {
|
||||||
return &httpWrapper{s: s, mux: mux}
|
return &accessController{s: s, mux: mux}
|
||||||
}
|
}
|
||||||
|
|
||||||
type httpWrapper struct {
|
type accessController struct {
|
||||||
s *etcdserver.EtcdServer
|
s *etcdserver.EtcdServer
|
||||||
mux *http.ServeMux
|
mux *http.ServeMux
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *httpWrapper) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
func (ac *accessController) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
// redirect for backward compatibilities
|
// redirect for backward compatibilities
|
||||||
if req != nil && req.URL != nil && strings.HasPrefix(req.URL.Path, "/v3beta/") {
|
if req != nil && req.URL != nil && strings.HasPrefix(req.URL.Path, "/v3beta/") {
|
||||||
req.URL.Path = strings.Replace(req.URL.Path, "/v3beta/", "/v3/", 1)
|
req.URL.Path = strings.Replace(req.URL.Path, "/v3beta/", "/v3/", 1)
|
||||||
@ -271,7 +271,7 @@ func (m *httpWrapper) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||||||
|
|
||||||
if req.TLS == nil { // check origin if client connection is not secure
|
if req.TLS == nil { // check origin if client connection is not secure
|
||||||
host := httputil.GetHostname(req)
|
host := httputil.GetHostname(req)
|
||||||
if !m.s.IsHostWhitelisted(host) {
|
if !ac.s.AccessController.IsHostWhitelisted(host) {
|
||||||
plog.Warningf("rejecting HTTP request from %q to prevent DNS rebinding attacks", host)
|
plog.Warningf("rejecting HTTP request from %q to prevent DNS rebinding attacks", host)
|
||||||
// TODO: use Go's "http.StatusMisdirectedRequest" (421)
|
// TODO: use Go's "http.StatusMisdirectedRequest" (421)
|
||||||
// https://github.com/golang/go/commit/4b8a7eafef039af1834ef9bfa879257c4a72b7b5
|
// https://github.com/golang/go/commit/4b8a7eafef039af1834ef9bfa879257c4a72b7b5
|
||||||
@ -280,7 +280,26 @@ func (m *httpWrapper) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
m.mux.ServeHTTP(rw, req)
|
// Write CORS header.
|
||||||
|
if ac.s.AccessController.OriginAllowed("*") {
|
||||||
|
addCORSHeader(rw, "*")
|
||||||
|
} else if origin := req.Header.Get("Origin"); ac.s.OriginAllowed(origin) {
|
||||||
|
addCORSHeader(rw, origin)
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Method == "OPTIONS" {
|
||||||
|
rw.WriteHeader(http.StatusOK)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ac.mux.ServeHTTP(rw, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// addCORSHeader adds the correct cors headers given an origin
|
||||||
|
func addCORSHeader(w http.ResponseWriter, origin string) {
|
||||||
|
w.Header().Add("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
|
||||||
|
w.Header().Add("Access-Control-Allow-Origin", origin)
|
||||||
|
w.Header().Add("Access-Control-Allow-Headers", "accept, content-type, authorization")
|
||||||
}
|
}
|
||||||
|
|
||||||
// https://github.com/transmission/transmission/pull/468
|
// https://github.com/transmission/transmission/pull/468
|
||||||
@ -297,6 +316,35 @@ This requirement has been added to help prevent "DNS Rebinding" attacks (CVE-201
|
|||||||
`, host)
|
`, host)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WrapCORS wraps existing handler with CORS.
|
||||||
|
// TODO: deprecate this after v2 proxy deprecate
|
||||||
|
func WrapCORS(cors map[string]struct{}, h http.Handler) http.Handler {
|
||||||
|
return &corsHandler{
|
||||||
|
ac: &etcdserver.AccessController{CORS: cors},
|
||||||
|
h: h,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type corsHandler struct {
|
||||||
|
ac *etcdserver.AccessController
|
||||||
|
h http.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ch *corsHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
if ch.ac.OriginAllowed("*") {
|
||||||
|
addCORSHeader(rw, "*")
|
||||||
|
} else if origin := req.Header.Get("Origin"); ch.ac.OriginAllowed(origin) {
|
||||||
|
addCORSHeader(rw, origin)
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Method == "OPTIONS" {
|
||||||
|
rw.WriteHeader(http.StatusOK)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ch.h.ServeHTTP(rw, req)
|
||||||
|
}
|
||||||
|
|
||||||
func (sctx *serveCtx) registerUserHandler(s string, h http.Handler) {
|
func (sctx *serveCtx) registerUserHandler(s string, h http.Handler) {
|
||||||
if sctx.userHandlers[s] != nil {
|
if sctx.userHandlers[s] != nil {
|
||||||
plog.Warningf("path %s already registered by user handler", s)
|
plog.Warningf("path %s already registered by user handler", s)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user