diff --git a/pkg/flag.go b/pkg/flag.go index b09374268..3eb66bef5 100644 --- a/pkg/flag.go +++ b/pkg/flag.go @@ -4,8 +4,12 @@ import ( "flag" "fmt" "log" + "net/url" "os" "strings" + + "github.com/coreos/etcd/pkg/flags" + "github.com/coreos/etcd/pkg/transport" ) type DeprecatedFlag struct { @@ -64,3 +68,38 @@ func SetFlagsFromEnv(fs *flag.FlagSet) { } }) } + +// URLsFromFlags decides what URLs should be using two different flags +// as datasources. The first flag's Value must be of type URLs, while +// the second must be of type IPAddressPort. If both of these flags +// are set, an error will be returned. If only the first flag is set, +// the underlying url.URL objects will be returned unmodified. If the +// second flag happens to be set, the underlying IPAddressPort will be +// converted to a url.URL and returned. The Scheme of the returned +// url.URL will be http unless the provided TLSInfo object is non-empty. +// If neither of the flags have been explicitly set, the default value +// of the first flag will be returned unmodified. +func URLsFromFlags(fs *flag.FlagSet, urlsFlagName string, addrFlagName string, tlsInfo transport.TLSInfo) ([]url.URL, error) { + visited := make(map[string]struct{}) + fs.Visit(func(f *flag.Flag) { + visited[f.Name] = struct{}{} + }) + + _, urlsFlagIsSet := visited[urlsFlagName] + _, addrFlagIsSet := visited[addrFlagName] + + if addrFlagIsSet { + if urlsFlagIsSet { + return nil, fmt.Errorf("Set only one of flags -%s and -%s", urlsFlagName, addrFlagName) + } + + addr := *fs.Lookup(addrFlagName).Value.(*flags.IPAddressPort) + addrURL := url.URL{Scheme: "http", Host: addr.String()} + if !tlsInfo.Empty() { + addrURL.Scheme = "https" + } + return []url.URL{addrURL}, nil + } + + return []url.URL(*fs.Lookup(urlsFlagName).Value.(*flags.URLs)), nil +} diff --git a/pkg/flag_test.go b/pkg/flag_test.go index 0ae424e6d..32fd84a67 100644 --- a/pkg/flag_test.go +++ b/pkg/flag_test.go @@ -2,8 +2,13 @@ package pkg import ( "flag" + "net/url" "os" + "reflect" "testing" + + "github.com/coreos/etcd/pkg/flags" + "github.com/coreos/etcd/pkg/transport" ) func TestSetFlagsFromEnv(t *testing.T) { @@ -49,3 +54,85 @@ func TestSetFlagsFromEnv(t *testing.T) { } } } + +func TestURLsFromFlags(t *testing.T) { + tests := []struct { + args []string + tlsInfo transport.TLSInfo + wantURLs []url.URL + wantFail bool + }{ + // use -urls default when no flags defined + { + args: []string{}, + tlsInfo: transport.TLSInfo{}, + wantURLs: []url.URL{ + url.URL{Scheme: "http", Host: "127.0.0.1:2379"}, + }, + wantFail: false, + }, + + // explicitly setting -urls should carry through + { + args: []string{"-urls=https://192.0.3.17:2930,http://127.0.0.1:1024"}, + tlsInfo: transport.TLSInfo{}, + wantURLs: []url.URL{ + url.URL{Scheme: "https", Host: "192.0.3.17:2930"}, + url.URL{Scheme: "http", Host: "127.0.0.1:1024"}, + }, + wantFail: false, + }, + + // explicitly setting -addr should carry through + { + args: []string{"-addr=192.0.2.3:1024"}, + tlsInfo: transport.TLSInfo{}, + wantURLs: []url.URL{ + url.URL{Scheme: "http", Host: "192.0.2.3:1024"}, + }, + wantFail: false, + }, + + // scheme prepended to -addr should be https if TLSInfo non-empty + { + args: []string{"-addr=192.0.2.3:1024"}, + tlsInfo: transport.TLSInfo{ + CertFile: "/tmp/foo", + KeyFile: "/tmp/bar", + }, + wantURLs: []url.URL{ + url.URL{Scheme: "https", Host: "192.0.2.3:1024"}, + }, + wantFail: false, + }, + + // explicitly setting both -urls and -addr should fail + { + args: []string{"-urls=https://127.0.0.1:1024", "-addr=192.0.2.3:1024"}, + tlsInfo: transport.TLSInfo{}, + wantURLs: nil, + wantFail: true, + }, + } + + for i, tt := range tests { + fs := flag.NewFlagSet("test", flag.PanicOnError) + fs.Var(flags.NewURLs("http://127.0.0.1:2379"), "urls", "") + fs.Var(&flags.IPAddressPort{}, "addr", "") + + if err := fs.Parse(tt.args); err != nil { + t.Errorf("#%d: failed to parse flags: %v", i, err) + continue + } + + gotURLs, err := URLsFromFlags(fs, "urls", "addr", tt.tlsInfo) + if tt.wantFail != (err != nil) { + t.Errorf("#%d: wantFail=%t, got err=%v", i, tt.wantFail, err) + continue + } + + if !reflect.DeepEqual(tt.wantURLs, gotURLs) { + t.Errorf("#%d: incorrect URLs\nwant=%#v\ngot=%#v", i, tt.wantURLs, gotURLs) + } + } +}