mirror of
https://github.com/etcd-io/etcd.git
synced 2024-09-27 06:25:44 +00:00
Merge pull request #7687 from heyitsanthony/deny-tls-ipsan
transport: deny incoming peer certs with wrong IP SAN
This commit is contained in:
commit
1153e1e7d9
@ -201,7 +201,6 @@ func startPeerListeners(cfg *Config) (plns []net.Listener, err error) {
|
||||
}()
|
||||
|
||||
for i, u := range cfg.LPUrls {
|
||||
var tlscfg *tls.Config
|
||||
if u.Scheme == "http" {
|
||||
if !cfg.PeerTLSInfo.Empty() {
|
||||
plog.Warningf("The scheme of peer url %s is HTTP while peer key/cert files are presented. Ignored peer key/cert files.", u.String())
|
||||
@ -210,12 +209,7 @@ func startPeerListeners(cfg *Config) (plns []net.Listener, err error) {
|
||||
plog.Warningf("The scheme of peer url %s is HTTP while client cert auth (--peer-client-cert-auth) is enabled. Ignored client cert auth for this url.", u.String())
|
||||
}
|
||||
}
|
||||
if !cfg.PeerTLSInfo.Empty() {
|
||||
if tlscfg, err = cfg.PeerTLSInfo.ServerConfig(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if plns[i], err = rafthttp.NewListener(u, tlscfg); err != nil {
|
||||
if plns[i], err = rafthttp.NewListener(u, &cfg.PeerTLSInfo); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
plog.Info("listening for peers on ", u.String())
|
||||
|
@ -19,7 +19,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@ -305,18 +304,7 @@ func startProxy(cfg *config) error {
|
||||
}
|
||||
// Start a proxy server goroutine for each listen address
|
||||
for _, u := range cfg.LCUrls {
|
||||
var (
|
||||
l net.Listener
|
||||
tlscfg *tls.Config
|
||||
)
|
||||
if !cfg.ClientTLSInfo.Empty() {
|
||||
tlscfg, err = cfg.ClientTLSInfo.ServerConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
l, err := transport.NewListener(u.Host, u.Scheme, tlscfg)
|
||||
l, err := transport.NewListener(u.Host, u.Scheme, &cfg.ClientTLSInfo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -369,6 +357,11 @@ func identifyDataDirOrDie(dir string) dirType {
|
||||
}
|
||||
|
||||
func setupLogging(cfg *config) {
|
||||
cfg.ClientTLSInfo.HandshakeFailure = func(conn *tls.Conn, err error) {
|
||||
plog.Infof("rejected connection from %q (%v)", conn.RemoteAddr().String(), err)
|
||||
}
|
||||
cfg.PeerTLSInfo.HandshakeFailure = cfg.ClientTLSInfo.HandshakeFailure
|
||||
|
||||
capnslog.SetGlobalLogLevel(capnslog.INFO)
|
||||
if cfg.Debug {
|
||||
capnslog.SetGlobalLogLevel(capnslog.DEBUG)
|
||||
|
@ -18,7 +18,6 @@ import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@ -50,12 +49,12 @@ func TestNewKeepAliveListener(t *testing.T) {
|
||||
}
|
||||
|
||||
// tls
|
||||
tmp, err := createTempFile([]byte("XXX"))
|
||||
tlsinfo, del, err := createSelfCert()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create tmpfile: %v", err)
|
||||
}
|
||||
defer os.Remove(tmp)
|
||||
tlsInfo := TLSInfo{CertFile: tmp, KeyFile: tmp}
|
||||
defer del()
|
||||
tlsInfo := TLSInfo{CertFile: tlsinfo.CertFile, KeyFile: tlsinfo.KeyFile}
|
||||
tlsInfo.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil)
|
||||
tlscfg, err := tlsInfo.ServerConfig()
|
||||
if err != nil {
|
||||
|
@ -33,11 +33,11 @@ import (
|
||||
"github.com/coreos/etcd/pkg/tlsutil"
|
||||
)
|
||||
|
||||
func NewListener(addr, scheme string, tlscfg *tls.Config) (l net.Listener, err error) {
|
||||
func NewListener(addr, scheme string, tlsinfo *TLSInfo) (l net.Listener, err error) {
|
||||
if l, err = newListener(addr, scheme); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return wrapTLS(addr, scheme, tlscfg, l)
|
||||
return wrapTLS(addr, scheme, tlsinfo, l)
|
||||
}
|
||||
|
||||
func newListener(addr string, scheme string) (net.Listener, error) {
|
||||
@ -48,15 +48,11 @@ func newListener(addr string, scheme string) (net.Listener, error) {
|
||||
return net.Listen("tcp", addr)
|
||||
}
|
||||
|
||||
func wrapTLS(addr, scheme string, tlscfg *tls.Config, l net.Listener) (net.Listener, error) {
|
||||
func wrapTLS(addr, scheme string, tlsinfo *TLSInfo, l net.Listener) (net.Listener, error) {
|
||||
if scheme != "https" && scheme != "unixs" {
|
||||
return l, nil
|
||||
}
|
||||
if tlscfg == nil {
|
||||
l.Close()
|
||||
return nil, fmt.Errorf("cannot listen on TLS for %s: KeyFile and CertFile are not presented", scheme+"://"+addr)
|
||||
}
|
||||
return tls.NewListener(l, tlscfg), nil
|
||||
return newTLSListener(l, tlsinfo)
|
||||
}
|
||||
|
||||
type TLSInfo struct {
|
||||
@ -69,6 +65,10 @@ type TLSInfo struct {
|
||||
// ServerName ensures the cert matches the given host in case of discovery / virtual hosting
|
||||
ServerName string
|
||||
|
||||
// HandshakeFailure is optinally called when a connection fails to handshake. The
|
||||
// connection will be closed immediately afterwards.
|
||||
HandshakeFailure func(*tls.Conn, error)
|
||||
|
||||
selfCert bool
|
||||
|
||||
// parseFunc exists to simplify testing. Typically, parseFunc
|
||||
|
@ -24,18 +24,16 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func createTempFile(b []byte) (string, error) {
|
||||
f, err := ioutil.TempFile("", "etcd-test-tls-")
|
||||
func createSelfCert() (*TLSInfo, func(), error) {
|
||||
d, terr := ioutil.TempDir("", "etcd-test-tls-")
|
||||
if terr != nil {
|
||||
return nil, nil, terr
|
||||
}
|
||||
info, err := SelfCert(d, []string{"127.0.0.1"})
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if _, err = f.Write(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return f.Name(), nil
|
||||
return &info, func() { os.RemoveAll(d) }, nil
|
||||
}
|
||||
|
||||
func fakeCertificateParserFunc(cert tls.Certificate, err error) func(certPEMBlock, keyPEMBlock []byte) (tls.Certificate, error) {
|
||||
@ -47,28 +45,25 @@ func fakeCertificateParserFunc(cert tls.Certificate, err error) func(certPEMBloc
|
||||
// TestNewListenerTLSInfo tests that NewListener with valid TLSInfo returns
|
||||
// a TLS listener that accepts TLS connections.
|
||||
func TestNewListenerTLSInfo(t *testing.T) {
|
||||
tmp, err := createTempFile([]byte("XXX"))
|
||||
tlsInfo, del, err := createSelfCert()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create tmpfile: %v", err)
|
||||
t.Fatalf("unable to create cert: %v", err)
|
||||
}
|
||||
defer os.Remove(tmp)
|
||||
tlsInfo := TLSInfo{CertFile: tmp, KeyFile: tmp}
|
||||
tlsInfo.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil)
|
||||
testNewListenerTLSInfoAccept(t, tlsInfo)
|
||||
defer del()
|
||||
testNewListenerTLSInfoAccept(t, *tlsInfo)
|
||||
}
|
||||
|
||||
func testNewListenerTLSInfoAccept(t *testing.T, tlsInfo TLSInfo) {
|
||||
tlscfg, err := tlsInfo.ServerConfig()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected serverConfig error: %v", err)
|
||||
}
|
||||
ln, err := NewListener("127.0.0.1:0", "https", tlscfg)
|
||||
ln, err := NewListener("127.0.0.1:0", "https", &tlsInfo)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected NewListener error: %v", err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
go http.Get("https://" + ln.Addr().String())
|
||||
tr := &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
|
||||
cli := &http.Client{Transport: tr}
|
||||
go cli.Get("https://" + ln.Addr().String())
|
||||
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected Accept error: %v", err)
|
||||
@ -87,25 +82,25 @@ func TestNewListenerTLSEmptyInfo(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewTransportTLSInfo(t *testing.T) {
|
||||
tmp, err := createTempFile([]byte("XXX"))
|
||||
tlsinfo, del, err := createSelfCert()
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to prepare tmpfile: %v", err)
|
||||
t.Fatalf("unable to create cert: %v", err)
|
||||
}
|
||||
defer os.Remove(tmp)
|
||||
defer del()
|
||||
|
||||
tests := []TLSInfo{
|
||||
{},
|
||||
{
|
||||
CertFile: tmp,
|
||||
KeyFile: tmp,
|
||||
CertFile: tlsinfo.CertFile,
|
||||
KeyFile: tlsinfo.KeyFile,
|
||||
},
|
||||
{
|
||||
CertFile: tmp,
|
||||
KeyFile: tmp,
|
||||
CAFile: tmp,
|
||||
CertFile: tlsinfo.CertFile,
|
||||
KeyFile: tlsinfo.KeyFile,
|
||||
CAFile: tlsinfo.CAFile,
|
||||
},
|
||||
{
|
||||
CAFile: tmp,
|
||||
CAFile: tlsinfo.CAFile,
|
||||
},
|
||||
}
|
||||
|
||||
@ -159,17 +154,17 @@ func TestTLSInfoEmpty(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestTLSInfoMissingFields(t *testing.T) {
|
||||
tmp, err := createTempFile([]byte("XXX"))
|
||||
tlsinfo, del, err := createSelfCert()
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to prepare tmpfile: %v", err)
|
||||
t.Fatalf("unable to create cert: %v", err)
|
||||
}
|
||||
defer os.Remove(tmp)
|
||||
defer del()
|
||||
|
||||
tests := []TLSInfo{
|
||||
{CertFile: tmp},
|
||||
{KeyFile: tmp},
|
||||
{CertFile: tmp, CAFile: tmp},
|
||||
{KeyFile: tmp, CAFile: tmp},
|
||||
{CertFile: tlsinfo.CertFile},
|
||||
{KeyFile: tlsinfo.KeyFile},
|
||||
{CertFile: tlsinfo.CertFile, CAFile: tlsinfo.CAFile},
|
||||
{KeyFile: tlsinfo.KeyFile, CAFile: tlsinfo.CAFile},
|
||||
}
|
||||
|
||||
for i, info := range tests {
|
||||
@ -184,30 +179,29 @@ func TestTLSInfoMissingFields(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestTLSInfoParseFuncError(t *testing.T) {
|
||||
tmp, err := createTempFile([]byte("XXX"))
|
||||
tlsinfo, del, err := createSelfCert()
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to prepare tmpfile: %v", err)
|
||||
t.Fatalf("unable to create cert: %v", err)
|
||||
}
|
||||
defer os.Remove(tmp)
|
||||
defer del()
|
||||
|
||||
info := TLSInfo{CertFile: tmp, KeyFile: tmp, CAFile: tmp}
|
||||
info.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, errors.New("fake"))
|
||||
tlsinfo.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, errors.New("fake"))
|
||||
|
||||
if _, err = info.ServerConfig(); err == nil {
|
||||
if _, err = tlsinfo.ServerConfig(); err == nil {
|
||||
t.Errorf("expected non-nil error from ServerConfig()")
|
||||
}
|
||||
|
||||
if _, err = info.ClientConfig(); err == nil {
|
||||
if _, err = tlsinfo.ClientConfig(); err == nil {
|
||||
t.Errorf("expected non-nil error from ClientConfig()")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSInfoConfigFuncs(t *testing.T) {
|
||||
tmp, err := createTempFile([]byte("XXX"))
|
||||
tlsinfo, del, err := createSelfCert()
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to prepare tmpfile: %v", err)
|
||||
t.Fatalf("unable to create cert: %v", err)
|
||||
}
|
||||
defer os.Remove(tmp)
|
||||
defer del()
|
||||
|
||||
tests := []struct {
|
||||
info TLSInfo
|
||||
@ -215,13 +209,13 @@ func TestTLSInfoConfigFuncs(t *testing.T) {
|
||||
wantCAs bool
|
||||
}{
|
||||
{
|
||||
info: TLSInfo{CertFile: tmp, KeyFile: tmp},
|
||||
info: TLSInfo{CertFile: tlsinfo.CertFile, KeyFile: tlsinfo.KeyFile},
|
||||
clientAuth: tls.NoClientCert,
|
||||
wantCAs: false,
|
||||
},
|
||||
|
||||
{
|
||||
info: TLSInfo{CertFile: tmp, KeyFile: tmp, CAFile: tmp},
|
||||
info: TLSInfo{CertFile: tlsinfo.CertFile, KeyFile: tlsinfo.KeyFile, CAFile: tlsinfo.CertFile},
|
||||
clientAuth: tls.RequireAndVerifyClientCert,
|
||||
wantCAs: true,
|
||||
},
|
||||
|
137
pkg/transport/listener_tls.go
Normal file
137
pkg/transport/listener_tls.go
Normal file
@ -0,0 +1,137 @@
|
||||
// Copyright 2017 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// tlsListener overrides a TLS listener so it will reject client
|
||||
// certificates with insufficient SAN credentials.
|
||||
type tlsListener struct {
|
||||
net.Listener
|
||||
connc chan net.Conn
|
||||
donec chan struct{}
|
||||
err error
|
||||
handshakeFailure func(*tls.Conn, error)
|
||||
}
|
||||
|
||||
func newTLSListener(l net.Listener, tlsinfo *TLSInfo) (net.Listener, error) {
|
||||
if tlsinfo == nil || tlsinfo.Empty() {
|
||||
l.Close()
|
||||
return nil, fmt.Errorf("cannot listen on TLS for %s: KeyFile and CertFile are not presented", l.Addr().String())
|
||||
}
|
||||
tlscfg, err := tlsinfo.ServerConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tlsl := &tlsListener{
|
||||
Listener: tls.NewListener(l, tlscfg),
|
||||
connc: make(chan net.Conn),
|
||||
donec: make(chan struct{}),
|
||||
handshakeFailure: tlsinfo.HandshakeFailure,
|
||||
}
|
||||
go tlsl.acceptLoop()
|
||||
return tlsl, nil
|
||||
}
|
||||
|
||||
func (l *tlsListener) Accept() (net.Conn, error) {
|
||||
select {
|
||||
case conn := <-l.connc:
|
||||
return conn, nil
|
||||
case <-l.donec:
|
||||
return nil, l.err
|
||||
}
|
||||
}
|
||||
|
||||
// acceptLoop launches each TLS handshake in a separate goroutine
|
||||
// to prevent a hanging TLS connection from blocking other connections.
|
||||
func (l *tlsListener) acceptLoop() {
|
||||
var wg sync.WaitGroup
|
||||
var pendingMu sync.Mutex
|
||||
|
||||
pending := make(map[net.Conn]struct{})
|
||||
stopc := make(chan struct{})
|
||||
defer func() {
|
||||
close(stopc)
|
||||
pendingMu.Lock()
|
||||
for c := range pending {
|
||||
c.Close()
|
||||
}
|
||||
pendingMu.Unlock()
|
||||
wg.Wait()
|
||||
close(l.donec)
|
||||
}()
|
||||
|
||||
for {
|
||||
conn, err := l.Listener.Accept()
|
||||
if err != nil {
|
||||
l.err = err
|
||||
return
|
||||
}
|
||||
|
||||
pendingMu.Lock()
|
||||
pending[conn] = struct{}{}
|
||||
pendingMu.Unlock()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer func() {
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
tlsConn := conn.(*tls.Conn)
|
||||
herr := tlsConn.Handshake()
|
||||
pendingMu.Lock()
|
||||
delete(pending, conn)
|
||||
pendingMu.Unlock()
|
||||
if herr != nil {
|
||||
if l.handshakeFailure != nil {
|
||||
l.handshakeFailure(tlsConn, herr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
st := tlsConn.ConnectionState()
|
||||
if len(st.PeerCertificates) > 0 {
|
||||
cert := st.PeerCertificates[0]
|
||||
if len(cert.IPAddresses) > 0 || len(cert.DNSNames) > 0 {
|
||||
addr := tlsConn.RemoteAddr().String()
|
||||
h, _, herr := net.SplitHostPort(addr)
|
||||
if herr != nil || cert.VerifyHostname(h) != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
select {
|
||||
case l.connc <- tlsConn:
|
||||
conn = nil
|
||||
case <-stopc:
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (l *tlsListener) Close() error {
|
||||
err := l.Listener.Close()
|
||||
<-l.donec
|
||||
return err
|
||||
}
|
@ -15,7 +15,6 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
@ -23,7 +22,7 @@ import (
|
||||
// NewTimeoutListener returns a listener that listens on the given address.
|
||||
// If read/write on the accepted connection blocks longer than its time limit,
|
||||
// it will return timeout error.
|
||||
func NewTimeoutListener(addr string, scheme string, tlscfg *tls.Config, rdtimeoutd, wtimeoutd time.Duration) (net.Listener, error) {
|
||||
func NewTimeoutListener(addr string, scheme string, tlsinfo *TLSInfo, rdtimeoutd, wtimeoutd time.Duration) (net.Listener, error) {
|
||||
ln, err := newListener(addr, scheme)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -33,7 +32,7 @@ func NewTimeoutListener(addr string, scheme string, tlscfg *tls.Config, rdtimeou
|
||||
rdtimeoutd: rdtimeoutd,
|
||||
wtimeoutd: wtimeoutd,
|
||||
}
|
||||
if ln, err = wrapTLS(addr, scheme, tlscfg, ln); err != nil {
|
||||
if ln, err = wrapTLS(addr, scheme, tlsinfo, ln); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ln, nil
|
||||
|
@ -15,7 +15,6 @@
|
||||
package rafthttp
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@ -37,8 +36,8 @@ var (
|
||||
|
||||
// NewListener returns a listener for raft message transfer between peers.
|
||||
// It uses timeout listener to identify broken streams promptly.
|
||||
func NewListener(u url.URL, tlscfg *tls.Config) (net.Listener, error) {
|
||||
return transport.NewTimeoutListener(u.Host, u.Scheme, tlscfg, ConnReadTimeout, ConnWriteTimeout)
|
||||
func NewListener(u url.URL, tlsinfo *transport.TLSInfo) (net.Listener, error) {
|
||||
return transport.NewTimeoutListener(u.Host, u.Scheme, tlsinfo, ConnReadTimeout, ConnWriteTimeout)
|
||||
}
|
||||
|
||||
// NewRoundTripper returns a roundTripper used to send requests
|
||||
|
Loading…
x
Reference in New Issue
Block a user