From 10629c40e136bad441ccdae423fbdab13742de68 Mon Sep 17 00:00:00 2001 From: Yicheng Qin Date: Wed, 18 Feb 2015 17:25:02 -0800 Subject: [PATCH] migrate/starter: fix flag parsing --- migrate/starter/starter.go | 13 ++++-- migrate/starter/starter_test.go | 70 +++++++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 3 deletions(-) create mode 100644 migrate/starter/starter_test.go diff --git a/migrate/starter/starter.go b/migrate/starter/starter.go index 8fb7601a1..151efe2b8 100644 --- a/migrate/starter/starter.go +++ b/migrate/starter/starter.go @@ -353,7 +353,8 @@ func newDefaultClient(tls *TLSInfo) (*http.Client, error) { } type value struct { - s string + isBoolFlag bool + s string } func (v *value) String() string { return v.s } @@ -363,14 +364,20 @@ func (v *value) Set(s string) error { return nil } -func (v *value) IsBoolFlag() bool { return true } +func (v *value) IsBoolFlag() bool { return v.isBoolFlag } + +type boolFlag interface { + flag.Value + IsBoolFlag() bool +} // parseConfig parses out the input config from cmdline arguments and // environment variables. func parseConfig(args []string) (*flag.FlagSet, error) { fs := flag.NewFlagSet("full flagset", flag.ContinueOnError) etcdmain.NewConfig().VisitAll(func(f *flag.Flag) { - fs.Var(&value{}, f.Name, "") + _, isBoolFlag := f.Value.(boolFlag) + fs.Var(&value{isBoolFlag: isBoolFlag}, f.Name, "") }) if err := fs.Parse(args); err != nil { return nil, err diff --git a/migrate/starter/starter_test.go b/migrate/starter/starter_test.go new file mode 100644 index 000000000..10817a530 --- /dev/null +++ b/migrate/starter/starter_test.go @@ -0,0 +1,70 @@ +// Copyright 2015 CoreOS, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package starter + +import ( + "flag" + "reflect" + "testing" +) + +func TestParseConfig(t *testing.T) { + tests := []struct { + args []string + wvals map[string]string + }{ + { + []string{"--name", "etcd", "--data-dir", "dir"}, + map[string]string{ + "name": "etcd", + "data-dir": "dir", + }, + }, + { + []string{"--name=etcd", "--data-dir=dir"}, + map[string]string{ + "name": "etcd", + "data-dir": "dir", + }, + }, + { + []string{"--version", "--name", "etcd"}, + map[string]string{ + "version": "true", + "name": "etcd", + }, + }, + { + []string{"--version=true", "--name", "etcd"}, + map[string]string{ + "version": "true", + "name": "etcd", + }, + }, + } + for i, tt := range tests { + fs, err := parseConfig(tt.args) + if err != nil { + t.Fatalf("#%d: unexpected parseConfig error: %v", i, err) + } + vals := make(map[string]string) + fs.Visit(func(f *flag.Flag) { + vals[f.Name] = f.Value.String() + }) + if !reflect.DeepEqual(vals, tt.wvals) { + t.Errorf("#%d: vals = %+v, want %+v", i, vals, tt.wvals) + } + } +}