diff --git a/pkg/netutil/netutil.go b/pkg/netutil/netutil.go index bb5f392b3..5e38dc98d 100644 --- a/pkg/netutil/netutil.go +++ b/pkg/netutil/netutil.go @@ -16,14 +16,13 @@ package netutil import ( + "context" "net" "net/url" "reflect" "sort" "time" - "golang.org/x/net/context" - "github.com/coreos/etcd/pkg/types" "github.com/coreos/pkg/capnslog" ) @@ -32,11 +31,38 @@ var ( plog = capnslog.NewPackageLogger("github.com/coreos/etcd", "pkg/netutil") // indirection for testing - resolveTCPAddr = net.ResolveTCPAddr + resolveTCPAddr = resolveTCPAddrDefault ) const retryInterval = time.Second +// taken from go's ResolveTCP code but uses configurable ctx +func resolveTCPAddrDefault(ctx context.Context, addr string) (*net.TCPAddr, error) { + host, port, serr := net.SplitHostPort(addr) + if serr != nil { + return nil, serr + } + portnum, perr := net.DefaultResolver.LookupPort(ctx, "tcp", port) + if perr != nil { + return nil, perr + } + + var ips []net.IPAddr + if ip := net.ParseIP(host); ip != nil { + ips = []net.IPAddr{{IP: ip}} + } else { + // Try as a DNS name. + ipss, err := net.DefaultResolver.LookupIPAddr(ctx, host) + if err != nil { + return nil, err + } + ips = ipss + } + // randomize? + ip := ips[0] + return &net.TCPAddr{IP: ip.IP, Port: portnum, Zone: ip.Zone}, nil +} + // resolveTCPAddrs is a convenience wrapper for net.ResolveTCPAddr. // resolveTCPAddrs return a new set of url.URLs, in which all DNS hostnames // are resolved. @@ -75,7 +101,7 @@ func resolveURL(ctx context.Context, u url.URL) (string, error) { if host == "localhost" || net.ParseIP(host) != nil { return "", nil } - tcpAddr, err := resolveTCPAddr("tcp", u.Host) + tcpAddr, err := resolveTCPAddr(ctx, u.Host) if err == nil { plog.Infof("resolving %s to %s", u.Host, tcpAddr.String()) return tcpAddr.String(), nil diff --git a/pkg/netutil/netutil_test.go b/pkg/netutil/netutil_test.go index c8d9f7994..82abe6d12 100644 --- a/pkg/netutil/netutil_test.go +++ b/pkg/netutil/netutil_test.go @@ -15,6 +15,7 @@ package netutil import ( + "context" "errors" "net" "net/url" @@ -22,12 +23,10 @@ import ( "strconv" "testing" "time" - - "golang.org/x/net/context" ) func TestResolveTCPAddrs(t *testing.T) { - defer func() { resolveTCPAddr = net.ResolveTCPAddr }() + defer func() { resolveTCPAddr = resolveTCPAddrDefault }() tests := []struct { urls [][]url.URL expected [][]url.URL @@ -113,7 +112,7 @@ func TestResolveTCPAddrs(t *testing.T) { }, } for _, tt := range tests { - resolveTCPAddr = func(network, addr string) (*net.TCPAddr, error) { + resolveTCPAddr = func(ctx context.Context, addr string) (*net.TCPAddr, error) { host, port, err := net.SplitHostPort(addr) if err != nil { return nil, err @@ -143,13 +142,13 @@ func TestResolveTCPAddrs(t *testing.T) { } func TestURLsEqual(t *testing.T) { - defer func() { resolveTCPAddr = net.ResolveTCPAddr }() + defer func() { resolveTCPAddr = resolveTCPAddrDefault }() hostm := map[string]string{ "example.com": "10.0.10.1", "first.com": "10.0.11.1", "second.com": "10.0.11.2", } - resolveTCPAddr = func(network, addr string) (*net.TCPAddr, error) { + resolveTCPAddr = func(ctx context.Context, addr string) (*net.TCPAddr, error) { host, port, herr := net.SplitHostPort(addr) if herr != nil { return nil, herr