From 0abd860f7e655903a3dee37964b3126c94fee13a Mon Sep 17 00:00:00 2001 From: Brian Waldon Date: Mon, 20 Jan 2014 17:22:09 -0800 Subject: [PATCH] refactor(server): drop Serve code; rename cors object * server/cors.go renamed to http/cors.go * all CORS code removed from Server and PeerServer * Server and PeerServer fulfill http.Handler, now passed to http.Serve * non-HTTP code in PeerServer.Serve moved to PeerServer.Start --- etcd.go | 12 +++++++---- {server => http}/cors.go | 18 ++++++++-------- server/peer_server.go | 38 +++++++++++++++++----------------- server/peer_server_handlers.go | 7 ------- server/server.go | 18 +++++++--------- tests/server_utils.go | 14 ++++++------- 6 files changed, 50 insertions(+), 57 deletions(-) rename {server => http}/cors.go (80%) diff --git a/etcd.go b/etcd.go index 8f98fa34e..385fc47c2 100644 --- a/etcd.go +++ b/etcd.go @@ -26,6 +26,7 @@ import ( "github.com/coreos/raft" + ehttp "github.com/coreos/etcd/http" "github.com/coreos/etcd/log" "github.com/coreos/etcd/metrics" "github.com/coreos/etcd/server" @@ -102,7 +103,7 @@ func main() { } // Retrieve CORS configuration - corsInfo, err := server.NewCORSInfo(config.CorsOrigins) + corsInfo, err := ehttp.NewCORSInfo(config.CorsOrigins) if err != nil { log.Fatal("CORS:", err) } @@ -130,7 +131,6 @@ func main() { SnapshotCount: config.SnapshotCount, MaxClusterSize: config.MaxClusterSize, RetryTimes: config.MaxRetryAttempts, - CORS: corsInfo, } ps := server.NewPeerServer(psConfig, registry, store, &mb, followersStats, serverStats) @@ -177,12 +177,16 @@ func main() { ps.SetServer(s) + ps.Start(config.Snapshot, config.Peers) + // Run peer server in separate thread while the client server blocks. go func() { - log.Fatal(ps.Serve(psListener, config.Snapshot, config.Peers)) + log.Infof("raft server [name %s, listen on %s, advertised url %s]", ps.Config.Name, psListener.Addr(), ps.Config.URL) + sHTTP := &ehttp.CORSHandler{ps, corsInfo} + log.Fatal(http.Serve(psListener, sHTTP)) }() log.Infof("etcd server [name %s, listen on %s, advertised url %s]", s.Config.Name, sListener.Addr(), s.Config.URL) - sHTTP := &server.CORSHTTPMiddleware{s, corsInfo} + sHTTP := &ehttp.CORSHandler{s, corsInfo} log.Fatal(http.Serve(sListener, sHTTP)) } diff --git a/server/cors.go b/http/cors.go similarity index 80% rename from server/cors.go rename to http/cors.go index d585fbfc9..16e616eed 100644 --- a/server/cors.go +++ b/http/cors.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package server +package http import ( "fmt" @@ -22,9 +22,9 @@ import ( "net/url" ) -type corsInfo map[string]bool +type CORSInfo map[string]bool -func NewCORSInfo(origins []string) (*corsInfo, error) { +func NewCORSInfo(origins []string) (*CORSInfo, error) { // Construct a lookup of all origins. m := make(map[string]bool) for _, v := range origins { @@ -36,29 +36,29 @@ func NewCORSInfo(origins []string) (*corsInfo, error) { m[v] = true } - info := corsInfo(m) + info := CORSInfo(m) return &info, nil } // OriginAllowed determines whether the server will allow a given CORS origin. -func (c corsInfo) OriginAllowed(origin string) bool { +func (c CORSInfo) OriginAllowed(origin string) bool { return c["*"] || c[origin] } -type CORSHTTPMiddleware struct { +type CORSHandler struct { Handler http.Handler - Info *corsInfo + Info *CORSInfo } // addHeader adds the correct cors headers given an origin -func (h *CORSHTTPMiddleware) addHeader(w http.ResponseWriter, origin string) { +func (h *CORSHandler) addHeader(w http.ResponseWriter, origin string) { w.Header().Add("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE") w.Header().Add("Access-Control-Allow-Origin", origin) } // ServeHTTP adds the correct CORS headers based on the origin and returns immediatly // with a 200 OK if the method is OPTIONS. -func (h *CORSHTTPMiddleware) ServeHTTP(w http.ResponseWriter, req *http.Request) { +func (h *CORSHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { // Write CORS header. if h.Info.OriginAllowed("*") { h.addHeader(w, "*") diff --git a/server/peer_server.go b/server/peer_server.go index 5576cc6d7..705cd9c13 100644 --- a/server/peer_server.go +++ b/server/peer_server.go @@ -6,7 +6,6 @@ import ( "encoding/json" "fmt" "io/ioutil" - "net" "net/http" "net/url" "strconv" @@ -35,11 +34,11 @@ type PeerServerConfig struct { ElectionTimeout time.Duration MaxClusterSize int RetryTimes int - CORS *corsInfo } type PeerServer struct { Config PeerServerConfig + handler http.Handler raftServer raft.Server server *Server joinIndex uint64 @@ -49,8 +48,6 @@ type PeerServer struct { store store.Store snapConf *snapshotConf - listener net.Listener - closeChan chan bool timeoutThresholdChan chan interface{} @@ -82,6 +79,9 @@ func NewPeerServer(psConfig PeerServerConfig, registry *Registry, store store.St metrics: mb, } + + s.handler = s.buildHTTPHandler() + return s } @@ -107,7 +107,7 @@ func (s *PeerServer) SetRaftServer(raftServer raft.Server) { } // Start the raft server -func (s *PeerServer) Serve(listener net.Listener, snapshot bool, cluster []string) error { +func (s *PeerServer) Start(snapshot bool, cluster []string) error { // LoadSnapshot if snapshot { err := s.raftServer.LoadSnapshot() @@ -157,8 +157,18 @@ func (s *PeerServer) Serve(listener net.Listener, snapshot bool, cluster []strin go s.monitorSnapshot() } + return nil +} + +func (s *PeerServer) Stop() { + if s.closeChan != nil { + close(s.closeChan) + s.closeChan = nil + } +} + +func (s *PeerServer) buildHTTPHandler() http.Handler { router := mux.NewRouter() - httpServer := &http.Server{Handler: router} // internal commands router.HandleFunc("/name", s.NameHttpHandler) @@ -174,21 +184,11 @@ func (s *PeerServer) Serve(listener net.Listener, snapshot bool, cluster []strin router.HandleFunc("/snapshotRecovery", s.SnapshotRecoveryHttpHandler) router.HandleFunc("/etcdURL", s.EtcdURLHttpHandler) - s.listener = listener - log.Infof("raft server [name %s, listen on %s, advertised url %s]", s.Config.Name, listener.Addr(), s.Config.URL) - httpServer.Serve(listener) - return nil + return router } -func (s *PeerServer) Close() { - if s.closeChan != nil { - close(s.closeChan) - s.closeChan = nil - } - if s.listener != nil { - s.listener.Close() - s.listener = nil - } +func (s *PeerServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + s.handler.ServeHTTP(w, r) } // Retrieves the underlying Raft server. diff --git a/server/peer_server_handlers.go b/server/peer_server_handlers.go index a2c498101..fdcbf3df5 100644 --- a/server/peer_server_handlers.go +++ b/server/peer_server_handlers.go @@ -149,13 +149,6 @@ func (ps *PeerServer) EtcdURLHttpHandler(w http.ResponseWriter, req *http.Reques func (ps *PeerServer) JoinHttpHandler(w http.ResponseWriter, req *http.Request) { command := &JoinCommand{} - // Write CORS header. - if ps.Config.CORS.OriginAllowed("*") { - w.Header().Add("Access-Control-Allow-Origin", "*") - } else if ps.Config.CORS.OriginAllowed(req.Header.Get("Origin")) { - w.Header().Add("Access-Control-Allow-Origin", req.Header.Get("Origin")) - } - err := decodeJsonRequest(req, command) if err != nil { w.WriteHeader(http.StatusInternalServerError) diff --git a/server/server.go b/server/server.go index dd6d08ff0..f4b823355 100644 --- a/server/server.go +++ b/server/server.go @@ -3,7 +3,6 @@ package server import ( "encoding/json" "fmt" - "net" "net/http" "net/http/pprof" "strings" @@ -30,13 +29,12 @@ type ServerConfig struct { // This is the default implementation of the Server interface. type Server struct { Config ServerConfig + handler http.Handler peerServer *PeerServer registry *Registry store store.Store metrics *metrics.Bucket - listener net.Listener - trace bool } @@ -50,6 +48,8 @@ func New(sConfig ServerConfig, peerServer *PeerServer, registry *Registry, store metrics: mb, } + s.handler = s.buildHTTPHandler() + return s } @@ -172,7 +172,7 @@ func (s *Server) handleFunc(r *mux.Router, path string, f func(http.ResponseWrit }) } -func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { +func (s *Server) buildHTTPHandler() http.Handler { router := mux.NewRouter() // Install the routes. @@ -185,15 +185,11 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.installDebug(router) } - router.ServeHTTP(w, r) + return router } -// Stops the server. -func (s *Server) Close() { - if s.listener != nil { - s.listener.Close() - s.listener = nil - } +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + s.handler.ServeHTTP(w, r) } // Dispatch command to the current leader diff --git a/tests/server_utils.go b/tests/server_utils.go index cbe11aa1e..5a0cb5031 100644 --- a/tests/server_utils.go +++ b/tests/server_utils.go @@ -2,6 +2,7 @@ package tests import ( "io/ioutil" + "net/http" "os" "time" @@ -27,7 +28,6 @@ func RunServer(f func(*server.Server)) { store := store.New() registry := server.NewRegistry(store) - corsInfo, _ := server.NewCORSInfo([]string{}) serverStats := server.NewRaftServerStats(testName) followersStats := server.NewRaftFollowersStats(testName) @@ -39,7 +39,6 @@ func RunServer(f func(*server.Server)) { Scheme: "http", SnapshotCount: testSnapshotCount, MaxClusterSize: 9, - CORS: corsInfo, } ps := server.NewPeerServer(psConfig, registry, store, nil, followersStats, serverStats) psListener, err := server.NewListener(testRaftURL) @@ -63,7 +62,6 @@ func RunServer(f func(*server.Server)) { sConfig := server.ServerConfig{ Name: testName, URL: "http://"+testClientURL, - CORS: corsInfo, } s := server.New(sConfig, ps, registry, store, nil) sListener, err := server.NewListener(testClientURL) @@ -77,14 +75,15 @@ func RunServer(f func(*server.Server)) { c := make(chan bool) go func() { c <- true - ps.Serve(psListener, false, []string{}) + ps.Start(false, []string{}) + http.Serve(psListener, ps) }() <-c // Start up etcd server. go func() { c <- true - s.Serve(sListener) + http.Serve(sListener, s) }() <-c @@ -95,6 +94,7 @@ func RunServer(f func(*server.Server)) { f(s) // Clean up servers. - ps.Close() - s.Close() + ps.Stop() + psListener.Close() + sListener.Close() }