From 3dc12e33f1aa6b464158a0b0931c30693e0e9a04 Mon Sep 17 00:00:00 2001 From: Anthony Romano Date: Tue, 23 Aug 2016 18:03:03 -0700 Subject: [PATCH] discovery: reject IP address records in SRVGetCluster Was incorrectly trimming the trailing '.' from the target; this in turn caused the etcd server to accept any SRV record with an IP target instead of only targets with A records. --- discovery/srv.go | 11 +++--- discovery/srv_test.go | 81 ++++++++++++++++++++----------------------- 2 files changed, 45 insertions(+), 47 deletions(-) diff --git a/discovery/srv.go b/discovery/srv.go index 34884ddcb..bac43ebb6 100644 --- a/discovery/srv.go +++ b/discovery/srv.go @@ -55,8 +55,8 @@ func SRVGetCluster(name, dns string, defaultToken string, apurls types.URLs) (st return err } for _, srv := range addrs { - target := strings.TrimSuffix(srv.Target, ".") - host := net.JoinHostPort(target, fmt.Sprintf("%d", srv.Port)) + port := fmt.Sprintf("%d", srv.Port) + host := net.JoinHostPort(srv.Target, port) tcpAddr, err := resolveTCPAddr("tcp", host) if err != nil { plog.Warningf("couldn't resolve host %s during SRV discovery", host) @@ -72,8 +72,11 @@ func SRVGetCluster(name, dns string, defaultToken string, apurls types.URLs) (st n = fmt.Sprintf("%d", tempName) tempName++ } - stringParts = append(stringParts, fmt.Sprintf("%s=%s%s", n, prefix, host)) - plog.Noticef("got bootstrap from DNS for %s at %s%s", service, prefix, host) + // SRV records have a trailing dot but URL shouldn't. + shortHost := strings.TrimSuffix(srv.Target, ".") + urlHost := net.JoinHostPort(shortHost, port) + stringParts = append(stringParts, fmt.Sprintf("%s=%s%s", n, prefix, urlHost)) + plog.Noticef("got bootstrap from DNS for %s at %s%s", service, prefix, urlHost) } return nil } diff --git a/discovery/srv_test.go b/discovery/srv_test.go index 4b8e2ed1e..c90f9b682 100644 --- a/discovery/srv_test.go +++ b/discovery/srv_test.go @@ -17,6 +17,7 @@ package discovery import ( "errors" "net" + "strings" "testing" "github.com/coreos/etcd/pkg/testutil" @@ -29,11 +30,22 @@ func TestSRVGetCluster(t *testing.T) { }() name := "dnsClusterTest" + dns := map[string]string{ + "1.example.com.:2480": "10.0.0.1:2480", + "2.example.com.:2480": "10.0.0.2:2480", + "3.example.com.:2480": "10.0.0.3:2480", + "4.example.com.:2380": "10.0.0.3:2380", + } + srvAll := []*net.SRV{ + {Target: "1.example.com.", Port: 2480}, + {Target: "2.example.com.", Port: 2480}, + {Target: "3.example.com.", Port: 2480}, + } + tests := []struct { withSSL []*net.SRV withoutSSL []*net.SRV urls []string - dns map[string]string expected string }{ @@ -41,61 +53,50 @@ func TestSRVGetCluster(t *testing.T) { []*net.SRV{}, []*net.SRV{}, nil, - nil, "", }, { - []*net.SRV{ - {Target: "10.0.0.1", Port: 2480}, - {Target: "10.0.0.2", Port: 2480}, - {Target: "10.0.0.3", Port: 2480}, - }, + srvAll, []*net.SRV{}, nil, + + "0=https://1.example.com:2480,1=https://2.example.com:2480,2=https://3.example.com:2480", + }, + { + srvAll, + []*net.SRV{{Target: "4.example.com.", Port: 2380}}, nil, - "0=https://10.0.0.1:2480,1=https://10.0.0.2:2480,2=https://10.0.0.3:2480", + "0=https://1.example.com:2480,1=https://2.example.com:2480,2=https://3.example.com:2480,3=http://4.example.com:2380", }, { - []*net.SRV{ - {Target: "10.0.0.1", Port: 2480}, - {Target: "10.0.0.2", Port: 2480}, - {Target: "10.0.0.3", Port: 2480}, - }, - []*net.SRV{ - {Target: "10.0.0.1", Port: 2380}, - }, - nil, - nil, - "0=https://10.0.0.1:2480,1=https://10.0.0.2:2480,2=https://10.0.0.3:2480,3=http://10.0.0.1:2380", - }, - { - []*net.SRV{ - {Target: "10.0.0.1", Port: 2480}, - {Target: "10.0.0.2", Port: 2480}, - {Target: "10.0.0.3", Port: 2480}, - }, - []*net.SRV{ - {Target: "10.0.0.1", Port: 2380}, - }, + srvAll, + []*net.SRV{{Target: "4.example.com.", Port: 2380}}, []string{"https://10.0.0.1:2480"}, - nil, - "dnsClusterTest=https://10.0.0.1:2480,0=https://10.0.0.2:2480,1=https://10.0.0.3:2480,2=http://10.0.0.1:2380", + + "dnsClusterTest=https://1.example.com:2480,0=https://2.example.com:2480,1=https://3.example.com:2480,2=http://4.example.com:2380", }, // matching local member with resolved addr and return unresolved hostnames { - []*net.SRV{ - {Target: "1.example.com.", Port: 2480}, - {Target: "2.example.com.", Port: 2480}, - {Target: "3.example.com.", Port: 2480}, - }, + srvAll, nil, []string{"https://10.0.0.1:2480"}, - map[string]string{"1.example.com:2480": "10.0.0.1:2480", "2.example.com:2480": "10.0.0.2:2480", "3.example.com:2480": "10.0.0.3:2480"}, "dnsClusterTest=https://1.example.com:2480,0=https://2.example.com:2480,1=https://3.example.com:2480", }, + // invalid + } + + resolveTCPAddr = func(network, addr string) (*net.TCPAddr, error) { + if strings.Contains(addr, "10.0.0.") { + // accept IP addresses when resolving apurls + return net.ResolveTCPAddr(network, addr) + } + if dns[addr] == "" { + return nil, errors.New("missing dns record") + } + return net.ResolveTCPAddr(network, dns[addr]) } for i, tt := range tests { @@ -108,12 +109,6 @@ func TestSRVGetCluster(t *testing.T) { } return "", nil, errors.New("Unknown service in mock") } - resolveTCPAddr = func(network, addr string) (*net.TCPAddr, error) { - if tt.dns == nil || tt.dns[addr] == "" { - return net.ResolveTCPAddr(network, addr) - } - return net.ResolveTCPAddr(network, tt.dns[addr]) - } urls := testutil.MustNewURLs(t, tt.urls) str, token, err := SRVGetCluster(name, "example.com", "token", urls) if err != nil {