diff --git a/client/http.go b/client/http.go index 7dfecc8c2..73649ef59 100644 --- a/client/http.go +++ b/client/http.go @@ -40,6 +40,7 @@ var ( type SyncableHTTPClient interface { HTTPClient Sync(context.Context) error + Endpoints() []string } type HTTPClient interface { @@ -65,7 +66,8 @@ func NewHTTPClient(tr CancelableTransport, eps []string) (SyncableHTTPClient, er func newHTTPClusterClient(tr CancelableTransport, eps []string) (*httpClusterClient, error) { c := httpClusterClient{ transport: tr, - endpoints: make([]HTTPClient, len(eps)), + endpoints: eps, + clients: make([]HTTPClient, len(eps)), } for i, ep := range eps { @@ -74,7 +76,7 @@ func newHTTPClusterClient(tr CancelableTransport, eps []string) (*httpClusterCli return nil, err } - c.endpoints[i] = &redirectFollowingHTTPClient{ + c.clients[i] = &redirectFollowingHTTPClient{ max: DefaultMaxRedirects, client: &httpClient{ transport: tr, @@ -88,14 +90,15 @@ func newHTTPClusterClient(tr CancelableTransport, eps []string) (*httpClusterCli type httpClusterClient struct { transport CancelableTransport - endpoints []HTTPClient + endpoints []string + clients []HTTPClient } func (c *httpClusterClient) Do(ctx context.Context, act HTTPAction) (resp *http.Response, body []byte, err error) { - if len(c.endpoints) == 0 { + if len(c.clients) == 0 { return nil, nil, ErrNoEndpoints } - for _, hc := range c.endpoints { + for _, hc := range c.clients { resp, body, err = hc.Do(ctx, act) if err != nil { if err == ErrTimeout || err == ErrCanceled { @@ -111,6 +114,10 @@ func (c *httpClusterClient) Do(ctx context.Context, act HTTPAction) (resp *http. return } +func (c *httpClusterClient) Endpoints() []string { + return c.endpoints +} + func (c *httpClusterClient) Sync(ctx context.Context) error { mAPI := NewMembersAPI(c) ms, err := mAPI.List(ctx) diff --git a/client/http_test.go b/client/http_test.go index c3cc1c9d5..a1ebca4c1 100644 --- a/client/http_test.go +++ b/client/http_test.go @@ -193,7 +193,7 @@ func TestHTTPClusterClientDo(t *testing.T) { // first good response short-circuits Do { client: &httpClusterClient{ - endpoints: []HTTPClient{ + clients: []HTTPClient{ &staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}}, &staticHTTPClient{err: fakeErr}, }, @@ -204,7 +204,7 @@ func TestHTTPClusterClientDo(t *testing.T) { // fall through to good endpoint if err is arbitrary { client: &httpClusterClient{ - endpoints: []HTTPClient{ + clients: []HTTPClient{ &staticHTTPClient{err: fakeErr}, &staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}}, }, @@ -215,7 +215,7 @@ func TestHTTPClusterClientDo(t *testing.T) { // ErrTimeout short-circuits Do { client: &httpClusterClient{ - endpoints: []HTTPClient{ + clients: []HTTPClient{ &staticHTTPClient{err: ErrTimeout}, &staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}}, }, @@ -226,7 +226,7 @@ func TestHTTPClusterClientDo(t *testing.T) { // ErrCanceled short-circuits Do { client: &httpClusterClient{ - endpoints: []HTTPClient{ + clients: []HTTPClient{ &staticHTTPClient{err: ErrCanceled}, &staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}}, }, @@ -237,7 +237,7 @@ func TestHTTPClusterClientDo(t *testing.T) { // return err if there are no endpoints { client: &httpClusterClient{ - endpoints: []HTTPClient{}, + clients: []HTTPClient{}, }, wantErr: ErrNoEndpoints, }, @@ -245,7 +245,7 @@ func TestHTTPClusterClientDo(t *testing.T) { // return err if all endpoints return arbitrary errors { client: &httpClusterClient{ - endpoints: []HTTPClient{ + clients: []HTTPClient{ &staticHTTPClient{err: fakeErr}, &staticHTTPClient{err: fakeErr}, }, @@ -256,7 +256,7 @@ func TestHTTPClusterClientDo(t *testing.T) { // 500-level errors cause Do to fallthrough to next endpoint { client: &httpClusterClient{ - endpoints: []HTTPClient{ + clients: []HTTPClient{ &staticHTTPClient{resp: http.Response{StatusCode: http.StatusBadGateway}}, &staticHTTPClient{resp: http.Response{StatusCode: http.StatusTeapot}}, }, diff --git a/etcdctl/command/handle.go b/etcdctl/command/handle.go index 54f2aaf59..20eb9086d 100644 --- a/etcdctl/command/handle.go +++ b/etcdctl/command/handle.go @@ -20,7 +20,6 @@ import ( "encoding/json" "errors" "fmt" - "net/url" "os" "strings" @@ -40,72 +39,35 @@ func dumpCURL(client *etcd.Client) { } } -// createHttpPath attaches http scheme to the given address if needed -func createHttpPath(addr string) (string, error) { - u, err := url.Parse(addr) - if err != nil { - return "", err - } - - if u.Scheme == "" { - u.Scheme = "http" - } - return u.String(), nil -} - -func getPeersFlagValue(c *cli.Context) []string { - peerstr := c.GlobalString("peers") - - // Use an environment variable if nothing was supplied on the - // command line - if peerstr == "" { - peerstr = os.Getenv("ETCDCTL_PEERS") - } - - // If we still don't have peers, use a default - if peerstr == "" { - peerstr = "127.0.0.1:4001" - } - - return strings.Split(peerstr, ",") -} - // rawhandle wraps the command function handlers and sets up the // environment but performs no output formatting. func rawhandle(c *cli.Context, fn handlerFunc) (*etcd.Response, error) { - sync := !c.GlobalBool("no-sync") - - peers := getPeersFlagValue(c) - - // If no sync, create http path for each peer address - if !sync { - revisedPeers := make([]string, 0) - for _, peer := range peers { - if revisedPeer, err := createHttpPath(peer); err != nil { - fmt.Fprintf(os.Stderr, "Unsupported url %v: %v\n", peer, err) - } else { - revisedPeers = append(revisedPeers, revisedPeer) - } - } - peers = revisedPeers + endpoints, err := getEndpoints(c) + if err != nil { + return nil, err } - client := etcd.NewClient(peers) + tr, err := getTransport(c) + if err != nil { + return nil, err + } + + client := etcd.NewClient(endpoints) + client.SetTransport(tr) if c.GlobalBool("debug") { go dumpCURL(client) } // Sync cluster. - if sync { + if !c.GlobalBool("no-sync") { if ok := client.SyncCluster(); !ok { - handleError(FailedToConnectToHost, errors.New("Cannot sync with the cluster using peers "+strings.Join(peers, ", "))) + handleError(FailedToConnectToHost, errors.New("cannot sync with the cluster using endpoints "+strings.Join(endpoints, ", "))) } } if c.GlobalBool("debug") { - fmt.Fprintf(os.Stderr, "Cluster-Peers: %s\n", - strings.Join(client.GetCluster(), " ")) + fmt.Fprintf(os.Stderr, "Cluster-Endpoints: %s\n", strings.Join(client.GetCluster(), ", ")) } // Execute handler function. diff --git a/etcdctl/command/member_commands.go b/etcdctl/command/member_commands.go index f79549a4f..7432b2b65 100644 --- a/etcdctl/command/member_commands.go +++ b/etcdctl/command/member_commands.go @@ -18,7 +18,6 @@ package command import ( "fmt" - "net/http" "os" "strings" @@ -52,14 +51,19 @@ func NewMemberCommand() cli.Command { } func mustNewMembersAPI(c *cli.Context) client.MembersAPI { - peers := getPeersFlagValue(c) - for i, p := range peers { - if !strings.HasPrefix(p, "http") && !strings.HasPrefix(p, "https") { - peers[i] = fmt.Sprintf("http://%s", p) - } + eps, err := getEndpoints(c) + if err != nil { + fmt.Fprintln(os.Stderr, err.Error()) + os.Exit(1) } - hc, err := client.NewHTTPClient(&http.Transport{}, peers) + tr, err := getTransport(c) + if err != nil { + fmt.Fprintln(os.Stderr, err.Error()) + os.Exit(1) + } + + hc, err := client.NewHTTPClient(tr, eps) if err != nil { fmt.Fprintln(os.Stderr, err.Error()) os.Exit(1) @@ -75,6 +79,10 @@ func mustNewMembersAPI(c *cli.Context) client.MembersAPI { } } + if c.GlobalBool("debug") { + fmt.Fprintf(os.Stderr, "Cluster-Endpoints: %s\n", strings.Join(hc.Endpoints(), ", ")) + } + return client.NewMembersAPI(hc) } diff --git a/etcdctl/command/util.go b/etcdctl/command/util.go index 7dd29227c..14ca7320e 100644 --- a/etcdctl/command/util.go +++ b/etcdctl/command/util.go @@ -20,7 +20,13 @@ import ( "errors" "io" "io/ioutil" + "net/http" + "net/url" + "os" "strings" + + "github.com/coreos/etcd/Godeps/_workspace/src/github.com/codegangsta/cli" + "github.com/coreos/etcd/pkg/transport" ) var ( @@ -49,3 +55,47 @@ func argOrStdin(args []string, stdin io.Reader, i int) (string, error) { } return string(bytes), nil } + +func getPeersFlagValue(c *cli.Context) []string { + peerstr := c.GlobalString("peers") + + // Use an environment variable if nothing was supplied on the + // command line + if peerstr == "" { + peerstr = os.Getenv("ETCDCTL_PEERS") + } + + // If we still don't have peers, use a default + if peerstr == "" { + peerstr = "127.0.0.1:4001" + } + + return strings.Split(peerstr, ",") +} + +func getEndpoints(c *cli.Context) ([]string, error) { + eps := getPeersFlagValue(c) + for i, ep := range eps { + u, err := url.Parse(ep) + if err != nil { + return nil, err + } + + if u.Scheme == "" { + u.Scheme = "http" + } + + eps[i] = u.String() + } + return eps, nil +} + +func getTransport(c *cli.Context) (*http.Transport, error) { + tls := transport.TLSInfo{ + CAFile: c.GlobalString("ca-file"), + CertFile: c.GlobalString("cert-file"), + KeyFile: c.GlobalString("key-file"), + } + return transport.NewTransport(tls) + +} diff --git a/etcdctl/main.go b/etcdctl/main.go index e2b67625b..aeb2d54d1 100644 --- a/etcdctl/main.go +++ b/etcdctl/main.go @@ -35,6 +35,9 @@ func main() { cli.BoolFlag{Name: "no-sync", Usage: "don't synchronize cluster information before sending request"}, cli.StringFlag{Name: "output, o", Value: "simple", Usage: "output response in the given format (`simple` or `json`)"}, cli.StringFlag{Name: "peers, C", Value: "", Usage: "a comma-delimited list of machine addresses in the cluster (default: \"127.0.0.1:4001\")"}, + cli.StringFlag{Name: "cert-file", Value: "", Usage: "identify HTTPS client using this SSL certificate file"}, + cli.StringFlag{Name: "key-file", Value: "", Usage: "identify HTTPS client using this SSL key file"}, + cli.StringFlag{Name: "ca-file", Value: "", Usage: "verify certificates of HTTPS-enabled servers using this CA bundle"}, } app.Commands = []cli.Command{ command.NewMakeCommand(), diff --git a/pkg/transport/listener.go b/pkg/transport/listener.go index 8a23abab6..b4172dde7 100644 --- a/pkg/transport/listener.go +++ b/pkg/transport/listener.go @@ -46,6 +46,11 @@ func NewListener(addr string, info TLSInfo) (net.Listener, error) { } func NewTransport(info TLSInfo) (*http.Transport, error) { + cfg, err := info.ClientConfig() + if err != nil { + return nil, err + } + t := &http.Transport{ // timeouts taken from http.DefaultTransport Dial: (&net.Dialer{ @@ -53,14 +58,7 @@ func NewTransport(info TLSInfo) (*http.Transport, error) { KeepAlive: 30 * time.Second, }).Dial, TLSHandshakeTimeout: 10 * time.Second, - } - - if !info.Empty() { - tlsCfg, err := info.ClientConfig() - if err != nil { - return nil, err - } - t.TLSClientConfig = tlsCfg + TLSClientConfig: cfg, } return t, nil @@ -134,22 +132,24 @@ func (info TLSInfo) ServerConfig() (*tls.Config, error) { } // ClientConfig generates a tls.Config object for use by an HTTP client -func (info TLSInfo) ClientConfig() (*tls.Config, error) { - cfg, err := info.baseConfig() - if err != nil { - return nil, err - } - - if info.CAFile != "" { - cp, err := newCertPool(info.CAFile) +func (info TLSInfo) ClientConfig() (cfg *tls.Config, err error) { + if !info.Empty() { + cfg, err = info.baseConfig() if err != nil { return nil, err } - - cfg.RootCAs = cp + } else { + cfg = &tls.Config{} } - return cfg, nil + if info.CAFile != "" { + cfg.RootCAs, err = newCertPool(info.CAFile) + if err != nil { + return + } + } + + return } // newCertPool creates x509 certPool with provided CA file diff --git a/pkg/transport/listener_test.go b/pkg/transport/listener_test.go index 9745b0900..8d18460b1 100644 --- a/pkg/transport/listener_test.go +++ b/pkg/transport/listener_test.go @@ -51,41 +51,31 @@ func TestNewTransportTLSInfo(t *testing.T) { } defer os.Remove(tmp) - tests := []struct { - info TLSInfo - wantTLSClientConfig bool - }{ - { - info: TLSInfo{}, - wantTLSClientConfig: false, + tests := []TLSInfo{ + TLSInfo{}, + TLSInfo{ + CertFile: tmp, + KeyFile: tmp, }, - { - info: TLSInfo{ - CertFile: tmp, - KeyFile: tmp, - }, - wantTLSClientConfig: true, + TLSInfo{ + CertFile: tmp, + KeyFile: tmp, + CAFile: tmp, }, - { - info: TLSInfo{ - CertFile: tmp, - KeyFile: tmp, - CAFile: tmp, - }, - wantTLSClientConfig: true, + TLSInfo{ + CAFile: tmp, }, } for i, tt := range tests { - tt.info.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil) - trans, err := NewTransport(tt.info) + tt.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil) + trans, err := NewTransport(tt) if err != nil { t.Fatalf("Received unexpected error from NewTransport: %v", err) } - gotTLSClientConfig := trans.TLSClientConfig != nil - if tt.wantTLSClientConfig != gotTLSClientConfig { - t.Fatalf("#%d: wantTLSClientConfig=%t but gotTLSClientConfig=%t", i, tt.wantTLSClientConfig, gotTLSClientConfig) + if trans.TLSClientConfig == nil { + t.Fatalf("#%d: want non-nil TLSClientConfig", i) } } } @@ -121,8 +111,6 @@ func TestTLSInfoMissingFields(t *testing.T) { defer os.Remove(tmp) tests := []TLSInfo{ - TLSInfo{}, - TLSInfo{CAFile: tmp}, TLSInfo{CertFile: tmp}, TLSInfo{KeyFile: tmp}, TLSInfo{CertFile: tmp, CAFile: tmp},