From 322976bedcd1e987f2b8bd4582a960fd5901e503 Mon Sep 17 00:00:00 2001 From: Anthony Romano Date: Thu, 15 Jun 2017 18:25:00 -0700 Subject: [PATCH] transport: CRL checking --- pkg/transport/listener.go | 5 ++- pkg/transport/listener_tls.go | 81 +++++++++++++++++++++++++++++------ 2 files changed, 71 insertions(+), 15 deletions(-) diff --git a/pkg/transport/listener.go b/pkg/transport/listener.go index 3b58b4154..12120beae 100644 --- a/pkg/transport/listener.go +++ b/pkg/transport/listener.go @@ -52,7 +52,7 @@ func wrapTLS(addr, scheme string, tlsinfo *TLSInfo, l net.Listener) (net.Listene if scheme != "https" && scheme != "unixs" { return l, nil } - return newTLSListener(l, tlsinfo) + return newTLSListener(l, tlsinfo, checkSAN) } type TLSInfo struct { @@ -61,6 +61,7 @@ type TLSInfo struct { CAFile string TrustedCAFile string ClientCertAuth bool + CRLFile string // ServerName ensures the cert matches the given host in case of discovery / virtual hosting ServerName string @@ -77,7 +78,7 @@ type TLSInfo struct { } func (info TLSInfo) String() string { - return fmt.Sprintf("cert = %s, key = %s, ca = %s, trusted-ca = %s, client-cert-auth = %v", info.CertFile, info.KeyFile, info.CAFile, info.TrustedCAFile, info.ClientCertAuth) + return fmt.Sprintf("cert = %s, key = %s, ca = %s, trusted-ca = %s, client-cert-auth = %v, crl-file = %s", info.CertFile, info.KeyFile, info.CAFile, info.TrustedCAFile, info.ClientCertAuth, info.CRLFile) } func (info TLSInfo) Empty() bool { diff --git a/pkg/transport/listener_tls.go b/pkg/transport/listener_tls.go index ecc124548..85630aab6 100644 --- a/pkg/transport/listener_tls.go +++ b/pkg/transport/listener_tls.go @@ -19,21 +19,32 @@ import ( "crypto/tls" "crypto/x509" "fmt" + "io/ioutil" "net" "sync" ) // tlsListener overrides a TLS listener so it will reject client -// certificates with insufficient SAN credentials. +// certificates with insufficient SAN credentials or CRL revoked +// certificates. type tlsListener struct { net.Listener connc chan net.Conn donec chan struct{} err error handshakeFailure func(*tls.Conn, error) + check tlsCheckFunc } -func newTLSListener(l net.Listener, tlsinfo *TLSInfo) (net.Listener, error) { +type tlsCheckFunc func(context.Context, *tls.Conn) error + +// NewTLSListener handshakes TLS connections and performs optional CRL checking. +func NewTLSListener(l net.Listener, tlsinfo *TLSInfo) (net.Listener, error) { + check := func(context.Context, *tls.Conn) error { return nil } + return newTLSListener(l, tlsinfo, check) +} + +func newTLSListener(l net.Listener, tlsinfo *TLSInfo, check tlsCheckFunc) (net.Listener, error) { if tlsinfo == nil || tlsinfo.Empty() { l.Close() return nil, fmt.Errorf("cannot listen on TLS for %s: KeyFile and CertFile are not presented", l.Addr().String()) @@ -47,11 +58,27 @@ func newTLSListener(l net.Listener, tlsinfo *TLSInfo) (net.Listener, error) { if hf == nil { hf = func(*tls.Conn, error) {} } + + if len(tlsinfo.CRLFile) > 0 { + prevCheck := check + check = func(ctx context.Context, tlsConn *tls.Conn) error { + if err := prevCheck(ctx, tlsConn); err != nil { + return err + } + st := tlsConn.ConnectionState() + if certs := st.PeerCertificates; len(certs) > 0 { + return checkCRL(tlsinfo.CRLFile, certs) + } + return nil + } + } + tlsl := &tlsListener{ Listener: tls.NewListener(l, tlscfg), connc: make(chan net.Conn), donec: make(chan struct{}), handshakeFailure: hf, + check: check, } go tlsl.acceptLoop() return tlsl, nil @@ -66,6 +93,15 @@ func (l *tlsListener) Accept() (net.Conn, error) { } } +func checkSAN(ctx context.Context, tlsConn *tls.Conn) error { + st := tlsConn.ConnectionState() + if certs := st.PeerCertificates; len(certs) > 0 { + addr := tlsConn.RemoteAddr().String() + return checkCertSAN(ctx, certs[0], addr) + } + return nil +} + // acceptLoop launches each TLS handshake in a separate goroutine // to prevent a hanging TLS connection from blocking other connections. func (l *tlsListener) acceptLoop() { @@ -110,20 +146,16 @@ func (l *tlsListener) acceptLoop() { pendingMu.Lock() delete(pending, conn) pendingMu.Unlock() + if herr != nil { l.handshakeFailure(tlsConn, herr) return } - - st := tlsConn.ConnectionState() - if len(st.PeerCertificates) > 0 { - cert := st.PeerCertificates[0] - addr := tlsConn.RemoteAddr().String() - if cerr := checkCert(ctx, cert, addr); cerr != nil { - l.handshakeFailure(tlsConn, cerr) - return - } + if err := l.check(ctx, tlsConn); err != nil { + l.handshakeFailure(tlsConn, err) + return } + select { case l.connc <- tlsConn: conn = nil @@ -133,11 +165,34 @@ func (l *tlsListener) acceptLoop() { } } -func checkCert(ctx context.Context, cert *x509.Certificate, remoteAddr string) error { - h, _, herr := net.SplitHostPort(remoteAddr) +func checkCRL(crlPath string, cert []*x509.Certificate) error { + // TODO: cache + crlBytes, err := ioutil.ReadFile(crlPath) + if err != nil { + return err + } + certList, err := x509.ParseCRL(crlBytes) + if err != nil { + return err + } + revokedSerials := make(map[string]struct{}) + for _, rc := range certList.TBSCertList.RevokedCertificates { + revokedSerials[string(rc.SerialNumber.Bytes())] = struct{}{} + } + for _, c := range cert { + serial := string(c.SerialNumber.Bytes()) + if _, ok := revokedSerials[serial]; ok { + return fmt.Errorf("transport: certificate serial %x revoked", serial) + } + } + return nil +} + +func checkCertSAN(ctx context.Context, cert *x509.Certificate, remoteAddr string) error { if len(cert.IPAddresses) == 0 && len(cert.DNSNames) == 0 { return nil } + h, _, herr := net.SplitHostPort(remoteAddr) if herr != nil { return herr }