mirror of
https://github.com/etcd-io/etcd.git
synced 2024-09-27 06:25:44 +00:00
refactor(server): drop Serve code; rename cors object
* server/cors.go renamed to http/cors.go * all CORS code removed from Server and PeerServer * Server and PeerServer fulfill http.Handler, now passed to http.Serve * non-HTTP code in PeerServer.Serve moved to PeerServer.Start
This commit is contained in:
@@ -1,75 +0,0 @@
|
||||
/*
|
||||
Copyright 2013 CoreOS Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
type corsInfo map[string]bool
|
||||
|
||||
func NewCORSInfo(origins []string) (*corsInfo, error) {
|
||||
// Construct a lookup of all origins.
|
||||
m := make(map[string]bool)
|
||||
for _, v := range origins {
|
||||
if v != "*" {
|
||||
if _, err := url.Parse(v); err != nil {
|
||||
return nil, fmt.Errorf("Invalid CORS origin: %s", err)
|
||||
}
|
||||
}
|
||||
m[v] = true
|
||||
}
|
||||
|
||||
info := corsInfo(m)
|
||||
return &info, nil
|
||||
}
|
||||
|
||||
// OriginAllowed determines whether the server will allow a given CORS origin.
|
||||
func (c corsInfo) OriginAllowed(origin string) bool {
|
||||
return c["*"] || c[origin]
|
||||
}
|
||||
|
||||
type CORSHTTPMiddleware struct {
|
||||
Handler http.Handler
|
||||
Info *corsInfo
|
||||
}
|
||||
|
||||
// addHeader adds the correct cors headers given an origin
|
||||
func (h *CORSHTTPMiddleware) addHeader(w http.ResponseWriter, origin string) {
|
||||
w.Header().Add("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
|
||||
w.Header().Add("Access-Control-Allow-Origin", origin)
|
||||
}
|
||||
|
||||
// ServeHTTP adds the correct CORS headers based on the origin and returns immediatly
|
||||
// with a 200 OK if the method is OPTIONS.
|
||||
func (h *CORSHTTPMiddleware) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
// Write CORS header.
|
||||
if h.Info.OriginAllowed("*") {
|
||||
h.addHeader(w, "*")
|
||||
} else if origin := req.Header.Get("Origin"); h.Info.OriginAllowed(origin) {
|
||||
h.addHeader(w, origin)
|
||||
}
|
||||
|
||||
if req.Method == "OPTIONS" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
h.Handler.ServeHTTP(w, req)
|
||||
}
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
@@ -35,11 +34,11 @@ type PeerServerConfig struct {
|
||||
ElectionTimeout time.Duration
|
||||
MaxClusterSize int
|
||||
RetryTimes int
|
||||
CORS *corsInfo
|
||||
}
|
||||
|
||||
type PeerServer struct {
|
||||
Config PeerServerConfig
|
||||
handler http.Handler
|
||||
raftServer raft.Server
|
||||
server *Server
|
||||
joinIndex uint64
|
||||
@@ -49,8 +48,6 @@ type PeerServer struct {
|
||||
store store.Store
|
||||
snapConf *snapshotConf
|
||||
|
||||
listener net.Listener
|
||||
|
||||
closeChan chan bool
|
||||
timeoutThresholdChan chan interface{}
|
||||
|
||||
@@ -82,6 +79,9 @@ func NewPeerServer(psConfig PeerServerConfig, registry *Registry, store store.St
|
||||
|
||||
metrics: mb,
|
||||
}
|
||||
|
||||
s.handler = s.buildHTTPHandler()
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
@@ -107,7 +107,7 @@ func (s *PeerServer) SetRaftServer(raftServer raft.Server) {
|
||||
}
|
||||
|
||||
// Start the raft server
|
||||
func (s *PeerServer) Serve(listener net.Listener, snapshot bool, cluster []string) error {
|
||||
func (s *PeerServer) Start(snapshot bool, cluster []string) error {
|
||||
// LoadSnapshot
|
||||
if snapshot {
|
||||
err := s.raftServer.LoadSnapshot()
|
||||
@@ -157,8 +157,18 @@ func (s *PeerServer) Serve(listener net.Listener, snapshot bool, cluster []strin
|
||||
go s.monitorSnapshot()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PeerServer) Stop() {
|
||||
if s.closeChan != nil {
|
||||
close(s.closeChan)
|
||||
s.closeChan = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PeerServer) buildHTTPHandler() http.Handler {
|
||||
router := mux.NewRouter()
|
||||
httpServer := &http.Server{Handler: router}
|
||||
|
||||
// internal commands
|
||||
router.HandleFunc("/name", s.NameHttpHandler)
|
||||
@@ -174,21 +184,11 @@ func (s *PeerServer) Serve(listener net.Listener, snapshot bool, cluster []strin
|
||||
router.HandleFunc("/snapshotRecovery", s.SnapshotRecoveryHttpHandler)
|
||||
router.HandleFunc("/etcdURL", s.EtcdURLHttpHandler)
|
||||
|
||||
s.listener = listener
|
||||
log.Infof("raft server [name %s, listen on %s, advertised url %s]", s.Config.Name, listener.Addr(), s.Config.URL)
|
||||
httpServer.Serve(listener)
|
||||
return nil
|
||||
return router
|
||||
}
|
||||
|
||||
func (s *PeerServer) Close() {
|
||||
if s.closeChan != nil {
|
||||
close(s.closeChan)
|
||||
s.closeChan = nil
|
||||
}
|
||||
if s.listener != nil {
|
||||
s.listener.Close()
|
||||
s.listener = nil
|
||||
}
|
||||
func (s *PeerServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
s.handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// Retrieves the underlying Raft server.
|
||||
|
||||
@@ -149,13 +149,6 @@ func (ps *PeerServer) EtcdURLHttpHandler(w http.ResponseWriter, req *http.Reques
|
||||
func (ps *PeerServer) JoinHttpHandler(w http.ResponseWriter, req *http.Request) {
|
||||
command := &JoinCommand{}
|
||||
|
||||
// Write CORS header.
|
||||
if ps.Config.CORS.OriginAllowed("*") {
|
||||
w.Header().Add("Access-Control-Allow-Origin", "*")
|
||||
} else if ps.Config.CORS.OriginAllowed(req.Header.Get("Origin")) {
|
||||
w.Header().Add("Access-Control-Allow-Origin", req.Header.Get("Origin"))
|
||||
}
|
||||
|
||||
err := decodeJsonRequest(req, command)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
|
||||
@@ -3,7 +3,6 @@ package server
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/pprof"
|
||||
"strings"
|
||||
@@ -30,13 +29,12 @@ type ServerConfig struct {
|
||||
// This is the default implementation of the Server interface.
|
||||
type Server struct {
|
||||
Config ServerConfig
|
||||
handler http.Handler
|
||||
peerServer *PeerServer
|
||||
registry *Registry
|
||||
store store.Store
|
||||
metrics *metrics.Bucket
|
||||
|
||||
listener net.Listener
|
||||
|
||||
trace bool
|
||||
}
|
||||
|
||||
@@ -50,6 +48,8 @@ func New(sConfig ServerConfig, peerServer *PeerServer, registry *Registry, store
|
||||
metrics: mb,
|
||||
}
|
||||
|
||||
s.handler = s.buildHTTPHandler()
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
@@ -172,7 +172,7 @@ func (s *Server) handleFunc(r *mux.Router, path string, f func(http.ResponseWrit
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *Server) buildHTTPHandler() http.Handler {
|
||||
router := mux.NewRouter()
|
||||
|
||||
// Install the routes.
|
||||
@@ -185,15 +185,11 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
s.installDebug(router)
|
||||
}
|
||||
|
||||
router.ServeHTTP(w, r)
|
||||
return router
|
||||
}
|
||||
|
||||
// Stops the server.
|
||||
func (s *Server) Close() {
|
||||
if s.listener != nil {
|
||||
s.listener.Close()
|
||||
s.listener = nil
|
||||
}
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
s.handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// Dispatch command to the current leader
|
||||
|
||||
Reference in New Issue
Block a user