mirror of
https://github.com/etcd-io/etcd.git
synced 2024-09-27 06:25:44 +00:00
Merge pull request #7767 from heyitsanthony/transport-resolve-dnsnames
transport: resolve DNSNames when SAN checking
This commit is contained in:
commit
8fa4b8da6e
@ -15,7 +15,9 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
@ -40,11 +42,16 @@ func newTLSListener(l net.Listener, tlsinfo *TLSInfo) (net.Listener, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hf := tlsinfo.HandshakeFailure
|
||||
if hf == nil {
|
||||
hf = func(*tls.Conn, error) {}
|
||||
}
|
||||
tlsl := &tlsListener{
|
||||
Listener: tls.NewListener(l, tlscfg),
|
||||
connc: make(chan net.Conn),
|
||||
donec: make(chan struct{}),
|
||||
handshakeFailure: tlsinfo.HandshakeFailure,
|
||||
handshakeFailure: hf,
|
||||
}
|
||||
go tlsl.acceptLoop()
|
||||
return tlsl, nil
|
||||
@ -66,9 +73,9 @@ func (l *tlsListener) acceptLoop() {
|
||||
var pendingMu sync.Mutex
|
||||
|
||||
pending := make(map[net.Conn]struct{})
|
||||
stopc := make(chan struct{})
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer func() {
|
||||
close(stopc)
|
||||
cancel()
|
||||
pendingMu.Lock()
|
||||
for c := range pending {
|
||||
c.Close()
|
||||
@ -104,32 +111,58 @@ func (l *tlsListener) acceptLoop() {
|
||||
delete(pending, conn)
|
||||
pendingMu.Unlock()
|
||||
if herr != nil {
|
||||
if l.handshakeFailure != nil {
|
||||
l.handshakeFailure(tlsConn, herr)
|
||||
}
|
||||
l.handshakeFailure(tlsConn, herr)
|
||||
return
|
||||
}
|
||||
|
||||
st := tlsConn.ConnectionState()
|
||||
if len(st.PeerCertificates) > 0 {
|
||||
cert := st.PeerCertificates[0]
|
||||
if len(cert.IPAddresses) > 0 || len(cert.DNSNames) > 0 {
|
||||
addr := tlsConn.RemoteAddr().String()
|
||||
h, _, herr := net.SplitHostPort(addr)
|
||||
if herr != nil || cert.VerifyHostname(h) != nil {
|
||||
return
|
||||
}
|
||||
addr := tlsConn.RemoteAddr().String()
|
||||
if cerr := checkCert(ctx, cert, addr); cerr != nil {
|
||||
l.handshakeFailure(tlsConn, cerr)
|
||||
return
|
||||
}
|
||||
}
|
||||
select {
|
||||
case l.connc <- tlsConn:
|
||||
conn = nil
|
||||
case <-stopc:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func checkCert(ctx context.Context, cert *x509.Certificate, remoteAddr string) error {
|
||||
h, _, herr := net.SplitHostPort(remoteAddr)
|
||||
if len(cert.IPAddresses) == 0 && len(cert.DNSNames) == 0 {
|
||||
return nil
|
||||
}
|
||||
if herr != nil {
|
||||
return herr
|
||||
}
|
||||
if len(cert.IPAddresses) > 0 {
|
||||
if cerr := cert.VerifyHostname(h); cerr != nil && len(cert.DNSNames) == 0 {
|
||||
return cerr
|
||||
}
|
||||
}
|
||||
if len(cert.DNSNames) > 0 {
|
||||
for _, dns := range cert.DNSNames {
|
||||
addrs, lerr := net.DefaultResolver.LookupHost(ctx, dns)
|
||||
if lerr != nil {
|
||||
continue
|
||||
}
|
||||
for _, addr := range addrs {
|
||||
if addr == h {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("tls: %q does not match any of DNSNames %q", h, cert.DNSNames)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *tlsListener) Close() error {
|
||||
err := l.Listener.Close()
|
||||
<-l.donec
|
||||
|
Loading…
x
Reference in New Issue
Block a user