Merge pull request #6253 from heyitsanthony/srv-arec

discovery: reject IP address records in SRVGetCluster
This commit is contained in:
Anthony Romano 2016-08-24 06:56:17 -07:00 committed by GitHub
commit 1c989edb47
2 changed files with 45 additions and 47 deletions

View File

@ -55,8 +55,8 @@ func SRVGetCluster(name, dns string, defaultToken string, apurls types.URLs) (st
return err
}
for _, srv := range addrs {
target := strings.TrimSuffix(srv.Target, ".")
host := net.JoinHostPort(target, fmt.Sprintf("%d", srv.Port))
port := fmt.Sprintf("%d", srv.Port)
host := net.JoinHostPort(srv.Target, port)
tcpAddr, err := resolveTCPAddr("tcp", host)
if err != nil {
plog.Warningf("couldn't resolve host %s during SRV discovery", host)
@ -72,8 +72,11 @@ func SRVGetCluster(name, dns string, defaultToken string, apurls types.URLs) (st
n = fmt.Sprintf("%d", tempName)
tempName++
}
stringParts = append(stringParts, fmt.Sprintf("%s=%s%s", n, prefix, host))
plog.Noticef("got bootstrap from DNS for %s at %s%s", service, prefix, host)
// SRV records have a trailing dot but URL shouldn't.
shortHost := strings.TrimSuffix(srv.Target, ".")
urlHost := net.JoinHostPort(shortHost, port)
stringParts = append(stringParts, fmt.Sprintf("%s=%s%s", n, prefix, urlHost))
plog.Noticef("got bootstrap from DNS for %s at %s%s", service, prefix, urlHost)
}
return nil
}

View File

@ -17,6 +17,7 @@ package discovery
import (
"errors"
"net"
"strings"
"testing"
"github.com/coreos/etcd/pkg/testutil"
@ -29,11 +30,22 @@ func TestSRVGetCluster(t *testing.T) {
}()
name := "dnsClusterTest"
dns := map[string]string{
"1.example.com.:2480": "10.0.0.1:2480",
"2.example.com.:2480": "10.0.0.2:2480",
"3.example.com.:2480": "10.0.0.3:2480",
"4.example.com.:2380": "10.0.0.3:2380",
}
srvAll := []*net.SRV{
{Target: "1.example.com.", Port: 2480},
{Target: "2.example.com.", Port: 2480},
{Target: "3.example.com.", Port: 2480},
}
tests := []struct {
withSSL []*net.SRV
withoutSSL []*net.SRV
urls []string
dns map[string]string
expected string
}{
@ -41,61 +53,50 @@ func TestSRVGetCluster(t *testing.T) {
[]*net.SRV{},
[]*net.SRV{},
nil,
nil,
"",
},
{
[]*net.SRV{
{Target: "10.0.0.1", Port: 2480},
{Target: "10.0.0.2", Port: 2480},
{Target: "10.0.0.3", Port: 2480},
},
srvAll,
[]*net.SRV{},
nil,
"0=https://1.example.com:2480,1=https://2.example.com:2480,2=https://3.example.com:2480",
},
{
srvAll,
[]*net.SRV{{Target: "4.example.com.", Port: 2380}},
nil,
"0=https://10.0.0.1:2480,1=https://10.0.0.2:2480,2=https://10.0.0.3:2480",
"0=https://1.example.com:2480,1=https://2.example.com:2480,2=https://3.example.com:2480,3=http://4.example.com:2380",
},
{
[]*net.SRV{
{Target: "10.0.0.1", Port: 2480},
{Target: "10.0.0.2", Port: 2480},
{Target: "10.0.0.3", Port: 2480},
},
[]*net.SRV{
{Target: "10.0.0.1", Port: 2380},
},
nil,
nil,
"0=https://10.0.0.1:2480,1=https://10.0.0.2:2480,2=https://10.0.0.3:2480,3=http://10.0.0.1:2380",
},
{
[]*net.SRV{
{Target: "10.0.0.1", Port: 2480},
{Target: "10.0.0.2", Port: 2480},
{Target: "10.0.0.3", Port: 2480},
},
[]*net.SRV{
{Target: "10.0.0.1", Port: 2380},
},
srvAll,
[]*net.SRV{{Target: "4.example.com.", Port: 2380}},
[]string{"https://10.0.0.1:2480"},
nil,
"dnsClusterTest=https://10.0.0.1:2480,0=https://10.0.0.2:2480,1=https://10.0.0.3:2480,2=http://10.0.0.1:2380",
"dnsClusterTest=https://1.example.com:2480,0=https://2.example.com:2480,1=https://3.example.com:2480,2=http://4.example.com:2380",
},
// matching local member with resolved addr and return unresolved hostnames
{
[]*net.SRV{
{Target: "1.example.com.", Port: 2480},
{Target: "2.example.com.", Port: 2480},
{Target: "3.example.com.", Port: 2480},
},
srvAll,
nil,
[]string{"https://10.0.0.1:2480"},
map[string]string{"1.example.com:2480": "10.0.0.1:2480", "2.example.com:2480": "10.0.0.2:2480", "3.example.com:2480": "10.0.0.3:2480"},
"dnsClusterTest=https://1.example.com:2480,0=https://2.example.com:2480,1=https://3.example.com:2480",
},
// invalid
}
resolveTCPAddr = func(network, addr string) (*net.TCPAddr, error) {
if strings.Contains(addr, "10.0.0.") {
// accept IP addresses when resolving apurls
return net.ResolveTCPAddr(network, addr)
}
if dns[addr] == "" {
return nil, errors.New("missing dns record")
}
return net.ResolveTCPAddr(network, dns[addr])
}
for i, tt := range tests {
@ -108,12 +109,6 @@ func TestSRVGetCluster(t *testing.T) {
}
return "", nil, errors.New("Unknown service in mock")
}
resolveTCPAddr = func(network, addr string) (*net.TCPAddr, error) {
if tt.dns == nil || tt.dns[addr] == "" {
return net.ResolveTCPAddr(network, addr)
}
return net.ResolveTCPAddr(network, tt.dns[addr])
}
urls := testutil.MustNewURLs(t, tt.urls)
str, token, err := SRVGetCluster(name, "example.com", "token", urls)
if err != nil {