diff --git a/discovery/srv.go b/discovery/srv.go index 86365f065..63c7deabc 100644 --- a/discovery/srv.go +++ b/discovery/srv.go @@ -25,7 +25,8 @@ import ( var ( // indirection for testing - lookupSRV = net.LookupSRV + lookupSRV = net.LookupSRV + resolveTCPAddr = net.ResolveTCPAddr ) // TODO(barakmich): Currently ignores priority and weight (as they don't make as much sense for a bootstrap) @@ -38,7 +39,7 @@ func SRVGetCluster(name, dns string, defaultToken string, apurls types.URLs) (st // First, resolve the apurls for _, url := range apurls { - tcpAddr, err := net.ResolveTCPAddr("tcp", url.Host) + tcpAddr, err := resolveTCPAddr("tcp", url.Host) if err != nil { log.Printf("discovery: Couldn't resolve host %s during SRV discovery", url.Host) return "", "", err @@ -53,7 +54,7 @@ func SRVGetCluster(name, dns string, defaultToken string, apurls types.URLs) (st } for _, srv := range addrs { host := net.JoinHostPort(srv.Target, fmt.Sprintf("%d", srv.Port)) - tcpAddr, err := net.ResolveTCPAddr("tcp", host) + tcpAddr, err := resolveTCPAddr("tcp", host) if err != nil { log.Printf("discovery: Couldn't resolve host %s during SRV discovery", host) continue diff --git a/discovery/srv_test.go b/discovery/srv_test.go index f523adb7b..de44682fd 100644 --- a/discovery/srv_test.go +++ b/discovery/srv_test.go @@ -23,19 +23,26 @@ import ( ) func TestSRVGetCluster(t *testing.T) { - defer func() { lookupSRV = net.LookupSRV }() + defer func() { + lookupSRV = net.LookupSRV + resolveTCPAddr = net.ResolveTCPAddr + }() name := "dnsClusterTest" tests := []struct { withSSL []*net.SRV withoutSSL []*net.SRV urls []string - expected string + dns map[string]string + + expected string }{ { []*net.SRV{}, []*net.SRV{}, nil, + nil, + "", }, { @@ -46,6 +53,8 @@ func TestSRVGetCluster(t *testing.T) { }, []*net.SRV{}, nil, + nil, + "0=https://10.0.0.1:2480,1=https://10.0.0.2:2480,2=https://10.0.0.3:2480", }, { @@ -58,6 +67,7 @@ func TestSRVGetCluster(t *testing.T) { &net.SRV{Target: "10.0.0.1", Port: 2380}, }, nil, + nil, "0=https://10.0.0.1:2480,1=https://10.0.0.2:2480,2=https://10.0.0.3:2480,3=http://10.0.0.1:2380", }, { @@ -70,8 +80,22 @@ func TestSRVGetCluster(t *testing.T) { &net.SRV{Target: "10.0.0.1", Port: 2380}, }, []string{"https://10.0.0.1:2480"}, + nil, "dnsClusterTest=https://10.0.0.1:2480,0=https://10.0.0.2:2480,1=https://10.0.0.3:2480,2=http://10.0.0.1:2380", }, + // matching local member with resolved addr and return unresolved hostnames + { + []*net.SRV{ + &net.SRV{Target: "1.example.com.", Port: 2480}, + &net.SRV{Target: "2.example.com.", Port: 2480}, + &net.SRV{Target: "3.example.com.", Port: 2480}, + }, + nil, + []string{"https://10.0.0.1:2480"}, + map[string]string{"1.example.com:2480": "10.0.0.1:2480", "2.example.com:2480": "10.0.0.2:2480", "3.example.com:2480": "10.0.0.3:2480"}, + + "dnsClusterTest=https://1.example.com:2480,0=https://2.example.com:2480,1=https://3.example.com:2480", + }, } for i, tt := range tests { @@ -84,6 +108,12 @@ func TestSRVGetCluster(t *testing.T) { } return "", nil, errors.New("Unkown service in mock") } + resolveTCPAddr = func(network, addr string) (*net.TCPAddr, error) { + if tt.dns == nil || tt.dns[addr] == "" { + return net.ResolveTCPAddr(network, addr) + } + return net.ResolveTCPAddr(network, tt.dns[addr]) + } urls := testutil.MustNewURLs(t, tt.urls) str, token, err := SRVGetCluster(name, "example.com", "token", urls) if err != nil {