From a93d60be90bfe41eabc0a18f56a3b4144dda5f8e Mon Sep 17 00:00:00 2001 From: Brian Waldon Date: Fri, 17 Jan 2014 20:04:10 -0800 Subject: [PATCH] refactor(cors): Break apart CORS data and middleware --- etcd.go | 11 ++++++++--- server/cors.go | 34 +++++++++++++++++----------------- server/peer_server.go | 1 + server/peer_server_handlers.go | 4 ++-- server/server.go | 33 ++++++++++++--------------------- tests/server_utils.go | 3 +++ 6 files changed, 43 insertions(+), 43 deletions(-) diff --git a/etcd.go b/etcd.go index c60577085..26ec45ebc 100644 --- a/etcd.go +++ b/etcd.go @@ -98,6 +98,12 @@ func main() { } } + // Retrieve CORS configuration + corsInfo, err := server.NewCORSInfo(config.CorsOrigins) + if err != nil { + log.Fatal("CORS:", err) + } + // Create etcd key-value store and registry. store := store.New() registry := server.NewRegistry(store) @@ -113,6 +119,7 @@ func main() { ElectionTimeout: time.Duration(config.Peer.ElectionTimeout) * time.Millisecond, MaxClusterSize: config.MaxClusterSize, RetryTimes: config.MaxRetryAttempts, + CORS: corsInfo, } ps := server.NewPeerServer(psConfig, &peerTLSConfig, &info.RaftTLS, registry, store, &mb) @@ -121,11 +128,9 @@ func main() { Name: info.Name, URL: info.EtcdURL, BindAddr: info.EtcdListenHost, + CORS: corsInfo, } s := server.New(sConfig, &tlsConfig, &info.EtcdTLS, ps, registry, store, &mb) - if err := s.AllowOrigins(config.CorsOrigins); err != nil { - panic(err) - } if config.Trace() { s.EnableTracing() diff --git a/server/cors.go b/server/cors.go index fec3c4abc..a3728b689 100644 --- a/server/cors.go +++ b/server/cors.go @@ -20,50 +20,50 @@ import ( "fmt" "net/http" "net/url" - - "github.com/gorilla/mux" ) -type corsHandler struct { - router *mux.Router - corsOrigins map[string]bool +type corsInfo struct { + origins map[string]bool } -// AllowOrigins sets a comma-delimited list of origins that are allowed. -func (s *corsHandler) AllowOrigins(origins []string) error { +func NewCORSInfo(origins []string) (*corsInfo, error) { // Construct a lookup of all origins. m := make(map[string]bool) for _, v := range origins { if v != "*" { if _, err := url.Parse(v); err != nil { - return fmt.Errorf("Invalid CORS origin: %s", err) + return nil, fmt.Errorf("Invalid CORS origin: %s", err) } } m[v] = true } - s.corsOrigins = m - return nil + return &corsInfo{m}, nil } // OriginAllowed determines whether the server will allow a given CORS origin. -func (c *corsHandler) OriginAllowed(origin string) bool { - return c.corsOrigins["*"] || c.corsOrigins[origin] +func (c *corsInfo) OriginAllowed(origin string) bool { + return c.origins["*"] || c.origins[origin] +} + +type corsHTTPMiddleware struct { + next http.Handler + info *corsInfo } // addHeader adds the correct cors headers given an origin -func (h *corsHandler) addHeader(w http.ResponseWriter, origin string) { +func (h *corsHTTPMiddleware) addHeader(w http.ResponseWriter, origin string) { w.Header().Add("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE") w.Header().Add("Access-Control-Allow-Origin", origin) } // ServeHTTP adds the correct CORS headers based on the origin and returns immediatly // with a 200 OK if the method is OPTIONS. -func (h *corsHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { +func (h *corsHTTPMiddleware) ServeHTTP(w http.ResponseWriter, req *http.Request) { // Write CORS header. - if h.OriginAllowed("*") { + if h.info.OriginAllowed("*") { h.addHeader(w, "*") - } else if origin := req.Header.Get("Origin"); h.OriginAllowed(origin) { + } else if origin := req.Header.Get("Origin"); h.info.OriginAllowed(origin) { h.addHeader(w, origin) } @@ -72,5 +72,5 @@ func (h *corsHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { return } - h.router.ServeHTTP(w, req) + h.next.ServeHTTP(w, req) } diff --git a/server/peer_server.go b/server/peer_server.go index 56149a5f4..bb5f8eb3a 100644 --- a/server/peer_server.go +++ b/server/peer_server.go @@ -35,6 +35,7 @@ type PeerServerConfig struct { ElectionTimeout time.Duration MaxClusterSize int RetryTimes int + CORS *corsInfo } type PeerServer struct { diff --git a/server/peer_server_handlers.go b/server/peer_server_handlers.go index a4ef84710..a2c498101 100644 --- a/server/peer_server_handlers.go +++ b/server/peer_server_handlers.go @@ -150,9 +150,9 @@ func (ps *PeerServer) JoinHttpHandler(w http.ResponseWriter, req *http.Request) command := &JoinCommand{} // Write CORS header. - if ps.server.OriginAllowed("*") { + if ps.Config.CORS.OriginAllowed("*") { w.Header().Add("Access-Control-Allow-Origin", "*") - } else if ps.server.OriginAllowed(req.Header.Get("Origin")) { + } else if ps.Config.CORS.OriginAllowed(req.Header.Get("Origin")) { w.Header().Add("Access-Control-Allow-Origin", req.Header.Get("Origin")) } diff --git a/server/server.go b/server/server.go index a660bc593..7a64d17fa 100644 --- a/server/server.go +++ b/server/server.go @@ -26,27 +26,28 @@ type ServerConfig struct { Name string URL string BindAddr string + CORS *corsInfo } // This is the default implementation of the Server interface. type Server struct { http.Server - Config ServerConfig - peerServer *PeerServer - registry *Registry - listener net.Listener - store store.Store - tlsConf *TLSConfig - tlsInfo *TLSInfo - router *mux.Router - corsHandler *corsHandler + Config ServerConfig + peerServer *PeerServer + registry *Registry + listener net.Listener + store store.Store + tlsConf *TLSConfig + tlsInfo *TLSInfo + router *mux.Router + corsMiddleware *corsHTTPMiddleware metrics *metrics.Bucket } // Creates a new Server. func New(sConfig ServerConfig, tlsConf *TLSConfig, tlsInfo *TLSInfo, peerServer *PeerServer, registry *Registry, store store.Store, mb *metrics.Bucket) *Server { r := mux.NewRouter() - cors := &corsHandler{router: r} + cors := &corsHTTPMiddleware{r, sConfig.CORS} s := &Server{ Config: sConfig, @@ -61,7 +62,7 @@ func New(sConfig ServerConfig, tlsConf *TLSConfig, tlsInfo *TLSInfo, peerServer tlsInfo: tlsInfo, peerServer: peerServer, router: r, - corsHandler: cors, + corsMiddleware: cors, metrics: mb, } @@ -326,16 +327,6 @@ func (s *Server) Dispatch(c raft.Command, w http.ResponseWriter, req *http.Reque } } -// OriginAllowed determines whether the server will allow a given CORS origin. -func (s *Server) OriginAllowed(origin string) bool { - return s.corsHandler.OriginAllowed(origin) -} - -// AllowOrigins sets a comma-delimited list of origins that are allowed. -func (s *Server) AllowOrigins(origins []string) error { - return s.corsHandler.AllowOrigins(origins) -} - // Handler to return the current version of etcd. func (s *Server) GetVersionHandler(w http.ResponseWriter, req *http.Request) error { w.WriteHeader(http.StatusOK) diff --git a/tests/server_utils.go b/tests/server_utils.go index cc785d2e7..93f3795b9 100644 --- a/tests/server_utils.go +++ b/tests/server_utils.go @@ -25,6 +25,7 @@ func RunServer(f func(*server.Server)) { store := store.New() registry := server.NewRegistry(store) + corsInfo, _ := server.NewCORSInfo([]string{}) psConfig := server.PeerServerConfig{ Name: testName, @@ -35,6 +36,7 @@ func RunServer(f func(*server.Server)) { HeartbeatTimeout: testHeartbeatTimeout, ElectionTimeout: testElectionTimeout, MaxClusterSize: 9, + CORS: corsInfo, } ps := server.NewPeerServer(psConfig, &server.TLSConfig{Scheme: "http"}, &server.TLSInfo{}, registry, store, nil) @@ -42,6 +44,7 @@ func RunServer(f func(*server.Server)) { Name: testName, URL: "http://"+testClientURL, BindAddr: testClientURL, + CORS: corsInfo, } s := server.New(sConfig, &server.TLSConfig{Scheme: "http"}, &server.TLSInfo{}, ps, registry, store, nil)