simplify createTrans

This commit is contained in:
Xiang Li 2013-08-09 10:12:50 -07:00 committed by Brandon Philips
parent 7b38812575
commit 06fab60dd6

154
etcd.go
View File

@ -89,14 +89,8 @@ func init() {
// CONSTANTS // CONSTANTS
const ( const (
HTTP = iota RaftServer = iota
HTTPS EtcdServer
HTTPSANDVERIFY
)
const (
SERVER = iota
CLIENT
) )
const ( const (
@ -200,19 +194,20 @@ func main() {
info = getInfo(dirPath) info = getInfo(dirPath)
// security type raftTlsConfs, ok := tlsConf(RaftServer)
st := securityType(SERVER) if !ok {
fatal("Please specify cert and key file or cert and key file and CAFile or none of the three")
}
clientSt := securityType(CLIENT) etcdTlsConfs, ok := tlsConf(EtcdServer)
if !ok {
if st == -1 || clientSt == -1 {
fatal("Please specify cert and key file or cert and key file and CAFile or none of the three") fatal("Please specify cert and key file or cert and key file and CAFile or none of the three")
} }
// Create etcd key-value store // Create etcd key-value store
etcdStore = store.CreateStore(maxSize) etcdStore = store.CreateStore(maxSize)
startRaft(st) startRaft(raftTlsConfs)
if argInfo.WebPort != -1 { if argInfo.WebPort != -1 {
// start web // start web
@ -221,18 +216,18 @@ func main() {
go web.Start(raftServer, argInfo.WebPort) go web.Start(raftServer, argInfo.WebPort)
} }
startClientTransport(*info, clientSt) startEtcdTransport(*info, etcdTlsConfs[0])
} }
// Start the raft server // Start the raft server
func startRaft(securityType int) { func startRaft(tlsConfs []*tls.Config) {
var err error var err error
raftName := fmt.Sprintf("%s:%d", info.Hostname, info.RaftPort) raftName := fmt.Sprintf("%s:%d", info.Hostname, info.RaftPort)
// Create transporter for raft // Create transporter for raft
raftTransporter = createTransporter(securityType) raftTransporter = newTransporter(tlsConfs[1])
// Create raft server // Create raft server
raftServer, err = raft.NewServer(raftName, dirPath, raftTransporter, etcdStore, nil) raftServer, err = raft.NewServer(raftName, dirPath, raftTransporter, etcdStore, nil)
@ -328,44 +323,30 @@ func startRaft(securityType int) {
} }
// start to response to raft requests // start to response to raft requests
go startRaftTransport(*info, securityType) go startRaftTransport(*info, tlsConfs[0])
} }
// Create transporter using by raft server // Create transporter using by raft server
// Create http or https transporter based on // Create http or https transporter based on
// whether the user give the server cert and key // whether the user give the server cert and key
func createTransporter(st int) transporter { func newTransporter(tlsConf *tls.Config) transporter {
t := transporter{} t := transporter{}
switch st { if tlsConf == nil {
case HTTP:
t.scheme = "http://" t.scheme = "http://"
tr := &http.Transport{
Dial: dialTimeout,
}
t.client = &http.Client{ t.client = &http.Client{
Transport: tr, Transport: &http.Transport{
Dial: dialTimeout,
},
} }
case HTTPS: } else {
fallthrough
case HTTPSANDVERIFY:
t.scheme = "https://" t.scheme = "https://"
tlsCert, err := tls.LoadX509KeyPair(argInfo.ServerCertFile, argInfo.ServerKeyFile)
if err != nil {
fatal(err)
}
tr := &http.Transport{ tr := &http.Transport{
TLSClientConfig: &tls.Config{ TLSClientConfig: tlsConf,
Certificates: []tls.Certificate{tlsCert},
InsecureSkipVerify: true,
},
Dial: dialTimeout, Dial: dialTimeout,
DisableCompression: true, DisableCompression: true,
} }
@ -382,7 +363,7 @@ func dialTimeout(network, addr string) (net.Conn, error) {
} }
// Start to listen and response raft command // Start to listen and response raft command
func startRaftTransport(info Info, st int) { func startRaftTransport(info Info, tlsConf *tls.Config) {
// internal commands // internal commands
http.HandleFunc("/join", JoinHttpHandler) http.HandleFunc("/join", JoinHttpHandler)
@ -393,24 +374,14 @@ func startRaftTransport(info Info, st int) {
http.HandleFunc("/snapshotRecovery", SnapshotRecoveryHttpHandler) http.HandleFunc("/snapshotRecovery", SnapshotRecoveryHttpHandler)
http.HandleFunc("/client", ClientHttpHandler) http.HandleFunc("/client", ClientHttpHandler)
switch st { if tlsConf == nil {
case HTTP:
fmt.Printf("raft server [%s] listen on http port %v\n", info.Hostname, info.RaftPort) fmt.Printf("raft server [%s] listen on http port %v\n", info.Hostname, info.RaftPort)
fatal(http.ListenAndServe(fmt.Sprintf(":%d", info.RaftPort), nil)) fatal(http.ListenAndServe(fmt.Sprintf(":%d", info.RaftPort), nil))
case HTTPS: } else {
fmt.Printf("raft server [%s] listen on https port %v\n", info.Hostname, info.RaftPort)
fatal(http.ListenAndServeTLS(fmt.Sprintf(":%d", info.RaftPort), info.ServerCertFile, argInfo.ServerKeyFile, nil))
case HTTPSANDVERIFY:
server := &http.Server{ server := &http.Server{
TLSConfig: &tls.Config{ TLSConfig: tlsConf,
ClientAuth: tls.RequireAndVerifyClientCert, Addr: fmt.Sprintf(":%d", info.RaftPort),
ClientCAs: createCertPool(info.ServerCAFile),
},
Addr: fmt.Sprintf(":%d", info.RaftPort),
} }
fmt.Printf("raft server [%s] listen on https port %v\n", info.Hostname, info.RaftPort) fmt.Printf("raft server [%s] listen on https port %v\n", info.Hostname, info.RaftPort)
fatal(server.ListenAndServeTLS(info.ServerCertFile, argInfo.ServerKeyFile)) fatal(server.ListenAndServeTLS(info.ServerCertFile, argInfo.ServerKeyFile))
@ -419,7 +390,7 @@ func startRaftTransport(info Info, st int) {
} }
// Start to listen and response client command // Start to listen and response client command
func startClientTransport(info Info, st int) { func startEtcdTransport(info Info, tlsConf *tls.Config) {
// external commands // external commands
http.HandleFunc("/"+version+"/keys/", Multiplexer) http.HandleFunc("/"+version+"/keys/", Multiplexer)
http.HandleFunc("/"+version+"/watch/", WatchHttpHandler) http.HandleFunc("/"+version+"/watch/", WatchHttpHandler)
@ -429,24 +400,13 @@ func startClientTransport(info Info, st int) {
http.HandleFunc("/stats", StatsHttpHandler) http.HandleFunc("/stats", StatsHttpHandler)
http.HandleFunc("/test/", TestHttpHandler) http.HandleFunc("/test/", TestHttpHandler)
switch st { if tlsConf == nil {
case HTTP:
fmt.Printf("etcd [%s] listen on http port %v\n", info.Hostname, info.ClientPort) fmt.Printf("etcd [%s] listen on http port %v\n", info.Hostname, info.ClientPort)
fatal(http.ListenAndServe(fmt.Sprintf(":%d", info.ClientPort), nil)) fatal(http.ListenAndServe(fmt.Sprintf(":%d", info.ClientPort), nil))
} else {
case HTTPS:
fmt.Printf("etcd [%s] listen on https port %v\n", info.Hostname, info.ClientPort)
http.ListenAndServeTLS(fmt.Sprintf(":%d", info.ClientPort), info.ClientCertFile, info.ClientKeyFile, nil)
case HTTPSANDVERIFY:
server := &http.Server{ server := &http.Server{
TLSConfig: &tls.Config{ TLSConfig: tlsConf,
ClientAuth: tls.RequireAndVerifyClientCert, Addr: fmt.Sprintf(":%d", info.ClientPort),
ClientCAs: createCertPool(info.ClientCAFile),
},
Addr: fmt.Sprintf(":%d", info.ClientPort),
} }
fmt.Printf("etcd [%s] listen on https port %v\n", info.Hostname, info.ClientPort) fmt.Printf("etcd [%s] listen on https port %v\n", info.Hostname, info.ClientPort)
fatal(server.ListenAndServeTLS(info.ClientCertFile, info.ClientKeyFile)) fatal(server.ListenAndServeTLS(info.ClientCertFile, info.ClientKeyFile))
@ -456,20 +416,28 @@ func startClientTransport(info Info, st int) {
//-------------------------------------- //--------------------------------------
// Config // Config
//-------------------------------------- //--------------------------------------
func tlsConf(source int) ([]*tls.Config, bool) {
// Get the security type
func securityType(source int) int {
var keyFile, certFile, CAFile string var keyFile, certFile, CAFile string
var tlsCert tls.Certificate
var isAuth bool
var err error
switch source { switch source {
case SERVER: case RaftServer:
keyFile = info.ServerKeyFile keyFile = info.ServerKeyFile
certFile = info.ServerCertFile certFile = info.ServerCertFile
CAFile = info.ServerCAFile CAFile = info.ServerCAFile
case CLIENT: if keyFile != "" && certFile != "" {
tlsCert, err = tls.LoadX509KeyPair(certFile, keyFile)
if err == nil {
fatal(err)
}
isAuth = true
}
case EtcdServer:
keyFile = info.ClientKeyFile keyFile = info.ClientKeyFile
certFile = info.ClientCertFile certFile = info.ClientCertFile
CAFile = info.ClientCAFile CAFile = info.ClientCAFile
@ -478,25 +446,28 @@ func securityType(source int) int {
// If the user do not specify key file, cert file and // If the user do not specify key file, cert file and
// CA file, the type will be HTTP // CA file, the type will be HTTP
if keyFile == "" && certFile == "" && CAFile == "" { if keyFile == "" && certFile == "" && CAFile == "" {
return []*tls.Config{nil, nil}, true
return HTTP
} }
if keyFile != "" && certFile != "" { if keyFile != "" && certFile != "" {
if CAFile != "" { serverConf := &tls.Config{}
// If the user specify all the three file, the type serverConf.ClientAuth, serverConf.ClientCAs = newCertPool(CAFile)
// will be HTTPS with client cert auth
return HTTPSANDVERIFY if isAuth {
raftTransConf := &tls.Config{
Certificates: []tls.Certificate{tlsCert},
InsecureSkipVerify: true,
}
return []*tls.Config{serverConf, raftTransConf}, true
} }
// If the user specify key file and cert file but not
// CA file, the type will be HTTPS without client cert return []*tls.Config{serverConf, nil}, true
// auth
return HTTPS
} }
// bad specification // bad specification
return -1 return nil, false
} }
func parseInfo(path string) *Info { func parseInfo(path string) *Info {
@ -569,7 +540,10 @@ func getInfo(path string) *Info {
} }
// Create client auth certpool // Create client auth certpool
func createCertPool(CAFile string) *x509.CertPool { func newCertPool(CAFile string) (tls.ClientAuthType, *x509.CertPool) {
if CAFile == "" {
return tls.NoClientCert, nil
}
pemByte, _ := ioutil.ReadFile(CAFile) pemByte, _ := ioutil.ReadFile(CAFile)
block, pemByte := pem.Decode(pemByte) block, pemByte := pem.Decode(pemByte)
@ -584,7 +558,7 @@ func createCertPool(CAFile string) *x509.CertPool {
certPool.AddCert(cert) certPool.AddCert(cert)
return certPool return tls.RequireAndVerifyClientCert, certPool
} }
// Send join requests to the leader. // Send join requests to the leader.