From 17459c7bfcda3bf5e7f2f9929cfe4d08ff9ffa4e Mon Sep 17 00:00:00 2001 From: Brian Waldon Date: Mon, 22 Sep 2014 16:48:12 -0700 Subject: [PATCH] transport: wrap net.Listener with TLSInfo --- main.go | 6 +-- test | 2 +- transport/listener.go | 87 ++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 89 insertions(+), 6 deletions(-) diff --git a/main.go b/main.go index 5e2f6070e..2ccffdc6b 100644 --- a/main.go +++ b/main.go @@ -168,7 +168,7 @@ func startEtcd() { Info: cors, } - l, err := transport.NewListener(*paddr) + l, err := transport.NewListener(*paddr, transport.TLSInfo{}) if err != nil { log.Fatal(err) } @@ -182,7 +182,7 @@ func startEtcd() { // Start a client server goroutine for each listen address for _, addr := range *addrs { addr := addr - l, err := transport.NewListener(addr) + l, err := transport.NewListener(addr, transport.TLSInfo{}) if err != nil { log.Fatal(err) } @@ -212,7 +212,7 @@ func startProxy() { // Start a proxy server goroutine for each listen address for _, addr := range *addrs { addr := addr - l, err := transport.NewListener(addr) + l, err := transport.NewListener(addr, transport.TLSInfo{}) if err != nil { log.Fatal(err) } diff --git a/test b/test index 0c468446e..ee93c6907 100755 --- a/test +++ b/test @@ -15,7 +15,7 @@ COVER=${COVER:-"-cover"} source ./build # Hack: gofmt ./ will recursively check the .git directory. So use *.go for gofmt. -TESTABLE_AND_FORMATTABLE="client etcdserver etcdserver/etcdhttp etcdserver/etcdserverpb functional proxy raft snap store wait wal" +TESTABLE_AND_FORMATTABLE="client etcdserver etcdserver/etcdhttp etcdserver/etcdserverpb functional proxy raft snap store wait wal transport" TESTABLE="$TESTABLE_AND_FORMATTABLE ./" FORMATTABLE="$TESTABLE_AND_FORMATTABLE *.go" diff --git a/transport/listener.go b/transport/listener.go index d73e22d5c..122f92d58 100644 --- a/transport/listener.go +++ b/transport/listener.go @@ -1,9 +1,92 @@ package transport import ( + "crypto/tls" + "crypto/x509" + "encoding/pem" + "fmt" + "io/ioutil" "net" ) -func NewListener(addr string) (net.Listener, error) { - return net.Listen("tcp", addr) +func NewListener(addr string, info TLSInfo) (net.Listener, error) { + l, err := net.Listen("tcp", addr) + if err != nil { + return nil, err + } + + if !info.Empty() { + cfg, err := info.ServerConfig() + if err != nil { + return nil, err + } + + l = tls.NewListener(l, cfg) + } + + return l, nil +} + +type TLSInfo struct { + CertFile string + KeyFile string + CAFile string +} + +func (info TLSInfo) Empty() bool { + return info.CertFile == "" && info.KeyFile == "" +} + +// Generates a tls.Config object for a server from the given files. +func (info TLSInfo) ServerConfig() (*tls.Config, error) { + // Both the key and cert must be present. + if info.KeyFile == "" || info.CertFile == "" { + return nil, fmt.Errorf("KeyFile and CertFile must both be present[key: %v, cert: %v]", info.KeyFile, info.CertFile) + } + + var cfg tls.Config + + tlsCert, err := tls.LoadX509KeyPair(info.CertFile, info.KeyFile) + if err != nil { + return nil, err + } + + cfg.Certificates = []tls.Certificate{tlsCert} + + if info.CAFile != "" { + cfg.ClientAuth = tls.RequireAndVerifyClientCert + cp, err := newCertPool(info.CAFile) + if err != nil { + return nil, err + } + + cfg.RootCAs = cp + cfg.ClientCAs = cp + } else { + cfg.ClientAuth = tls.NoClientCert + } + + return &cfg, nil +} + +// newCertPool creates x509 certPool with provided CA file +func newCertPool(CAFile string) (*x509.CertPool, error) { + certPool := x509.NewCertPool() + pemByte, err := ioutil.ReadFile(CAFile) + if err != nil { + return nil, err + } + + for { + var block *pem.Block + block, pemByte = pem.Decode(pemByte) + if block == nil { + return certPool, nil + } + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, err + } + certPool.AddCert(cert) + } }