discovery: add a test case for srv

During srv discovery, it should try to match local member with
resolved addr and return unresolved hostnames for the cluster.

Conflicts:
	discovery/srv_test.go
This commit is contained in:
Xiang Li 2015-03-31 10:39:46 -07:00 committed by Yicheng Qin
parent 21455d2f3b
commit 6a3bb93305
2 changed files with 36 additions and 5 deletions

View File

@ -25,7 +25,8 @@ import (
var ( var (
// indirection for testing // indirection for testing
lookupSRV = net.LookupSRV lookupSRV = net.LookupSRV
resolveTCPAddr = net.ResolveTCPAddr
) )
// TODO(barakmich): Currently ignores priority and weight (as they don't make as much sense for a bootstrap) // TODO(barakmich): Currently ignores priority and weight (as they don't make as much sense for a bootstrap)
@ -38,7 +39,7 @@ func SRVGetCluster(name, dns string, defaultToken string, apurls types.URLs) (st
// First, resolve the apurls // First, resolve the apurls
for _, url := range apurls { for _, url := range apurls {
tcpAddr, err := net.ResolveTCPAddr("tcp", url.Host) tcpAddr, err := resolveTCPAddr("tcp", url.Host)
if err != nil { if err != nil {
log.Printf("discovery: Couldn't resolve host %s during SRV discovery", url.Host) log.Printf("discovery: Couldn't resolve host %s during SRV discovery", url.Host)
return "", "", err return "", "", err
@ -53,7 +54,7 @@ func SRVGetCluster(name, dns string, defaultToken string, apurls types.URLs) (st
} }
for _, srv := range addrs { for _, srv := range addrs {
host := net.JoinHostPort(srv.Target, fmt.Sprintf("%d", srv.Port)) host := net.JoinHostPort(srv.Target, fmt.Sprintf("%d", srv.Port))
tcpAddr, err := net.ResolveTCPAddr("tcp", host) tcpAddr, err := resolveTCPAddr("tcp", host)
if err != nil { if err != nil {
log.Printf("discovery: Couldn't resolve host %s during SRV discovery", host) log.Printf("discovery: Couldn't resolve host %s during SRV discovery", host)
continue continue

View File

@ -23,19 +23,26 @@ import (
) )
func TestSRVGetCluster(t *testing.T) { func TestSRVGetCluster(t *testing.T) {
defer func() { lookupSRV = net.LookupSRV }() defer func() {
lookupSRV = net.LookupSRV
resolveTCPAddr = net.ResolveTCPAddr
}()
name := "dnsClusterTest" name := "dnsClusterTest"
tests := []struct { tests := []struct {
withSSL []*net.SRV withSSL []*net.SRV
withoutSSL []*net.SRV withoutSSL []*net.SRV
urls []string urls []string
expected string dns map[string]string
expected string
}{ }{
{ {
[]*net.SRV{}, []*net.SRV{},
[]*net.SRV{}, []*net.SRV{},
nil, nil,
nil,
"", "",
}, },
{ {
@ -46,6 +53,8 @@ func TestSRVGetCluster(t *testing.T) {
}, },
[]*net.SRV{}, []*net.SRV{},
nil, nil,
nil,
"0=https://10.0.0.1:2480,1=https://10.0.0.2:2480,2=https://10.0.0.3:2480", "0=https://10.0.0.1:2480,1=https://10.0.0.2:2480,2=https://10.0.0.3:2480",
}, },
{ {
@ -58,6 +67,7 @@ func TestSRVGetCluster(t *testing.T) {
&net.SRV{Target: "10.0.0.1", Port: 7001}, &net.SRV{Target: "10.0.0.1", Port: 7001},
}, },
nil, nil,
nil,
"0=https://10.0.0.1:2480,1=https://10.0.0.2:2480,2=https://10.0.0.3:2480,3=http://10.0.0.1:7001", "0=https://10.0.0.1:2480,1=https://10.0.0.2:2480,2=https://10.0.0.3:2480,3=http://10.0.0.1:7001",
}, },
{ {
@ -70,8 +80,22 @@ func TestSRVGetCluster(t *testing.T) {
&net.SRV{Target: "10.0.0.1", Port: 7001}, &net.SRV{Target: "10.0.0.1", Port: 7001},
}, },
[]string{"https://10.0.0.1:2480"}, []string{"https://10.0.0.1:2480"},
nil,
"dnsClusterTest=https://10.0.0.1:2480,0=https://10.0.0.2:2480,1=https://10.0.0.3:2480,2=http://10.0.0.1:7001", "dnsClusterTest=https://10.0.0.1:2480,0=https://10.0.0.2:2480,1=https://10.0.0.3:2480,2=http://10.0.0.1:7001",
}, },
// matching local member with resolved addr and return unresolved hostnames
{
[]*net.SRV{
&net.SRV{Target: "1.example.com.", Port: 2480},
&net.SRV{Target: "2.example.com.", Port: 2480},
&net.SRV{Target: "3.example.com.", Port: 2480},
},
nil,
[]string{"https://10.0.0.1:2480"},
map[string]string{"1.example.com:2480": "10.0.0.1:2480", "2.example.com:2480": "10.0.0.2:2480", "3.example.com:2480": "10.0.0.3:2480"},
"dnsClusterTest=https://1.example.com:2480,0=https://2.example.com:2480,1=https://3.example.com:2480",
},
} }
for i, tt := range tests { for i, tt := range tests {
@ -84,6 +108,12 @@ func TestSRVGetCluster(t *testing.T) {
} }
return "", nil, errors.New("Unkown service in mock") return "", nil, errors.New("Unkown service in mock")
} }
resolveTCPAddr = func(network, addr string) (*net.TCPAddr, error) {
if tt.dns == nil || tt.dns[addr] == "" {
return net.ResolveTCPAddr(network, addr)
}
return net.ResolveTCPAddr(network, tt.dns[addr])
}
urls := testutil.MustNewURLs(t, tt.urls) urls := testutil.MustNewURLs(t, tt.urls)
str, token, err := SRVGetCluster(name, "example.com", "token", urls) str, token, err := SRVGetCluster(name, "example.com", "token", urls)
if err != nil { if err != nil {