diff --git a/config/config.go b/config/config.go index 633877394..b30774102 100644 --- a/config/config.go +++ b/config/config.go @@ -393,24 +393,16 @@ func (c *Config) Reset() error { // Sanitize cleans the input fields. func (c *Config) Sanitize() error { - tlsConfig, err := c.TLSConfig() - if err != nil { - return err - } - - peerTlsConfig, err := c.PeerTLSConfig() - if err != nil { - return err - } + var err error // Sanitize the URLs first. - if c.Addr, err = sanitizeURL(c.Addr, tlsConfig.Scheme); err != nil { + if c.Addr, err = sanitizeURL(c.Addr, c.EtcdTLSInfo().Scheme()); err != nil { return fmt.Errorf("Advertised URL: %s", err) } if c.BindAddr, err = sanitizeBindAddr(c.BindAddr, c.Addr); err != nil { return fmt.Errorf("Listen Host: %s", err) } - if c.Peer.Addr, err = sanitizeURL(c.Peer.Addr, peerTlsConfig.Scheme); err != nil { + if c.Peer.Addr, err = sanitizeURL(c.Peer.Addr, c.PeerTLSInfo().Scheme()); err != nil { return fmt.Errorf("Peer Advertised URL: %s", err) } if c.Peer.BindAddr, err = sanitizeBindAddr(c.Peer.BindAddr, c.Peer.Addr); err != nil { @@ -430,34 +422,24 @@ func (c *Config) Sanitize() error { return nil } -// TLSInfo retrieves a TLSInfo object for the client server. -func (c *Config) TLSInfo() server.TLSInfo { +// EtcdTLSInfo retrieves a TLSInfo object for the etcd server +func (c *Config) EtcdTLSInfo() server.TLSInfo { return server.TLSInfo{ - CAFile: c.CAFile, - CertFile: c.CertFile, - KeyFile: c.KeyFile, + CAFile: c.CAFile, + CertFile: c.CertFile, + KeyFile: c.KeyFile, } } -// ClientTLSConfig generates the TLS configuration for the client server. -func (c *Config) TLSConfig() (server.TLSConfig, error) { - return c.TLSInfo().Config() -} - -// PeerTLSInfo retrieves a TLSInfo object for the peer server. +// PeerRaftInfo retrieves a TLSInfo object for the peer server. func (c *Config) PeerTLSInfo() server.TLSInfo { return server.TLSInfo{ - CAFile: c.Peer.CAFile, - CertFile: c.Peer.CertFile, - KeyFile: c.Peer.KeyFile, + CAFile: c.Peer.CAFile, + CertFile: c.Peer.CertFile, + KeyFile: c.Peer.KeyFile, } } -// PeerTLSConfig generates the TLS configuration for the peer server. -func (c *Config) PeerTLSConfig() (server.TLSConfig, error) { - return c.PeerTLSInfo().Config() -} - // MetricsBucketName generates the name that should be used for a // corresponding MetricsBucket object func (c *Config) MetricsBucketName() string { diff --git a/etcd.go b/etcd.go index c63f71fb6..7055fcabf 100644 --- a/etcd.go +++ b/etcd.go @@ -79,16 +79,6 @@ func main() { log.Warnf("All cached configuration is now ignored. The file %s can be removed.", info) } - // Retrieve TLS configuration. - tlsConfig, err := config.TLSInfo().Config() - if err != nil { - log.Fatal("Client TLS:", err) - } - peerTLSConfig, err := config.PeerTLSInfo().Config() - if err != nil { - log.Fatal("Peer TLS:", err) - } - var mbName string if config.Trace() { mbName = config.MetricsBucketName() @@ -124,10 +114,10 @@ func main() { dialTimeout := (3 * heartbeatTimeout) + electionTimeout responseHeaderTimeout := (3 * heartbeatTimeout) + electionTimeout - // Create peer server. + // Create peer server psConfig := server.PeerServerConfig{ Name: config.Name, - Scheme: peerTLSConfig.Scheme, + Scheme: config.PeerTLSInfo().Scheme(), URL: config.Peer.Addr, SnapshotCount: config.SnapshotCount, MaxClusterSize: config.MaxClusterSize, @@ -137,18 +127,30 @@ func main() { var psListener net.Listener if psConfig.Scheme == "https" { - psListener, err = server.NewTLSListener(&tlsConfig.Server, config.Peer.BindAddr, config.PeerTLSInfo().CertFile, config.PeerTLSInfo().KeyFile) + peerServerTLSConfig, err := config.PeerTLSInfo().ServerConfig() + if err != nil { + log.Fatal("peer server TLS error: ", err) + } + + psListener, err = server.NewTLSListener(config.Peer.BindAddr, peerServerTLSConfig) + if err != nil { + log.Fatal("Failed to create peer listener: ", err) + } } else { psListener, err = server.NewListener(config.Peer.BindAddr) - } - if err != nil { - panic(err) + if err != nil { + log.Fatal("Failed to create peer listener: ", err) + } } - // Create Raft transporter and server + // Create raft transporter and server raftTransporter := server.NewTransporter(followersStats, serverStats, registry, heartbeatTimeout, dialTimeout, responseHeaderTimeout) if psConfig.Scheme == "https" { - raftTransporter.SetTLSConfig(peerTLSConfig.Client) + raftClientTLSConfig, err := config.PeerTLSInfo().ClientConfig() + if err != nil { + log.Fatal("raft client TLS error: ", err) + } + raftTransporter.SetTLSConfig(*raftClientTLSConfig) } raftServer, err := raft.NewServer(config.Name, config.DataDir, raftTransporter, store, ps, "") if err != nil { @@ -158,7 +160,7 @@ func main() { raftServer.SetHeartbeatTimeout(heartbeatTimeout) ps.SetRaftServer(raftServer) - // Create client server. + // Create etcd server s := server.New(config.Name, config.Addr, ps, registry, store, &mb) if config.Trace() { @@ -166,22 +168,28 @@ func main() { } var sListener net.Listener - if tlsConfig.Scheme == "https" { - sListener, err = server.NewTLSListener(&tlsConfig.Server, config.BindAddr, config.TLSInfo().CertFile, config.TLSInfo().KeyFile) + if config.EtcdTLSInfo().Scheme() == "https" { + etcdServerTLSConfig, err := config.EtcdTLSInfo().ServerConfig() + if err != nil { + log.Fatal("etcd TLS error: ", err) + } + + sListener, err = server.NewTLSListener(config.BindAddr, etcdServerTLSConfig) + if err != nil { + log.Fatal("Failed to create TLS etcd listener: ", err) + } } else { sListener, err = server.NewListener(config.BindAddr) - } - if err != nil { - panic(err) + if err != nil { + log.Fatal("Failed to create etcd listener: ", err) + } } ps.SetServer(s) - ps.Start(config.Snapshot, config.Peers) - // Run peer server in separate thread while the client server blocks. go func() { - log.Infof("raft server [name %s, listen on %s, advertised url %s]", ps.Config.Name, psListener.Addr(), ps.Config.URL) + log.Infof("peer server [name %s, listen on %s, advertised url %s]", ps.Config.Name, psListener.Addr(), ps.Config.URL) sHTTP := &ehttp.CORSHandler{ps.HTTPHandler(), corsInfo} log.Fatal(http.Serve(psListener, sHTTP)) }() diff --git a/server/listener.go b/server/listener.go index f007f0cb3..93527d66c 100644 --- a/server/listener.go +++ b/server/listener.go @@ -16,28 +16,15 @@ func NewListener(addr string) (net.Listener, error) { return l, nil } -func NewTLSListener(config *tls.Config, addr, certFile, keyFile string) (net.Listener, error) { +func NewTLSListener(addr string, cfg *tls.Config) (net.Listener, error) { if addr == "" { addr = ":https" } - if config == nil { - config = &tls.Config{} - } - - config.NextProtos = []string{"http/1.1"} - - var err error - config.Certificates = make([]tls.Certificate, 1) - config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) - if err != nil { - return nil, err - } - conn, err := net.Listen("tcp", addr) if err != nil { return nil, err } - return tls.NewListener(conn, config), nil + return tls.NewListener(conn, cfg), nil } diff --git a/server/tls_config.go b/server/tls_config.go deleted file mode 100644 index 733f9c082..000000000 --- a/server/tls_config.go +++ /dev/null @@ -1,12 +0,0 @@ -package server - -import ( - "crypto/tls" -) - -// TLSConfig holds the TLS configuration. -type TLSConfig struct { - Scheme string // http or https - Server tls.Config // Used by the Raft or etcd Server transporter. - Client tls.Config // Used by the Raft peer client. -} diff --git a/server/tls_info.go b/server/tls_info.go index 6b16db013..bc2d1099b 100644 --- a/server/tls_info.go +++ b/server/tls_info.go @@ -15,62 +15,88 @@ type TLSInfo struct { CAFile string `json:"CAFile"` } -// Generates a TLS configuration from the given files. -func (info TLSInfo) Config() (TLSConfig, error) { - var t TLSConfig - t.Scheme = "http" - - // If the user do not specify key file, cert file and CA file, the type will be HTTP - if info.KeyFile == "" && info.CertFile == "" && info.CAFile == "" { - return t, nil +func (info TLSInfo) Scheme() string { + if info.KeyFile != "" && info.CertFile != "" { + return "https" + } else { + return "http" } +} +// Generates a tls.Config object for a server from the given files. +func (info TLSInfo) ServerConfig() (*tls.Config, error) { // Both the key and cert must be present. if info.KeyFile == "" || info.CertFile == "" { - return t, fmt.Errorf("KeyFile and CertFile must both be present[key: %v, cert: %v]", info.KeyFile, info.CertFile) + return nil, fmt.Errorf("KeyFile and CertFile must both be present[key: %v, cert: %v]", info.KeyFile, info.CertFile) + } + + var cfg tls.Config + + tlsCert, err := tls.LoadX509KeyPair(info.CertFile, info.KeyFile) + if err != nil { + return nil, err + } + + cfg.Certificates = []tls.Certificate{tlsCert} + + if info.CAFile != "" { + cfg.ClientAuth = tls.RequireAndVerifyClientCert + cp, err := newCertPool(info.CAFile) + if err != nil { + return nil, err + } + + cfg.RootCAs = cp + cfg.ClientCAs = cp + } else { + cfg.ClientAuth = tls.NoClientCert + } + + return &cfg, nil +} + +// Generates a tls.Config object for a client from the given files. +func (info TLSInfo) ClientConfig() (*tls.Config, error) { + var cfg tls.Config + + if info.KeyFile == "" || info.CertFile == "" { + return &cfg, nil } tlsCert, err := tls.LoadX509KeyPair(info.CertFile, info.KeyFile) if err != nil { - return t, err + return nil, err } - t.Scheme = "https" - t.Server.ClientAuth, t.Server.ClientCAs, err = newCertPool(info.CAFile) - if err != nil { - return t, err + cfg.Certificates = []tls.Certificate{tlsCert} + + if info.CAFile != "" { + cp, err := newCertPool(info.CAFile) + if err != nil { + return nil, err + } + + cfg.RootCAs = cp } - // The client should trust the RootCA that the Server uses since - // everyone is a peer in the network. - t.Client.Certificates = []tls.Certificate{tlsCert} - t.Client.RootCAs = t.Server.ClientCAs - - return t, nil + return &cfg, nil } -// newCertPool creates x509 certPool and corresponding Auth Type. -// If the given CAfile is valid, add the cert into the pool and verify the clients' -// certs against the cert in the pool. -// If the given CAfile is empty, do not verify the clients' cert. -// If the given CAfile is not valid, fatal. -func newCertPool(CAFile string) (tls.ClientAuthType, *x509.CertPool, error) { - if CAFile == "" { - return tls.NoClientCert, nil, nil - } +// newCertPool creates x509 certPool with provided CA file +func newCertPool(CAFile string) (*x509.CertPool, error) { pemByte, err := ioutil.ReadFile(CAFile) if err != nil { - return 0, nil, err + return nil, err } block, pemByte := pem.Decode(pemByte) cert, err := x509.ParseCertificate(block.Bytes) if err != nil { - return 0, nil, err + return nil, err } certPool := x509.NewCertPool() certPool.AddCert(cert) - return tls.RequireAndVerifyClientCert, certPool, nil + return certPool, nil } diff --git a/tests/functional/etcd_tls_test.go b/tests/functional/etcd_tls_test.go index 493dd7bd7..2f24cc74a 100644 --- a/tests/functional/etcd_tls_test.go +++ b/tests/functional/etcd_tls_test.go @@ -162,6 +162,8 @@ func startServer(extra []string) (*os.Process, error) { cmd := []string{"etcd", "-f", "-data-dir=/tmp/node1", "-name=node1"} cmd = append(cmd, extra...) + println(strings.Join(cmd, " ")) + return os.StartProcess(EtcdBinPath, cmd, procAttr) }