transport: resolve DNSNames when SAN checking

The current transport client TLS checking will pass an IP address into
VerifyHostnames if there is DNSNames SAN. However, the go runtime will
not resolve the DNS names to match the client IP. Intead, resolve the
names when checking.
This commit is contained in:
Anthony Romano 2017-04-18 12:56:15 -07:00
parent 8fdf8f752b
commit 05582ad5b2

View File

@ -15,7 +15,9 @@
package transport
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"sync"
@ -40,11 +42,16 @@ func newTLSListener(l net.Listener, tlsinfo *TLSInfo) (net.Listener, error) {
if err != nil {
return nil, err
}
hf := tlsinfo.HandshakeFailure
if hf == nil {
hf = func(*tls.Conn, error) {}
}
tlsl := &tlsListener{
Listener: tls.NewListener(l, tlscfg),
connc: make(chan net.Conn),
donec: make(chan struct{}),
handshakeFailure: tlsinfo.HandshakeFailure,
handshakeFailure: hf,
}
go tlsl.acceptLoop()
return tlsl, nil
@ -66,9 +73,9 @@ func (l *tlsListener) acceptLoop() {
var pendingMu sync.Mutex
pending := make(map[net.Conn]struct{})
stopc := make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
defer func() {
close(stopc)
cancel()
pendingMu.Lock()
for c := range pending {
c.Close()
@ -104,32 +111,58 @@ func (l *tlsListener) acceptLoop() {
delete(pending, conn)
pendingMu.Unlock()
if herr != nil {
if l.handshakeFailure != nil {
l.handshakeFailure(tlsConn, herr)
}
l.handshakeFailure(tlsConn, herr)
return
}
st := tlsConn.ConnectionState()
if len(st.PeerCertificates) > 0 {
cert := st.PeerCertificates[0]
if len(cert.IPAddresses) > 0 || len(cert.DNSNames) > 0 {
addr := tlsConn.RemoteAddr().String()
h, _, herr := net.SplitHostPort(addr)
if herr != nil || cert.VerifyHostname(h) != nil {
return
}
addr := tlsConn.RemoteAddr().String()
if cerr := checkCert(ctx, cert, addr); cerr != nil {
l.handshakeFailure(tlsConn, cerr)
return
}
}
select {
case l.connc <- tlsConn:
conn = nil
case <-stopc:
case <-ctx.Done():
}
}()
}
}
func checkCert(ctx context.Context, cert *x509.Certificate, remoteAddr string) error {
h, _, herr := net.SplitHostPort(remoteAddr)
if len(cert.IPAddresses) == 0 && len(cert.DNSNames) == 0 {
return nil
}
if herr != nil {
return herr
}
if len(cert.IPAddresses) > 0 {
if cerr := cert.VerifyHostname(h); cerr != nil && len(cert.DNSNames) == 0 {
return cerr
}
}
if len(cert.DNSNames) > 0 {
for _, dns := range cert.DNSNames {
addrs, lerr := net.DefaultResolver.LookupHost(ctx, dns)
if lerr != nil {
continue
}
for _, addr := range addrs {
if addr == h {
return nil
}
}
}
return fmt.Errorf("tls: %q does not match any of DNSNames %q", h, cert.DNSNames)
}
return nil
}
func (l *tlsListener) Close() error {
err := l.Listener.Close()
<-l.donec