From 8c09f98882e0e4db551eba73a783af24602d204b Mon Sep 17 00:00:00 2001 From: Brandon Philips Date: Sat, 10 Aug 2013 19:26:21 -0700 Subject: [PATCH] chore(etcd): cleanup TLS configuration the TLS configuration was getting rather complex with slices of tls.Config's being passed around and pointer nil checking for schema types. Introduce a new TLSInfo type that is in charge of holding the various TLS key/cert/CA filenames the user passes in. Then create a new TlsConfig type that has a Scheme and the Client and Server tls.Config objects inside of it. This is used by the two transport start methods which had been using a slice of tls.Config objects and guessing at the scheme based on the non-nil value of the Config. --- etcd.go | 176 ++++++++++++++++++++++---------------------------------- 1 file changed, 70 insertions(+), 106 deletions(-) diff --git a/etcd.go b/etcd.go index 1cf9ca41d..b583c09fd 100644 --- a/etcd.go +++ b/etcd.go @@ -63,13 +63,13 @@ func init() { flag.StringVar(&argInfo.RaftURL, "s", "127.0.0.1:7001", "the hostname:port for raft server communication") flag.StringVar(&argInfo.WebURL, "w", "", "the hostname:port of web interface") - flag.StringVar(&argInfo.ServerCAFile, "serverCAFile", "", "the path of the CAFile") - flag.StringVar(&argInfo.ServerCertFile, "serverCert", "", "the cert file of the server") - flag.StringVar(&argInfo.ServerKeyFile, "serverKey", "", "the key file of the server") + flag.StringVar(&argInfo.RaftTLS.CAFile, "serverCAFile", "", "the path of the CAFile") + flag.StringVar(&argInfo.RaftTLS.CertFile, "serverCert", "", "the cert file of the server") + flag.StringVar(&argInfo.RaftTLS.KeyFile, "serverKey", "", "the key file of the server") - flag.StringVar(&argInfo.ClientCAFile, "clientCAFile", "", "the path of the client CAFile") - flag.StringVar(&argInfo.ClientCertFile, "clientCert", "", "the cert file of the client") - flag.StringVar(&argInfo.ClientKeyFile, "clientKey", "", "the key file of the client") + flag.StringVar(&argInfo.EtcdTLS.CAFile, "clientCAFile", "", "the path of the client CAFile") + flag.StringVar(&argInfo.EtcdTLS.CertFile, "clientCert", "", "the cert file of the client") + flag.StringVar(&argInfo.EtcdTLS.KeyFile, "clientKey", "", "the key file of the client") flag.StringVar(&dirPath, "d", ".", "the directory to store log and snapshot") @@ -86,12 +86,6 @@ func init() { flag.StringVar(&cpuprofile, "cpuprofile", "", "write cpu profile to file") } -// CONSTANTS -const ( - RaftServer = iota - EtcdServer -) - const ( ELECTIONTIMEOUT = 200 * time.Millisecond HEARTBEATTIMEOUT = 50 * time.Millisecond @@ -109,6 +103,12 @@ const ( // //------------------------------------------------------------------------------ +type TLSInfo struct { + CertFile string `json:"serverCertFile"` + KeyFile string `json:"serverKeyFile"` + CAFile string `json:"serverCAFile"` +} + type Info struct { Name string `json:"name"` @@ -116,13 +116,8 @@ type Info struct { EtcdURL string `json:"etcdURL"` WebURL string `json:"webURL"` - ServerCertFile string `json:"serverCertFile"` - ServerKeyFile string `json:"serverKeyFile"` - ServerCAFile string `json:"serverCAFile"` - - ClientCertFile string `json:"clientCertFile"` - ClientKeyFile string `json:"clientKeyFile"` - ClientCAFile string `json:"clientCAFile"` + RaftTLS TLSInfo `json:"raftTLS"` + EtcdTLS TLSInfo `json:"raftTLS"` } //------------------------------------------------------------------------------ @@ -208,35 +203,23 @@ func main() { cluster = strings.Split(string(b), ",") } - raftTlsConfs, ok := tlsConf(RaftServer) + raftTLSConfig, ok := tlsConfigFromInfo(argInfo.RaftTLS) if !ok { fatal("Please specify cert and key file or cert and key file and CAFile or none of the three") } - raftDefaultScheme := "http" - if raftTlsConfs[0] != nil { - raftDefaultScheme = "https" - } - - etcdTlsConfs, ok := tlsConf(EtcdServer) + etcdTLSConfig, ok := tlsConfigFromInfo(argInfo.EtcdTLS) if !ok { fatal("Please specify cert and key file or cert and key file and CAFile or none of the three") } - etcdDefaultScheme := "http" - if etcdTlsConfs[0] != nil { - raftDefaultScheme = "https" - } - - // Otherwise ask user for info and write it to file. argInfo.Name = strings.TrimSpace(argInfo.Name) - if argInfo.Name == "" { fatal("ERROR: server name required. e.g. '-n=server_name'") } - argInfo.RaftURL = sanitizeURL(argInfo.RaftURL, raftTlsConfig.Scheme) - argInfo.EtcdURL = sanitizeURL(argInfo.EtcdURL, etcdTlsConfig.Scheme) + argInfo.RaftURL = sanitizeURL(argInfo.RaftURL, raftTLSConfig.Scheme) + argInfo.EtcdURL = sanitizeURL(argInfo.EtcdURL, etcdTLSConfig.Scheme) argInfo.WebURL = sanitizeURL(argInfo.WebURL, "http") // Setup commands. @@ -252,27 +235,27 @@ func main() { // Create etcd key-value store etcdStore = store.CreateStore(maxSize) - startRaft(raftTlsConfs) + startRaft(raftTLSConfig) if argInfo.WebURL != "" { // start web - argInfo.WebURL = checkURL(argInfo.WebURL, "http") + argInfo.WebURL = sanitizeURL(argInfo.WebURL, "http") go webHelper() go web.Start(raftServer, argInfo.WebURL) } - startEtcdTransport(*info, etcdTlsConfs[0]) + startEtcdTransport(*info, etcdTLSConfig.Scheme, etcdTLSConfig.Server) } // Start the raft server -func startRaft(tlsConfs []*tls.Config) { +func startRaft(tlsConfig TLSConfig) { var err error raftName := info.Name // Create transporter for raft - raftTransporter = newTransporter(tlsConfs[1]) + raftTransporter = newTransporter(tlsConfig.Scheme, tlsConfig.Client) // Create raft server raftServer, err = raft.NewServer(raftName, dirPath, raftTransporter, etcdStore, nil) @@ -367,37 +350,29 @@ func startRaft(tlsConfs []*tls.Config) { } // start to response to raft requests - go startRaftTransport(*info, tlsConfs[0]) + go startRaftTransport(*info, tlsConfig.Scheme, tlsConfig.Server) } // Create transporter using by raft server // Create http or https transporter based on // whether the user give the server cert and key -func newTransporter(tlsConf *tls.Config) transporter { +func newTransporter(scheme string, tlsConf tls.Config) transporter { t := transporter{} - if tlsConf == nil { - t.scheme = "http://" + t.scheme = scheme - t.client = &http.Client{ - Transport: &http.Transport{ - Dial: dialTimeout, - }, - } - - } else { - t.scheme = "https://" - - tr := &http.Transport{ - TLSClientConfig: tlsConf, - Dial: dialTimeout, - DisableCompression: true, - } - - t.client = &http.Client{Transport: tr} + tr := &http.Transport{ + Dial: dialTimeout, } + if scheme == "https" { + tr.TLSClientConfig = &tlsConf + tr.DisableCompression = true + } + + t.client = &http.Client{Transport: tr} + return t } @@ -407,7 +382,7 @@ func dialTimeout(network, addr string) (net.Conn, error) { } // Start to listen and response raft command -func startRaftTransport(info Info, tlsConf *tls.Config) { +func startRaftTransport(info Info, scheme string, tlsConf tls.Config) { u, _ := url.Parse(info.RaftURL) fmt.Printf("raft server [%s] listening on %s\n", info.Name, u) @@ -415,7 +390,7 @@ func startRaftTransport(info Info, tlsConf *tls.Config) { server := &http.Server{ Handler: raftMux, - TLSConfig: tlsConf, + TLSConfig: &tlsConf, Addr: u.Host, } @@ -429,16 +404,16 @@ func startRaftTransport(info Info, tlsConf *tls.Config) { raftMux.HandleFunc("/snapshotRecovery", SnapshotRecoveryHttpHandler) raftMux.HandleFunc("/etcdURL", EtcdURLHttpHandler) - if tlsConf == nil { + if scheme == "http" { fatal(server.ListenAndServe()) } else { - fatal(server.ListenAndServeTLS(info.ServerCertFile, argInfo.ServerKeyFile)) + fatal(server.ListenAndServeTLS(info.RaftTLS.CertFile, info.RaftTLS.KeyFile)) } } // Start to listen and response client command -func startEtcdTransport(info Info, tlsConf *tls.Config) { +func startEtcdTransport(info Info, scheme string, tlsConf tls.Config) { u, _ := url.Parse(info.EtcdURL) fmt.Printf("etcd server [%s] listening on %s\n", info.Name, u) @@ -446,7 +421,7 @@ func startEtcdTransport(info Info, tlsConf *tls.Config) { server := &http.Server{ Handler: etcdMux, - TLSConfig: tlsConf, + TLSConfig: &tlsConf, Addr: u.Host, } @@ -459,68 +434,57 @@ func startEtcdTransport(info Info, tlsConf *tls.Config) { etcdMux.HandleFunc("/stats", StatsHttpHandler) etcdMux.HandleFunc("/test/", TestHttpHandler) - if tlsConf == nil { + if scheme == "http" { fatal(server.ListenAndServe()) } else { - fatal(server.ListenAndServeTLS(info.ClientCertFile, info.ClientKeyFile)) + fatal(server.ListenAndServeTLS(info.EtcdTLS.CertFile, info.EtcdTLS.KeyFile)) } } //-------------------------------------- // Config //-------------------------------------- -func tlsConf(source int) ([]*tls.Config, bool) { + +type TLSConfig struct { + Scheme string + Server tls.Config + Client tls.Config +} + +func tlsConfigFromInfo(info TLSInfo) (t TLSConfig, ok bool) { var keyFile, certFile, CAFile string var tlsCert tls.Certificate - var isAuth bool var err error - switch source { + t.Scheme = "http" - case RaftServer: - keyFile = info.ServerKeyFile - certFile = info.ServerCertFile - CAFile = info.ServerCAFile - - if keyFile != "" && certFile != "" { - tlsCert, err = tls.LoadX509KeyPair(certFile, keyFile) - if err == nil { - fatal(err) - } - isAuth = true - } - - case EtcdServer: - keyFile = info.ClientKeyFile - certFile = info.ClientCertFile - CAFile = info.ClientCAFile - } + keyFile = info.KeyFile + certFile = info.CertFile + CAFile = info.CAFile // If the user do not specify key file, cert file and // CA file, the type will be HTTP if keyFile == "" && certFile == "" && CAFile == "" { - return []*tls.Config{nil, nil}, true + return t, true } - if keyFile != "" && certFile != "" { - serverConf := &tls.Config{} - serverConf.ClientAuth, serverConf.ClientCAs = newCertPool(CAFile) - - if isAuth { - raftTransConf := &tls.Config{ - Certificates: []tls.Certificate{tlsCert}, - InsecureSkipVerify: true, - } - return []*tls.Config{serverConf, raftTransConf}, true - } - - return []*tls.Config{serverConf, nil}, true - + // both the key and cert must be present + if keyFile == "" || certFile == "" { + return t, false } - // bad specification - return nil, false + tlsCert, err = tls.LoadX509KeyPair(certFile, keyFile) + if err == nil { + fatal(err) + } + t.Scheme = "https" + t.Server.Certificates = []tls.Certificate{tlsCert} + t.Server.InsecureSkipVerify = true + + t.Client.ClientAuth, t.Client.ClientCAs = newCertPool(CAFile) + + return t, true } func parseInfo(path string) *Info {