refactor(server): drop Serve code; rename cors object

* server/cors.go renamed to http/cors.go
* all CORS code removed from Server and PeerServer
* Server and PeerServer fulfill http.Handler, now passed to http.Serve
* non-HTTP code in PeerServer.Serve moved to PeerServer.Start
This commit is contained in:
Brian Waldon 2014-01-20 17:22:09 -08:00
parent 5c3a3db2d8
commit 0abd860f7e
6 changed files with 50 additions and 57 deletions

12
etcd.go
View File

@ -26,6 +26,7 @@ import (
"github.com/coreos/raft" "github.com/coreos/raft"
ehttp "github.com/coreos/etcd/http"
"github.com/coreos/etcd/log" "github.com/coreos/etcd/log"
"github.com/coreos/etcd/metrics" "github.com/coreos/etcd/metrics"
"github.com/coreos/etcd/server" "github.com/coreos/etcd/server"
@ -102,7 +103,7 @@ func main() {
} }
// Retrieve CORS configuration // Retrieve CORS configuration
corsInfo, err := server.NewCORSInfo(config.CorsOrigins) corsInfo, err := ehttp.NewCORSInfo(config.CorsOrigins)
if err != nil { if err != nil {
log.Fatal("CORS:", err) log.Fatal("CORS:", err)
} }
@ -130,7 +131,6 @@ func main() {
SnapshotCount: config.SnapshotCount, SnapshotCount: config.SnapshotCount,
MaxClusterSize: config.MaxClusterSize, MaxClusterSize: config.MaxClusterSize,
RetryTimes: config.MaxRetryAttempts, RetryTimes: config.MaxRetryAttempts,
CORS: corsInfo,
} }
ps := server.NewPeerServer(psConfig, registry, store, &mb, followersStats, serverStats) ps := server.NewPeerServer(psConfig, registry, store, &mb, followersStats, serverStats)
@ -177,12 +177,16 @@ func main() {
ps.SetServer(s) ps.SetServer(s)
ps.Start(config.Snapshot, config.Peers)
// Run peer server in separate thread while the client server blocks. // Run peer server in separate thread while the client server blocks.
go func() { go func() {
log.Fatal(ps.Serve(psListener, config.Snapshot, config.Peers)) log.Infof("raft server [name %s, listen on %s, advertised url %s]", ps.Config.Name, psListener.Addr(), ps.Config.URL)
sHTTP := &ehttp.CORSHandler{ps, corsInfo}
log.Fatal(http.Serve(psListener, sHTTP))
}() }()
log.Infof("etcd server [name %s, listen on %s, advertised url %s]", s.Config.Name, sListener.Addr(), s.Config.URL) log.Infof("etcd server [name %s, listen on %s, advertised url %s]", s.Config.Name, sListener.Addr(), s.Config.URL)
sHTTP := &server.CORSHTTPMiddleware{s, corsInfo} sHTTP := &ehttp.CORSHandler{s, corsInfo}
log.Fatal(http.Serve(sListener, sHTTP)) log.Fatal(http.Serve(sListener, sHTTP))
} }

View File

@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
package server package http
import ( import (
"fmt" "fmt"
@ -22,9 +22,9 @@ import (
"net/url" "net/url"
) )
type corsInfo map[string]bool type CORSInfo map[string]bool
func NewCORSInfo(origins []string) (*corsInfo, error) { func NewCORSInfo(origins []string) (*CORSInfo, error) {
// Construct a lookup of all origins. // Construct a lookup of all origins.
m := make(map[string]bool) m := make(map[string]bool)
for _, v := range origins { for _, v := range origins {
@ -36,29 +36,29 @@ func NewCORSInfo(origins []string) (*corsInfo, error) {
m[v] = true m[v] = true
} }
info := corsInfo(m) info := CORSInfo(m)
return &info, nil return &info, nil
} }
// OriginAllowed determines whether the server will allow a given CORS origin. // OriginAllowed determines whether the server will allow a given CORS origin.
func (c corsInfo) OriginAllowed(origin string) bool { func (c CORSInfo) OriginAllowed(origin string) bool {
return c["*"] || c[origin] return c["*"] || c[origin]
} }
type CORSHTTPMiddleware struct { type CORSHandler struct {
Handler http.Handler Handler http.Handler
Info *corsInfo Info *CORSInfo
} }
// addHeader adds the correct cors headers given an origin // addHeader adds the correct cors headers given an origin
func (h *CORSHTTPMiddleware) addHeader(w http.ResponseWriter, origin string) { func (h *CORSHandler) addHeader(w http.ResponseWriter, origin string) {
w.Header().Add("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE") w.Header().Add("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
w.Header().Add("Access-Control-Allow-Origin", origin) w.Header().Add("Access-Control-Allow-Origin", origin)
} }
// ServeHTTP adds the correct CORS headers based on the origin and returns immediatly // ServeHTTP adds the correct CORS headers based on the origin and returns immediatly
// with a 200 OK if the method is OPTIONS. // with a 200 OK if the method is OPTIONS.
func (h *CORSHTTPMiddleware) ServeHTTP(w http.ResponseWriter, req *http.Request) { func (h *CORSHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// Write CORS header. // Write CORS header.
if h.Info.OriginAllowed("*") { if h.Info.OriginAllowed("*") {
h.addHeader(w, "*") h.addHeader(w, "*")

View File

@ -6,7 +6,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net"
"net/http" "net/http"
"net/url" "net/url"
"strconv" "strconv"
@ -35,11 +34,11 @@ type PeerServerConfig struct {
ElectionTimeout time.Duration ElectionTimeout time.Duration
MaxClusterSize int MaxClusterSize int
RetryTimes int RetryTimes int
CORS *corsInfo
} }
type PeerServer struct { type PeerServer struct {
Config PeerServerConfig Config PeerServerConfig
handler http.Handler
raftServer raft.Server raftServer raft.Server
server *Server server *Server
joinIndex uint64 joinIndex uint64
@ -49,8 +48,6 @@ type PeerServer struct {
store store.Store store store.Store
snapConf *snapshotConf snapConf *snapshotConf
listener net.Listener
closeChan chan bool closeChan chan bool
timeoutThresholdChan chan interface{} timeoutThresholdChan chan interface{}
@ -82,6 +79,9 @@ func NewPeerServer(psConfig PeerServerConfig, registry *Registry, store store.St
metrics: mb, metrics: mb,
} }
s.handler = s.buildHTTPHandler()
return s return s
} }
@ -107,7 +107,7 @@ func (s *PeerServer) SetRaftServer(raftServer raft.Server) {
} }
// Start the raft server // Start the raft server
func (s *PeerServer) Serve(listener net.Listener, snapshot bool, cluster []string) error { func (s *PeerServer) Start(snapshot bool, cluster []string) error {
// LoadSnapshot // LoadSnapshot
if snapshot { if snapshot {
err := s.raftServer.LoadSnapshot() err := s.raftServer.LoadSnapshot()
@ -157,8 +157,18 @@ func (s *PeerServer) Serve(listener net.Listener, snapshot bool, cluster []strin
go s.monitorSnapshot() go s.monitorSnapshot()
} }
return nil
}
func (s *PeerServer) Stop() {
if s.closeChan != nil {
close(s.closeChan)
s.closeChan = nil
}
}
func (s *PeerServer) buildHTTPHandler() http.Handler {
router := mux.NewRouter() router := mux.NewRouter()
httpServer := &http.Server{Handler: router}
// internal commands // internal commands
router.HandleFunc("/name", s.NameHttpHandler) router.HandleFunc("/name", s.NameHttpHandler)
@ -174,21 +184,11 @@ func (s *PeerServer) Serve(listener net.Listener, snapshot bool, cluster []strin
router.HandleFunc("/snapshotRecovery", s.SnapshotRecoveryHttpHandler) router.HandleFunc("/snapshotRecovery", s.SnapshotRecoveryHttpHandler)
router.HandleFunc("/etcdURL", s.EtcdURLHttpHandler) router.HandleFunc("/etcdURL", s.EtcdURLHttpHandler)
s.listener = listener return router
log.Infof("raft server [name %s, listen on %s, advertised url %s]", s.Config.Name, listener.Addr(), s.Config.URL)
httpServer.Serve(listener)
return nil
} }
func (s *PeerServer) Close() { func (s *PeerServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if s.closeChan != nil { s.handler.ServeHTTP(w, r)
close(s.closeChan)
s.closeChan = nil
}
if s.listener != nil {
s.listener.Close()
s.listener = nil
}
} }
// Retrieves the underlying Raft server. // Retrieves the underlying Raft server.

View File

@ -149,13 +149,6 @@ func (ps *PeerServer) EtcdURLHttpHandler(w http.ResponseWriter, req *http.Reques
func (ps *PeerServer) JoinHttpHandler(w http.ResponseWriter, req *http.Request) { func (ps *PeerServer) JoinHttpHandler(w http.ResponseWriter, req *http.Request) {
command := &JoinCommand{} command := &JoinCommand{}
// Write CORS header.
if ps.Config.CORS.OriginAllowed("*") {
w.Header().Add("Access-Control-Allow-Origin", "*")
} else if ps.Config.CORS.OriginAllowed(req.Header.Get("Origin")) {
w.Header().Add("Access-Control-Allow-Origin", req.Header.Get("Origin"))
}
err := decodeJsonRequest(req, command) err := decodeJsonRequest(req, command)
if err != nil { if err != nil {
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)

View File

@ -3,7 +3,6 @@ package server
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net"
"net/http" "net/http"
"net/http/pprof" "net/http/pprof"
"strings" "strings"
@ -30,13 +29,12 @@ type ServerConfig struct {
// This is the default implementation of the Server interface. // This is the default implementation of the Server interface.
type Server struct { type Server struct {
Config ServerConfig Config ServerConfig
handler http.Handler
peerServer *PeerServer peerServer *PeerServer
registry *Registry registry *Registry
store store.Store store store.Store
metrics *metrics.Bucket metrics *metrics.Bucket
listener net.Listener
trace bool trace bool
} }
@ -50,6 +48,8 @@ func New(sConfig ServerConfig, peerServer *PeerServer, registry *Registry, store
metrics: mb, metrics: mb,
} }
s.handler = s.buildHTTPHandler()
return s return s
} }
@ -172,7 +172,7 @@ func (s *Server) handleFunc(r *mux.Router, path string, f func(http.ResponseWrit
}) })
} }
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (s *Server) buildHTTPHandler() http.Handler {
router := mux.NewRouter() router := mux.NewRouter()
// Install the routes. // Install the routes.
@ -185,15 +185,11 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.installDebug(router) s.installDebug(router)
} }
router.ServeHTTP(w, r) return router
} }
// Stops the server. func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func (s *Server) Close() { s.handler.ServeHTTP(w, r)
if s.listener != nil {
s.listener.Close()
s.listener = nil
}
} }
// Dispatch command to the current leader // Dispatch command to the current leader

View File

@ -2,6 +2,7 @@ package tests
import ( import (
"io/ioutil" "io/ioutil"
"net/http"
"os" "os"
"time" "time"
@ -27,7 +28,6 @@ func RunServer(f func(*server.Server)) {
store := store.New() store := store.New()
registry := server.NewRegistry(store) registry := server.NewRegistry(store)
corsInfo, _ := server.NewCORSInfo([]string{})
serverStats := server.NewRaftServerStats(testName) serverStats := server.NewRaftServerStats(testName)
followersStats := server.NewRaftFollowersStats(testName) followersStats := server.NewRaftFollowersStats(testName)
@ -39,7 +39,6 @@ func RunServer(f func(*server.Server)) {
Scheme: "http", Scheme: "http",
SnapshotCount: testSnapshotCount, SnapshotCount: testSnapshotCount,
MaxClusterSize: 9, MaxClusterSize: 9,
CORS: corsInfo,
} }
ps := server.NewPeerServer(psConfig, registry, store, nil, followersStats, serverStats) ps := server.NewPeerServer(psConfig, registry, store, nil, followersStats, serverStats)
psListener, err := server.NewListener(testRaftURL) psListener, err := server.NewListener(testRaftURL)
@ -63,7 +62,6 @@ func RunServer(f func(*server.Server)) {
sConfig := server.ServerConfig{ sConfig := server.ServerConfig{
Name: testName, Name: testName,
URL: "http://"+testClientURL, URL: "http://"+testClientURL,
CORS: corsInfo,
} }
s := server.New(sConfig, ps, registry, store, nil) s := server.New(sConfig, ps, registry, store, nil)
sListener, err := server.NewListener(testClientURL) sListener, err := server.NewListener(testClientURL)
@ -77,14 +75,15 @@ func RunServer(f func(*server.Server)) {
c := make(chan bool) c := make(chan bool)
go func() { go func() {
c <- true c <- true
ps.Serve(psListener, false, []string{}) ps.Start(false, []string{})
http.Serve(psListener, ps)
}() }()
<-c <-c
// Start up etcd server. // Start up etcd server.
go func() { go func() {
c <- true c <- true
s.Serve(sListener) http.Serve(sListener, s)
}() }()
<-c <-c
@ -95,6 +94,7 @@ func RunServer(f func(*server.Server)) {
f(s) f(s)
// Clean up servers. // Clean up servers.
ps.Close() ps.Stop()
s.Close() psListener.Close()
sListener.Close()
} }