mirror of
https://github.com/etcd-io/etcd.git
synced 2024-09-27 06:25:44 +00:00
refactor(server): drop Serve code; rename cors object
* server/cors.go renamed to http/cors.go * all CORS code removed from Server and PeerServer * Server and PeerServer fulfill http.Handler, now passed to http.Serve * non-HTTP code in PeerServer.Serve moved to PeerServer.Start
This commit is contained in:
parent
5c3a3db2d8
commit
0abd860f7e
12
etcd.go
12
etcd.go
@ -26,6 +26,7 @@ import (
|
|||||||
|
|
||||||
"github.com/coreos/raft"
|
"github.com/coreos/raft"
|
||||||
|
|
||||||
|
ehttp "github.com/coreos/etcd/http"
|
||||||
"github.com/coreos/etcd/log"
|
"github.com/coreos/etcd/log"
|
||||||
"github.com/coreos/etcd/metrics"
|
"github.com/coreos/etcd/metrics"
|
||||||
"github.com/coreos/etcd/server"
|
"github.com/coreos/etcd/server"
|
||||||
@ -102,7 +103,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Retrieve CORS configuration
|
// Retrieve CORS configuration
|
||||||
corsInfo, err := server.NewCORSInfo(config.CorsOrigins)
|
corsInfo, err := ehttp.NewCORSInfo(config.CorsOrigins)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("CORS:", err)
|
log.Fatal("CORS:", err)
|
||||||
}
|
}
|
||||||
@ -130,7 +131,6 @@ func main() {
|
|||||||
SnapshotCount: config.SnapshotCount,
|
SnapshotCount: config.SnapshotCount,
|
||||||
MaxClusterSize: config.MaxClusterSize,
|
MaxClusterSize: config.MaxClusterSize,
|
||||||
RetryTimes: config.MaxRetryAttempts,
|
RetryTimes: config.MaxRetryAttempts,
|
||||||
CORS: corsInfo,
|
|
||||||
}
|
}
|
||||||
ps := server.NewPeerServer(psConfig, registry, store, &mb, followersStats, serverStats)
|
ps := server.NewPeerServer(psConfig, registry, store, &mb, followersStats, serverStats)
|
||||||
|
|
||||||
@ -177,12 +177,16 @@ func main() {
|
|||||||
|
|
||||||
ps.SetServer(s)
|
ps.SetServer(s)
|
||||||
|
|
||||||
|
ps.Start(config.Snapshot, config.Peers)
|
||||||
|
|
||||||
// Run peer server in separate thread while the client server blocks.
|
// Run peer server in separate thread while the client server blocks.
|
||||||
go func() {
|
go func() {
|
||||||
log.Fatal(ps.Serve(psListener, config.Snapshot, config.Peers))
|
log.Infof("raft server [name %s, listen on %s, advertised url %s]", ps.Config.Name, psListener.Addr(), ps.Config.URL)
|
||||||
|
sHTTP := &ehttp.CORSHandler{ps, corsInfo}
|
||||||
|
log.Fatal(http.Serve(psListener, sHTTP))
|
||||||
}()
|
}()
|
||||||
|
|
||||||
log.Infof("etcd server [name %s, listen on %s, advertised url %s]", s.Config.Name, sListener.Addr(), s.Config.URL)
|
log.Infof("etcd server [name %s, listen on %s, advertised url %s]", s.Config.Name, sListener.Addr(), s.Config.URL)
|
||||||
sHTTP := &server.CORSHTTPMiddleware{s, corsInfo}
|
sHTTP := &ehttp.CORSHandler{s, corsInfo}
|
||||||
log.Fatal(http.Serve(sListener, sHTTP))
|
log.Fatal(http.Serve(sListener, sHTTP))
|
||||||
}
|
}
|
||||||
|
@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package server
|
package http
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -22,9 +22,9 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
)
|
)
|
||||||
|
|
||||||
type corsInfo map[string]bool
|
type CORSInfo map[string]bool
|
||||||
|
|
||||||
func NewCORSInfo(origins []string) (*corsInfo, error) {
|
func NewCORSInfo(origins []string) (*CORSInfo, error) {
|
||||||
// Construct a lookup of all origins.
|
// Construct a lookup of all origins.
|
||||||
m := make(map[string]bool)
|
m := make(map[string]bool)
|
||||||
for _, v := range origins {
|
for _, v := range origins {
|
||||||
@ -36,29 +36,29 @@ func NewCORSInfo(origins []string) (*corsInfo, error) {
|
|||||||
m[v] = true
|
m[v] = true
|
||||||
}
|
}
|
||||||
|
|
||||||
info := corsInfo(m)
|
info := CORSInfo(m)
|
||||||
return &info, nil
|
return &info, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// OriginAllowed determines whether the server will allow a given CORS origin.
|
// OriginAllowed determines whether the server will allow a given CORS origin.
|
||||||
func (c corsInfo) OriginAllowed(origin string) bool {
|
func (c CORSInfo) OriginAllowed(origin string) bool {
|
||||||
return c["*"] || c[origin]
|
return c["*"] || c[origin]
|
||||||
}
|
}
|
||||||
|
|
||||||
type CORSHTTPMiddleware struct {
|
type CORSHandler struct {
|
||||||
Handler http.Handler
|
Handler http.Handler
|
||||||
Info *corsInfo
|
Info *CORSInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
// addHeader adds the correct cors headers given an origin
|
// addHeader adds the correct cors headers given an origin
|
||||||
func (h *CORSHTTPMiddleware) addHeader(w http.ResponseWriter, origin string) {
|
func (h *CORSHandler) addHeader(w http.ResponseWriter, origin string) {
|
||||||
w.Header().Add("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
|
w.Header().Add("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
|
||||||
w.Header().Add("Access-Control-Allow-Origin", origin)
|
w.Header().Add("Access-Control-Allow-Origin", origin)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServeHTTP adds the correct CORS headers based on the origin and returns immediatly
|
// ServeHTTP adds the correct CORS headers based on the origin and returns immediatly
|
||||||
// with a 200 OK if the method is OPTIONS.
|
// with a 200 OK if the method is OPTIONS.
|
||||||
func (h *CORSHTTPMiddleware) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
func (h *CORSHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||||
// Write CORS header.
|
// Write CORS header.
|
||||||
if h.Info.OriginAllowed("*") {
|
if h.Info.OriginAllowed("*") {
|
||||||
h.addHeader(w, "*")
|
h.addHeader(w, "*")
|
@ -6,7 +6,6 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -35,11 +34,11 @@ type PeerServerConfig struct {
|
|||||||
ElectionTimeout time.Duration
|
ElectionTimeout time.Duration
|
||||||
MaxClusterSize int
|
MaxClusterSize int
|
||||||
RetryTimes int
|
RetryTimes int
|
||||||
CORS *corsInfo
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type PeerServer struct {
|
type PeerServer struct {
|
||||||
Config PeerServerConfig
|
Config PeerServerConfig
|
||||||
|
handler http.Handler
|
||||||
raftServer raft.Server
|
raftServer raft.Server
|
||||||
server *Server
|
server *Server
|
||||||
joinIndex uint64
|
joinIndex uint64
|
||||||
@ -49,8 +48,6 @@ type PeerServer struct {
|
|||||||
store store.Store
|
store store.Store
|
||||||
snapConf *snapshotConf
|
snapConf *snapshotConf
|
||||||
|
|
||||||
listener net.Listener
|
|
||||||
|
|
||||||
closeChan chan bool
|
closeChan chan bool
|
||||||
timeoutThresholdChan chan interface{}
|
timeoutThresholdChan chan interface{}
|
||||||
|
|
||||||
@ -82,6 +79,9 @@ func NewPeerServer(psConfig PeerServerConfig, registry *Registry, store store.St
|
|||||||
|
|
||||||
metrics: mb,
|
metrics: mb,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.handler = s.buildHTTPHandler()
|
||||||
|
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -107,7 +107,7 @@ func (s *PeerServer) SetRaftServer(raftServer raft.Server) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Start the raft server
|
// Start the raft server
|
||||||
func (s *PeerServer) Serve(listener net.Listener, snapshot bool, cluster []string) error {
|
func (s *PeerServer) Start(snapshot bool, cluster []string) error {
|
||||||
// LoadSnapshot
|
// LoadSnapshot
|
||||||
if snapshot {
|
if snapshot {
|
||||||
err := s.raftServer.LoadSnapshot()
|
err := s.raftServer.LoadSnapshot()
|
||||||
@ -157,8 +157,18 @@ func (s *PeerServer) Serve(listener net.Listener, snapshot bool, cluster []strin
|
|||||||
go s.monitorSnapshot()
|
go s.monitorSnapshot()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PeerServer) Stop() {
|
||||||
|
if s.closeChan != nil {
|
||||||
|
close(s.closeChan)
|
||||||
|
s.closeChan = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *PeerServer) buildHTTPHandler() http.Handler {
|
||||||
router := mux.NewRouter()
|
router := mux.NewRouter()
|
||||||
httpServer := &http.Server{Handler: router}
|
|
||||||
|
|
||||||
// internal commands
|
// internal commands
|
||||||
router.HandleFunc("/name", s.NameHttpHandler)
|
router.HandleFunc("/name", s.NameHttpHandler)
|
||||||
@ -174,21 +184,11 @@ func (s *PeerServer) Serve(listener net.Listener, snapshot bool, cluster []strin
|
|||||||
router.HandleFunc("/snapshotRecovery", s.SnapshotRecoveryHttpHandler)
|
router.HandleFunc("/snapshotRecovery", s.SnapshotRecoveryHttpHandler)
|
||||||
router.HandleFunc("/etcdURL", s.EtcdURLHttpHandler)
|
router.HandleFunc("/etcdURL", s.EtcdURLHttpHandler)
|
||||||
|
|
||||||
s.listener = listener
|
return router
|
||||||
log.Infof("raft server [name %s, listen on %s, advertised url %s]", s.Config.Name, listener.Addr(), s.Config.URL)
|
|
||||||
httpServer.Serve(listener)
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *PeerServer) Close() {
|
func (s *PeerServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
if s.closeChan != nil {
|
s.handler.ServeHTTP(w, r)
|
||||||
close(s.closeChan)
|
|
||||||
s.closeChan = nil
|
|
||||||
}
|
|
||||||
if s.listener != nil {
|
|
||||||
s.listener.Close()
|
|
||||||
s.listener = nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Retrieves the underlying Raft server.
|
// Retrieves the underlying Raft server.
|
||||||
|
@ -149,13 +149,6 @@ func (ps *PeerServer) EtcdURLHttpHandler(w http.ResponseWriter, req *http.Reques
|
|||||||
func (ps *PeerServer) JoinHttpHandler(w http.ResponseWriter, req *http.Request) {
|
func (ps *PeerServer) JoinHttpHandler(w http.ResponseWriter, req *http.Request) {
|
||||||
command := &JoinCommand{}
|
command := &JoinCommand{}
|
||||||
|
|
||||||
// Write CORS header.
|
|
||||||
if ps.Config.CORS.OriginAllowed("*") {
|
|
||||||
w.Header().Add("Access-Control-Allow-Origin", "*")
|
|
||||||
} else if ps.Config.CORS.OriginAllowed(req.Header.Get("Origin")) {
|
|
||||||
w.Header().Add("Access-Control-Allow-Origin", req.Header.Get("Origin"))
|
|
||||||
}
|
|
||||||
|
|
||||||
err := decodeJsonRequest(req, command)
|
err := decodeJsonRequest(req, command)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
@ -3,7 +3,6 @@ package server
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/pprof"
|
"net/http/pprof"
|
||||||
"strings"
|
"strings"
|
||||||
@ -30,13 +29,12 @@ type ServerConfig struct {
|
|||||||
// This is the default implementation of the Server interface.
|
// This is the default implementation of the Server interface.
|
||||||
type Server struct {
|
type Server struct {
|
||||||
Config ServerConfig
|
Config ServerConfig
|
||||||
|
handler http.Handler
|
||||||
peerServer *PeerServer
|
peerServer *PeerServer
|
||||||
registry *Registry
|
registry *Registry
|
||||||
store store.Store
|
store store.Store
|
||||||
metrics *metrics.Bucket
|
metrics *metrics.Bucket
|
||||||
|
|
||||||
listener net.Listener
|
|
||||||
|
|
||||||
trace bool
|
trace bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -50,6 +48,8 @@ func New(sConfig ServerConfig, peerServer *PeerServer, registry *Registry, store
|
|||||||
metrics: mb,
|
metrics: mb,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.handler = s.buildHTTPHandler()
|
||||||
|
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -172,7 +172,7 @@ func (s *Server) handleFunc(r *mux.Router, path string, f func(http.ResponseWrit
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) buildHTTPHandler() http.Handler {
|
||||||
router := mux.NewRouter()
|
router := mux.NewRouter()
|
||||||
|
|
||||||
// Install the routes.
|
// Install the routes.
|
||||||
@ -185,15 +185,11 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
s.installDebug(router)
|
s.installDebug(router)
|
||||||
}
|
}
|
||||||
|
|
||||||
router.ServeHTTP(w, r)
|
return router
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stops the server.
|
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
func (s *Server) Close() {
|
s.handler.ServeHTTP(w, r)
|
||||||
if s.listener != nil {
|
|
||||||
s.listener.Close()
|
|
||||||
s.listener = nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dispatch command to the current leader
|
// Dispatch command to the current leader
|
||||||
|
@ -2,6 +2,7 @@ package tests
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -27,7 +28,6 @@ func RunServer(f func(*server.Server)) {
|
|||||||
|
|
||||||
store := store.New()
|
store := store.New()
|
||||||
registry := server.NewRegistry(store)
|
registry := server.NewRegistry(store)
|
||||||
corsInfo, _ := server.NewCORSInfo([]string{})
|
|
||||||
|
|
||||||
serverStats := server.NewRaftServerStats(testName)
|
serverStats := server.NewRaftServerStats(testName)
|
||||||
followersStats := server.NewRaftFollowersStats(testName)
|
followersStats := server.NewRaftFollowersStats(testName)
|
||||||
@ -39,7 +39,6 @@ func RunServer(f func(*server.Server)) {
|
|||||||
Scheme: "http",
|
Scheme: "http",
|
||||||
SnapshotCount: testSnapshotCount,
|
SnapshotCount: testSnapshotCount,
|
||||||
MaxClusterSize: 9,
|
MaxClusterSize: 9,
|
||||||
CORS: corsInfo,
|
|
||||||
}
|
}
|
||||||
ps := server.NewPeerServer(psConfig, registry, store, nil, followersStats, serverStats)
|
ps := server.NewPeerServer(psConfig, registry, store, nil, followersStats, serverStats)
|
||||||
psListener, err := server.NewListener(testRaftURL)
|
psListener, err := server.NewListener(testRaftURL)
|
||||||
@ -63,7 +62,6 @@ func RunServer(f func(*server.Server)) {
|
|||||||
sConfig := server.ServerConfig{
|
sConfig := server.ServerConfig{
|
||||||
Name: testName,
|
Name: testName,
|
||||||
URL: "http://"+testClientURL,
|
URL: "http://"+testClientURL,
|
||||||
CORS: corsInfo,
|
|
||||||
}
|
}
|
||||||
s := server.New(sConfig, ps, registry, store, nil)
|
s := server.New(sConfig, ps, registry, store, nil)
|
||||||
sListener, err := server.NewListener(testClientURL)
|
sListener, err := server.NewListener(testClientURL)
|
||||||
@ -77,14 +75,15 @@ func RunServer(f func(*server.Server)) {
|
|||||||
c := make(chan bool)
|
c := make(chan bool)
|
||||||
go func() {
|
go func() {
|
||||||
c <- true
|
c <- true
|
||||||
ps.Serve(psListener, false, []string{})
|
ps.Start(false, []string{})
|
||||||
|
http.Serve(psListener, ps)
|
||||||
}()
|
}()
|
||||||
<-c
|
<-c
|
||||||
|
|
||||||
// Start up etcd server.
|
// Start up etcd server.
|
||||||
go func() {
|
go func() {
|
||||||
c <- true
|
c <- true
|
||||||
s.Serve(sListener)
|
http.Serve(sListener, s)
|
||||||
}()
|
}()
|
||||||
<-c
|
<-c
|
||||||
|
|
||||||
@ -95,6 +94,7 @@ func RunServer(f func(*server.Server)) {
|
|||||||
f(s)
|
f(s)
|
||||||
|
|
||||||
// Clean up servers.
|
// Clean up servers.
|
||||||
ps.Close()
|
ps.Stop()
|
||||||
s.Close()
|
psListener.Close()
|
||||||
|
sListener.Close()
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user