refactor(server): drop Serve code; rename cors object

* server/cors.go renamed to http/cors.go
* all CORS code removed from Server and PeerServer
* Server and PeerServer fulfill http.Handler, now passed to http.Serve
* non-HTTP code in PeerServer.Serve moved to PeerServer.Start
This commit is contained in:
Brian Waldon 2014-01-20 17:22:09 -08:00
parent 5c3a3db2d8
commit 0abd860f7e
6 changed files with 50 additions and 57 deletions

12
etcd.go
View File

@ -26,6 +26,7 @@ import (
"github.com/coreos/raft"
ehttp "github.com/coreos/etcd/http"
"github.com/coreos/etcd/log"
"github.com/coreos/etcd/metrics"
"github.com/coreos/etcd/server"
@ -102,7 +103,7 @@ func main() {
}
// Retrieve CORS configuration
corsInfo, err := server.NewCORSInfo(config.CorsOrigins)
corsInfo, err := ehttp.NewCORSInfo(config.CorsOrigins)
if err != nil {
log.Fatal("CORS:", err)
}
@ -130,7 +131,6 @@ func main() {
SnapshotCount: config.SnapshotCount,
MaxClusterSize: config.MaxClusterSize,
RetryTimes: config.MaxRetryAttempts,
CORS: corsInfo,
}
ps := server.NewPeerServer(psConfig, registry, store, &mb, followersStats, serverStats)
@ -177,12 +177,16 @@ func main() {
ps.SetServer(s)
ps.Start(config.Snapshot, config.Peers)
// Run peer server in separate thread while the client server blocks.
go func() {
log.Fatal(ps.Serve(psListener, config.Snapshot, config.Peers))
log.Infof("raft server [name %s, listen on %s, advertised url %s]", ps.Config.Name, psListener.Addr(), ps.Config.URL)
sHTTP := &ehttp.CORSHandler{ps, corsInfo}
log.Fatal(http.Serve(psListener, sHTTP))
}()
log.Infof("etcd server [name %s, listen on %s, advertised url %s]", s.Config.Name, sListener.Addr(), s.Config.URL)
sHTTP := &server.CORSHTTPMiddleware{s, corsInfo}
sHTTP := &ehttp.CORSHandler{s, corsInfo}
log.Fatal(http.Serve(sListener, sHTTP))
}

View File

@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
package server
package http
import (
"fmt"
@ -22,9 +22,9 @@ import (
"net/url"
)
type corsInfo map[string]bool
type CORSInfo map[string]bool
func NewCORSInfo(origins []string) (*corsInfo, error) {
func NewCORSInfo(origins []string) (*CORSInfo, error) {
// Construct a lookup of all origins.
m := make(map[string]bool)
for _, v := range origins {
@ -36,29 +36,29 @@ func NewCORSInfo(origins []string) (*corsInfo, error) {
m[v] = true
}
info := corsInfo(m)
info := CORSInfo(m)
return &info, nil
}
// OriginAllowed determines whether the server will allow a given CORS origin.
func (c corsInfo) OriginAllowed(origin string) bool {
func (c CORSInfo) OriginAllowed(origin string) bool {
return c["*"] || c[origin]
}
type CORSHTTPMiddleware struct {
type CORSHandler struct {
Handler http.Handler
Info *corsInfo
Info *CORSInfo
}
// addHeader adds the correct cors headers given an origin
func (h *CORSHTTPMiddleware) addHeader(w http.ResponseWriter, origin string) {
func (h *CORSHandler) addHeader(w http.ResponseWriter, origin string) {
w.Header().Add("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
w.Header().Add("Access-Control-Allow-Origin", origin)
}
// ServeHTTP adds the correct CORS headers based on the origin and returns immediatly
// with a 200 OK if the method is OPTIONS.
func (h *CORSHTTPMiddleware) ServeHTTP(w http.ResponseWriter, req *http.Request) {
func (h *CORSHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// Write CORS header.
if h.Info.OriginAllowed("*") {
h.addHeader(w, "*")

View File

@ -6,7 +6,6 @@ import (
"encoding/json"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/url"
"strconv"
@ -35,11 +34,11 @@ type PeerServerConfig struct {
ElectionTimeout time.Duration
MaxClusterSize int
RetryTimes int
CORS *corsInfo
}
type PeerServer struct {
Config PeerServerConfig
handler http.Handler
raftServer raft.Server
server *Server
joinIndex uint64
@ -49,8 +48,6 @@ type PeerServer struct {
store store.Store
snapConf *snapshotConf
listener net.Listener
closeChan chan bool
timeoutThresholdChan chan interface{}
@ -82,6 +79,9 @@ func NewPeerServer(psConfig PeerServerConfig, registry *Registry, store store.St
metrics: mb,
}
s.handler = s.buildHTTPHandler()
return s
}
@ -107,7 +107,7 @@ func (s *PeerServer) SetRaftServer(raftServer raft.Server) {
}
// Start the raft server
func (s *PeerServer) Serve(listener net.Listener, snapshot bool, cluster []string) error {
func (s *PeerServer) Start(snapshot bool, cluster []string) error {
// LoadSnapshot
if snapshot {
err := s.raftServer.LoadSnapshot()
@ -157,8 +157,18 @@ func (s *PeerServer) Serve(listener net.Listener, snapshot bool, cluster []strin
go s.monitorSnapshot()
}
return nil
}
func (s *PeerServer) Stop() {
if s.closeChan != nil {
close(s.closeChan)
s.closeChan = nil
}
}
func (s *PeerServer) buildHTTPHandler() http.Handler {
router := mux.NewRouter()
httpServer := &http.Server{Handler: router}
// internal commands
router.HandleFunc("/name", s.NameHttpHandler)
@ -174,21 +184,11 @@ func (s *PeerServer) Serve(listener net.Listener, snapshot bool, cluster []strin
router.HandleFunc("/snapshotRecovery", s.SnapshotRecoveryHttpHandler)
router.HandleFunc("/etcdURL", s.EtcdURLHttpHandler)
s.listener = listener
log.Infof("raft server [name %s, listen on %s, advertised url %s]", s.Config.Name, listener.Addr(), s.Config.URL)
httpServer.Serve(listener)
return nil
return router
}
func (s *PeerServer) Close() {
if s.closeChan != nil {
close(s.closeChan)
s.closeChan = nil
}
if s.listener != nil {
s.listener.Close()
s.listener = nil
}
func (s *PeerServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.handler.ServeHTTP(w, r)
}
// Retrieves the underlying Raft server.

View File

@ -149,13 +149,6 @@ func (ps *PeerServer) EtcdURLHttpHandler(w http.ResponseWriter, req *http.Reques
func (ps *PeerServer) JoinHttpHandler(w http.ResponseWriter, req *http.Request) {
command := &JoinCommand{}
// Write CORS header.
if ps.Config.CORS.OriginAllowed("*") {
w.Header().Add("Access-Control-Allow-Origin", "*")
} else if ps.Config.CORS.OriginAllowed(req.Header.Get("Origin")) {
w.Header().Add("Access-Control-Allow-Origin", req.Header.Get("Origin"))
}
err := decodeJsonRequest(req, command)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)

View File

@ -3,7 +3,6 @@ package server
import (
"encoding/json"
"fmt"
"net"
"net/http"
"net/http/pprof"
"strings"
@ -30,13 +29,12 @@ type ServerConfig struct {
// This is the default implementation of the Server interface.
type Server struct {
Config ServerConfig
handler http.Handler
peerServer *PeerServer
registry *Registry
store store.Store
metrics *metrics.Bucket
listener net.Listener
trace bool
}
@ -50,6 +48,8 @@ func New(sConfig ServerConfig, peerServer *PeerServer, registry *Registry, store
metrics: mb,
}
s.handler = s.buildHTTPHandler()
return s
}
@ -172,7 +172,7 @@ func (s *Server) handleFunc(r *mux.Router, path string, f func(http.ResponseWrit
})
}
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func (s *Server) buildHTTPHandler() http.Handler {
router := mux.NewRouter()
// Install the routes.
@ -185,15 +185,11 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.installDebug(router)
}
router.ServeHTTP(w, r)
return router
}
// Stops the server.
func (s *Server) Close() {
if s.listener != nil {
s.listener.Close()
s.listener = nil
}
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.handler.ServeHTTP(w, r)
}
// Dispatch command to the current leader

View File

@ -2,6 +2,7 @@ package tests
import (
"io/ioutil"
"net/http"
"os"
"time"
@ -27,7 +28,6 @@ func RunServer(f func(*server.Server)) {
store := store.New()
registry := server.NewRegistry(store)
corsInfo, _ := server.NewCORSInfo([]string{})
serverStats := server.NewRaftServerStats(testName)
followersStats := server.NewRaftFollowersStats(testName)
@ -39,7 +39,6 @@ func RunServer(f func(*server.Server)) {
Scheme: "http",
SnapshotCount: testSnapshotCount,
MaxClusterSize: 9,
CORS: corsInfo,
}
ps := server.NewPeerServer(psConfig, registry, store, nil, followersStats, serverStats)
psListener, err := server.NewListener(testRaftURL)
@ -63,7 +62,6 @@ func RunServer(f func(*server.Server)) {
sConfig := server.ServerConfig{
Name: testName,
URL: "http://"+testClientURL,
CORS: corsInfo,
}
s := server.New(sConfig, ps, registry, store, nil)
sListener, err := server.NewListener(testClientURL)
@ -77,14 +75,15 @@ func RunServer(f func(*server.Server)) {
c := make(chan bool)
go func() {
c <- true
ps.Serve(psListener, false, []string{})
ps.Start(false, []string{})
http.Serve(psListener, ps)
}()
<-c
// Start up etcd server.
go func() {
c <- true
s.Serve(sListener)
http.Serve(sListener, s)
}()
<-c
@ -95,6 +94,7 @@ func RunServer(f func(*server.Server)) {
f(s)
// Clean up servers.
ps.Close()
s.Close()
ps.Stop()
psListener.Close()
sListener.Close()
}