diff --git a/pkg/flags/int_test.go b/pkg/flags/int_test.go deleted file mode 100644 index 2ebabd01f..000000000 --- a/pkg/flags/int_test.go +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2018 The etcd Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package flags - -import ( - "reflect" - "testing" -) - -func TestInvalidUint32(t *testing.T) { - tests := []string{ - // string - "invalid", - // negative number - "-1", - // float number - "0.1", - "-0.2", - // larger than math.MaxUint32 - "4294967296", - } - for i, in := range tests { - var u uint32Value - if err := u.Set(in); err == nil { - t.Errorf(`#%d: unexpected nil error for in=%q`, i, in) - } - } -} - -func TestUint32Value(t *testing.T) { - tests := []struct { - s string - exp uint32 - }{ - {s: "0", exp: 0}, - {s: "1", exp: 1}, - {s: "", exp: 0}, - } - for i := range tests { - ss := uint32(*NewUint32Value(tests[i].s)) - if !reflect.DeepEqual(tests[i].exp, ss) { - t.Fatalf("#%d: expected %q, got %q", i, tests[i].exp, ss) - } - } -} diff --git a/pkg/flags/int.go b/pkg/flags/uint32.go similarity index 76% rename from pkg/flags/int.go rename to pkg/flags/uint32.go index c3d763da4..bbef7df6a 100644 --- a/pkg/flags/int.go +++ b/pkg/flags/uint32.go @@ -1,4 +1,4 @@ -// Copyright 2018 The etcd Authors +// Copyright 2022 The etcd Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,31 +16,26 @@ package flags import ( "flag" - "fmt" "strconv" ) type uint32Value uint32 -func NewUint32Value(s string) *uint32Value { - var u uint32Value - if s == "" || s == "0" { - return &u - } - if err := u.Set(s); err != nil { - panic(fmt.Sprintf("new uint32Value should never fail: %v", err)) - } - return &u +// NewUint32Value creates a uint32Value instance with the default value `v`. +func NewUint32Value(v uint32) *uint32Value { + val := new(uint32Value) + *val = uint32Value(v) + return val } +// Set parses a command line uint32 value. +// Implements "flag.Value" interface. func (i *uint32Value) Set(s string) error { v, err := strconv.ParseUint(s, 0, 32) *i = uint32Value(v) return err } -func (i *uint32Value) Type() string { - return "uint32" -} + func (i *uint32Value) String() string { return strconv.FormatUint(uint64(*i), 10) } // Uint32FromFlag return the uint32 value of a flag with the given name diff --git a/pkg/flags/uint32_test.go b/pkg/flags/uint32_test.go new file mode 100644 index 000000000..6e7d38df2 --- /dev/null +++ b/pkg/flags/uint32_test.go @@ -0,0 +1,68 @@ +// Copyright 2022 The etcd Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package flags + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestUint32Value(t *testing.T) { + cases := []struct { + name string + s string + expectedVal uint32 + expectError bool + }{ + { + name: "normal uint32 value", + s: "200", + expectedVal: 200, + }, + { + name: "zero value", + s: "0", + expectedVal: 0, + }, + { + name: "negative int value", + s: "-200", + expectError: true, + }, + { + name: "invalid integer value", + s: "invalid", + expectError: true, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + var val uint32Value + err := val.Set(tc.s) + + if tc.expectError { + if err == nil { + t.Errorf("Expected failure on parsing uint32 value from %s", tc.s) + } + } else { + if err != nil { + t.Errorf("Unexpected error when parsing %s: %v", tc.s, err) + } + assert.Equal(t, uint32(val), tc.expectedVal) + } + }) + } +}