diff --git a/discovery/discovery.go b/discovery/discovery.go index 3dac28b92..d53c46114 100644 --- a/discovery/discovery.go +++ b/discovery/discovery.go @@ -20,7 +20,6 @@ import ( "errors" "fmt" "math" - "net" "net/http" "net/url" "path" @@ -30,6 +29,7 @@ import ( "time" "github.com/coreos/etcd/client" + "github.com/coreos/etcd/pkg/transport" "github.com/coreos/etcd/pkg/types" "github.com/coreos/pkg/capnslog" "github.com/jonboulle/clockwork" @@ -124,16 +124,15 @@ func newDiscovery(durl, dproxyurl string, id types.ID) (*discovery, error) { if err != nil { return nil, err } + + // TODO: add ResponseHeaderTimeout back when watch on discovery service writes header early + tr, err := transport.NewTransport(transport.TLSInfo{}, 30*time.Second) + if err != nil { + return nil, err + } + tr.Proxy = pf cfg := client.Config{ - Transport: &http.Transport{ - Proxy: pf, - Dial: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).Dial, - TLSHandshakeTimeout: 10 * time.Second, - // TODO: add ResponseHeaderTimeout back when watch on discovery service writes header early - }, + Transport: tr, Endpoints: []string{u.String()}, } c, err := client.New(cfg) diff --git a/integration/bridge.go b/integration/bridge.go index 4c596dc90..e46397ca8 100644 --- a/integration/bridge.go +++ b/integration/bridge.go @@ -18,8 +18,9 @@ import ( "fmt" "io" "net" - "os" "sync" + + "github.com/coreos/etcd/pkg/transport" ) // bridge creates a unix socket bridge to another unix socket, making it possible @@ -43,10 +44,7 @@ func newBridge(addr string) (*bridge, error) { conns: make(map[*bridgeConn]struct{}), stopc: make(chan struct{}, 1), } - if err := os.RemoveAll(b.inaddr); err != nil { - return nil, err - } - l, err := net.Listen("unix", b.inaddr) + l, err := transport.NewUnixListener(b.inaddr) if err != nil { return nil, fmt.Errorf("listen failed on socket %s (%v)", addr, err) } @@ -79,7 +77,6 @@ func (b *bridge) Reset() { func (b *bridge) serveListen() { defer func() { b.l.Close() - os.RemoveAll(b.inaddr) b.mu.Lock() for bc := range b.conns { bc.Close() diff --git a/integration/cluster.go b/integration/cluster.go index 17f5e5dca..8f15f42cb 100644 --- a/integration/cluster.go +++ b/integration/cluster.go @@ -25,7 +25,6 @@ import ( "os" "reflect" "sort" - "strconv" "strings" "sync" "sync/atomic" @@ -53,14 +52,18 @@ const ( tickDuration = 10 * time.Millisecond clusterName = "etcd" requestTimeout = 20 * time.Second + + basePort = 21000 + urlScheme = "unix" + urlSchemeTLS = "unixs" ) var ( electionTicks = 10 - // integration test uses well-known ports to listen for each running member, - // which ensures restarted member could listen on specific port again. - nextListenPort int64 = 21000 + // integration test uses unique ports, counting up, to listen for each + // member, ensuring restarted members can listen on the same port again. + localListenCount int64 = 0 testTLSInfo = transport.TLSInfo{ KeyFile: "./fixtures/server.key.insecure", @@ -91,6 +94,13 @@ func init() { api.EnableCapability(api.V3rpcCapability) } +func schemeFromTLSInfo(tls *transport.TLSInfo) string { + if tls == nil { + return urlScheme + } + return urlSchemeTLS +} + func (c *cluster) fillClusterForMembers() error { if c.cfg.DiscoveryURL != "" { // cluster will be discovered @@ -99,10 +109,7 @@ func (c *cluster) fillClusterForMembers() error { addrs := make([]string, 0) for _, m := range c.Members { - scheme := "http" - if m.PeerTLSInfo != nil { - scheme = "https" - } + scheme := schemeFromTLSInfo(m.PeerTLSInfo) for _, l := range m.PeerListeners { addrs = append(addrs, fmt.Sprintf("%s=%s://%s", m.Name, scheme, l.Addr().String())) } @@ -186,13 +193,8 @@ func (c *cluster) URLs() []string { func (c *cluster) HTTPMembers() []client.Member { ms := []client.Member{} for _, m := range c.Members { - pScheme, cScheme := "http", "http" - if m.PeerTLSInfo != nil { - pScheme = "https" - } - if m.ClientTLSInfo != nil { - cScheme = "https" - } + pScheme := schemeFromTLSInfo(m.PeerTLSInfo) + cScheme := schemeFromTLSInfo(m.ClientTLSInfo) cm := client.Member{Name: m.Name} for _, ln := range m.PeerListeners { cm.PeerURLs = append(cm.PeerURLs, pScheme+"://"+ln.Addr().String()) @@ -225,10 +227,7 @@ func (c *cluster) mustNewMember(t *testing.T) *member { func (c *cluster) addMember(t *testing.T) { m := c.mustNewMember(t) - scheme := "http" - if c.cfg.PeerTLS != nil { - scheme = "https" - } + scheme := schemeFromTLSInfo(c.cfg.PeerTLS) // send add request to the cluster var err error @@ -390,26 +389,13 @@ func isMembersEqual(membs []client.Member, wmembs []client.Member) bool { } func newLocalListener(t *testing.T) net.Listener { - port := atomic.AddInt64(&nextListenPort, 1) - l, err := net.Listen("tcp", "127.0.0.1:"+strconv.FormatInt(port, 10)) - if err != nil { - t.Fatal(err) - } - return l + c := atomic.AddInt64(&localListenCount, 1) + addr := fmt.Sprintf("127.0.0.1:%d.%d.sock", c+basePort, os.Getpid()) + return newListenerWithAddr(t, addr) } func newListenerWithAddr(t *testing.T, addr string) net.Listener { - var err error - var l net.Listener - // TODO: we want to reuse a previous closed port immediately. - // a better way is to set SO_REUSExx instead of doing retry. - for i := 0; i < 5; i++ { - l, err = net.Listen("tcp", addr) - if err == nil { - break - } - time.Sleep(500 * time.Millisecond) - } + l, err := transport.NewUnixListener(addr) if err != nil { t.Fatal(err) } @@ -449,13 +435,8 @@ func mustNewMember(t *testing.T, mcfg memberConfig) *member { var err error m := &member{} - peerScheme, clientScheme := "http", "http" - if mcfg.peerTLS != nil { - peerScheme = "https" - } - if mcfg.clientTLS != nil { - clientScheme = "https" - } + peerScheme := schemeFromTLSInfo(mcfg.peerTLS) + clientScheme := schemeFromTLSInfo(mcfg.clientTLS) pln := newLocalListener(t) m.PeerListeners = []net.Listener{pln} @@ -500,10 +481,7 @@ func mustNewMember(t *testing.T, mcfg memberConfig) *member { func (m *member) listenGRPC() error { // prefix with localhost so cert has right domain m.grpcAddr = "localhost:" + m.Name + ".sock" - if err := os.RemoveAll(m.grpcAddr); err != nil { - return err - } - l, err := net.Listen("unix", m.grpcAddr) + l, err := transport.NewUnixListener(m.grpcAddr) if err != nil { return fmt.Errorf("listen failed on grpc socket %s (%v)", m.grpcAddr, err) } diff --git a/integration/v2_http_kv_test.go b/integration/v2_http_kv_test.go index c50a02857..03e34f29c 100644 --- a/integration/v2_http_kv_test.go +++ b/integration/v2_http_kv_test.go @@ -19,7 +19,6 @@ import ( "fmt" "io" "io/ioutil" - "net" "net/http" "net/url" "reflect" @@ -28,6 +27,7 @@ import ( "time" "github.com/coreos/etcd/pkg/testutil" + "github.com/coreos/etcd/pkg/transport" "github.com/coreos/pkg/capnslog" ) @@ -1038,10 +1038,8 @@ type testHttpClient struct { // Creates a new HTTP client with KeepAlive disabled. func NewTestClient() *testHttpClient { - tr := &http.Transport{ - Dial: (&net.Dialer{Timeout: time.Second}).Dial, - DisableKeepAlives: true, - } + tr, _ := transport.NewTransport(transport.TLSInfo{}, time.Second) + tr.DisableKeepAlives = true return &testHttpClient{&http.Client{Transport: tr}} } diff --git a/pkg/transport/listener.go b/pkg/transport/listener.go index 4e38bf95f..d94757b2c 100644 --- a/pkg/transport/listener.go +++ b/pkg/transport/listener.go @@ -25,7 +25,6 @@ import ( "fmt" "math/big" "net" - "net/http" "os" "path" "strings" @@ -35,19 +34,19 @@ import ( "github.com/coreos/etcd/pkg/tlsutil" ) -func NewListener(addr string, scheme string, tlscfg *tls.Config) (net.Listener, error) { - nettype := "tcp" - if scheme == "unix" { +func NewListener(addr string, scheme string, tlscfg *tls.Config) (l net.Listener, err error) { + if scheme == "unix" || scheme == "unixs" { // unix sockets via unix://laddr - nettype = scheme + l, err = NewUnixListener(addr) + } else { + l, err = net.Listen("tcp", addr) } - l, err := net.Listen(nettype, addr) if err != nil { return nil, err } - if scheme == "https" { + if scheme == "https" || scheme == "unixs" { if tlscfg == nil { return nil, fmt.Errorf("cannot listen on TLS for %s: KeyFile and CertFile are not presented", scheme+"://"+addr) } @@ -58,27 +57,6 @@ func NewListener(addr string, scheme string, tlscfg *tls.Config) (net.Listener, return l, nil } -func NewTransport(info TLSInfo, dialtimeoutd time.Duration) (*http.Transport, error) { - cfg, err := info.ClientConfig() - if err != nil { - return nil, err - } - - t := &http.Transport{ - Proxy: http.ProxyFromEnvironment, - Dial: (&net.Dialer{ - Timeout: dialtimeoutd, - // value taken from http.DefaultTransport - KeepAlive: 30 * time.Second, - }).Dial, - // value taken from http.DefaultTransport - TLSHandshakeTimeout: 10 * time.Second, - TLSClientConfig: cfg, - } - - return t, nil -} - type TLSInfo struct { CertFile string KeyFile string diff --git a/pkg/transport/transport.go b/pkg/transport/transport.go new file mode 100644 index 000000000..ca9ccfd80 --- /dev/null +++ b/pkg/transport/transport.go @@ -0,0 +1,70 @@ +// Copyright 2016 The etcd Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package transport + +import ( + "net" + "net/http" + "strings" + "time" +) + +type unixTransport struct{ *http.Transport } + +func NewTransport(info TLSInfo, dialtimeoutd time.Duration) (*http.Transport, error) { + cfg, err := info.ClientConfig() + if err != nil { + return nil, err + } + + t := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + Dial: (&net.Dialer{ + Timeout: dialtimeoutd, + // value taken from http.DefaultTransport + KeepAlive: 30 * time.Second, + }).Dial, + // value taken from http.DefaultTransport + TLSHandshakeTimeout: 10 * time.Second, + TLSClientConfig: cfg, + } + + dialer := (&net.Dialer{ + Timeout: dialtimeoutd, + KeepAlive: 30 * time.Second, + }) + dial := func(net, addr string) (net.Conn, error) { + return dialer.Dial("unix", addr) + } + + tu := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + Dial: dial, + TLSHandshakeTimeout: 10 * time.Second, + TLSClientConfig: cfg, + } + ut := &unixTransport{tu} + + t.RegisterProtocol("unix", ut) + t.RegisterProtocol("unixs", ut) + + return t, nil +} + +func (urt *unixTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req2 := *req + req2.URL.Scheme = strings.Replace(req.URL.Scheme, "unix", "http", 1) + return urt.Transport.RoundTrip(req) +} diff --git a/pkg/transport/unix_listener.go b/pkg/transport/unix_listener.go new file mode 100644 index 000000000..c126b6f7f --- /dev/null +++ b/pkg/transport/unix_listener.go @@ -0,0 +1,40 @@ +// Copyright 2016 The etcd Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package transport + +import ( + "net" + "os" +) + +type unixListener struct{ net.Listener } + +func NewUnixListener(addr string) (net.Listener, error) { + if err := os.RemoveAll(addr); err != nil { + return nil, err + } + l, err := net.Listen("unix", addr) + if err != nil { + return nil, err + } + return &unixListener{l}, nil +} + +func (ul *unixListener) Close() error { + if err := os.RemoveAll(ul.Addr().String()); err != nil { + return err + } + return ul.Listener.Close() +} diff --git a/pkg/types/urls.go b/pkg/types/urls.go index e532722ac..9e5d03ff6 100644 --- a/pkg/types/urls.go +++ b/pkg/types/urls.go @@ -36,8 +36,8 @@ func NewURLs(strs []string) (URLs, error) { if err != nil { return nil, err } - if u.Scheme != "http" && u.Scheme != "https" { - return nil, fmt.Errorf("URL scheme must be http or https: %s", in) + if u.Scheme != "http" && u.Scheme != "https" && u.Scheme != "unix" && u.Scheme != "unixs" { + return nil, fmt.Errorf("URL scheme must be http, https, unix, or unixs: %s", in) } if _, _, err := net.SplitHostPort(u.Host); err != nil { return nil, fmt.Errorf(`URL address does not have the form "host:port": %s`, in)