diff --git a/client/v3/maintenance.go b/client/v3/maintenance.go index b0ef90db7..388106305 100644 --- a/client/v3/maintenance.go +++ b/client/v3/maintenance.go @@ -119,19 +119,7 @@ func NewMaintenance(c *Client) Maintenance { return nil, nil, fmt.Errorf("failed to dial endpoint %s with maintenance client: %v", endpoint, err) } - //get token with established connection - dctx := c.ctx - cancel := func() {} - if c.cfg.DialTimeout > 0 { - dctx, cancel = context.WithTimeout(c.ctx, c.cfg.DialTimeout) - } - err = c.getToken(dctx) - cancel() - if err != nil { - conn.Close() - return nil, nil, fmt.Errorf("failed to getToken from endpoint %s with maintenance client: %v", endpoint, err) - } - cancel = func() { conn.Close() } + cancel := func() { conn.Close() } return RetryMaintenanceClient(c, conn), cancel, nil }, remote: RetryMaintenanceClient(c, c.conn), diff --git a/etcdctl/ctlv3/command/defrag_command.go b/etcdctl/ctlv3/command/defrag_command.go index 5ebf4483d..dfccafb74 100644 --- a/etcdctl/ctlv3/command/defrag_command.go +++ b/etcdctl/ctlv3/command/defrag_command.go @@ -35,10 +35,11 @@ func NewDefragCommand() *cobra.Command { } func defragCommandFunc(cmd *cobra.Command, args []string) { - failures := 0 - c := mustClientFromCmd(cmd) + cfg := clientConfigFromCmd(cmd) for _, ep := range endpointsFromCluster(cmd) { + cfg.Endpoints = []string{ep} + c := mustClient(cfg) ctx, cancel := commandCtx(cmd) start := time.Now() _, err := c.Defragment(ctx, ep) @@ -50,6 +51,7 @@ func defragCommandFunc(cmd *cobra.Command, args []string) { } else { fmt.Printf("Finished defragmenting etcd member[%s]. took %s\n", ep, d.String()) } + c.Close() } if failures != 0 { diff --git a/etcdctl/ctlv3/command/ep_command.go b/etcdctl/ctlv3/command/ep_command.go index 6abfef397..0964f564c 100644 --- a/etcdctl/ctlv3/command/ep_command.go +++ b/etcdctl/ctlv3/command/ep_command.go @@ -191,14 +191,17 @@ type epStatus struct { } func epStatusCommandFunc(cmd *cobra.Command, args []string) { - c := mustClientFromCmd(cmd) + cfg := clientConfigFromCmd(cmd) var statusList []epStatus var err error for _, ep := range endpointsFromCluster(cmd) { + cfg.Endpoints = []string{ep} + c := mustClient(cfg) ctx, cancel := commandCtx(cmd) resp, serr := c.Status(ctx, ep) cancel() + c.Close() if serr != nil { err = serr fmt.Fprintf(os.Stderr, "Failed to get the status of endpoint %s (%v)\n", ep, serr) @@ -220,14 +223,17 @@ type epHashKV struct { } func epHashKVCommandFunc(cmd *cobra.Command, args []string) { - c := mustClientFromCmd(cmd) + cfg := clientConfigFromCmd(cmd) var hashList []epHashKV var err error for _, ep := range endpointsFromCluster(cmd) { + cfg.Endpoints = []string{ep} + c := mustClient(cfg) ctx, cancel := commandCtx(cmd) resp, serr := c.HashKV(ctx, ep, epHashKVRev) cancel() + c.Close() if serr != nil { err = serr fmt.Fprintf(os.Stderr, "Failed to get the hash of endpoint %s (%v)\n", ep, serr) diff --git a/server/etcdserver/api/v3rpc/auth.go b/server/etcdserver/api/v3rpc/auth.go index d986037a1..6c5db76cb 100644 --- a/server/etcdserver/api/v3rpc/auth.go +++ b/server/etcdserver/api/v3rpc/auth.go @@ -18,6 +18,7 @@ import ( "context" pb "go.etcd.io/etcd/api/v3/etcdserverpb" + "go.etcd.io/etcd/server/v3/auth" "go.etcd.io/etcd/server/v3/etcdserver" ) @@ -164,3 +165,23 @@ func (as *AuthServer) UserChangePassword(ctx context.Context, r *pb.AuthUserChan } return resp, nil } + +type AuthGetter interface { + AuthInfoFromCtx(ctx context.Context) (*auth.AuthInfo, error) + AuthStore() auth.AuthStore +} + +type AuthAdmin struct { + ag AuthGetter +} + +// isPermitted verifies the user has admin privilege. +// Only users with "root" role are permitted. +func (aa *AuthAdmin) isPermitted(ctx context.Context) error { + authInfo, err := aa.ag.AuthInfoFromCtx(ctx) + if err != nil { + return err + } + + return aa.ag.AuthStore().IsAdminPermitted(authInfo) +} diff --git a/server/etcdserver/api/v3rpc/maintenance.go b/server/etcdserver/api/v3rpc/maintenance.go index f8b61d3f9..af1f2acb1 100644 --- a/server/etcdserver/api/v3rpc/maintenance.go +++ b/server/etcdserver/api/v3rpc/maintenance.go @@ -25,7 +25,6 @@ import ( "go.etcd.io/etcd/api/v3/v3rpc/rpctypes" "go.etcd.io/etcd/api/v3/version" "go.etcd.io/etcd/raft/v3" - "go.etcd.io/etcd/server/v3/auth" "go.etcd.io/etcd/server/v3/etcdserver" "go.etcd.io/etcd/server/v3/etcdserver/apply" "go.etcd.io/etcd/server/v3/etcdserver/errors" @@ -60,11 +59,6 @@ type LeaderTransferrer interface { MoveLeader(ctx context.Context, lead, target uint64) error } -type AuthGetter interface { - AuthInfoFromCtx(ctx context.Context) (*auth.AuthInfo, error) - AuthStore() auth.AuthStore -} - type ClusterStatusGetter interface { IsLearner() bool } @@ -87,7 +81,7 @@ func NewMaintenanceServer(s *etcdserver.EtcdServer) pb.MaintenanceServer { if srv.lg == nil { srv.lg = zap.NewNop() } - return &authMaintenanceServer{srv, s} + return &authMaintenanceServer{srv, &AuthAdmin{s}} } func (ms *maintenanceServer) Defragment(ctx context.Context, sr *pb.DefragmentRequest) (*pb.DefragmentResponse, error) { @@ -274,20 +268,11 @@ func (ms *maintenanceServer) Downgrade(ctx context.Context, r *pb.DowngradeReque type authMaintenanceServer struct { *maintenanceServer - ag AuthGetter -} - -func (ams *authMaintenanceServer) isAuthenticated(ctx context.Context) error { - authInfo, err := ams.ag.AuthInfoFromCtx(ctx) - if err != nil { - return err - } - - return ams.ag.AuthStore().IsAdminPermitted(authInfo) + *AuthAdmin } func (ams *authMaintenanceServer) Defragment(ctx context.Context, sr *pb.DefragmentRequest) (*pb.DefragmentResponse, error) { - if err := ams.isAuthenticated(ctx); err != nil { + if err := ams.isPermitted(ctx); err != nil { return nil, err } @@ -295,7 +280,7 @@ func (ams *authMaintenanceServer) Defragment(ctx context.Context, sr *pb.Defragm } func (ams *authMaintenanceServer) Snapshot(sr *pb.SnapshotRequest, srv pb.Maintenance_SnapshotServer) error { - if err := ams.isAuthenticated(srv.Context()); err != nil { + if err := ams.isPermitted(srv.Context()); err != nil { return err } @@ -303,7 +288,7 @@ func (ams *authMaintenanceServer) Snapshot(sr *pb.SnapshotRequest, srv pb.Mainte } func (ams *authMaintenanceServer) Hash(ctx context.Context, r *pb.HashRequest) (*pb.HashResponse, error) { - if err := ams.isAuthenticated(ctx); err != nil { + if err := ams.isPermitted(ctx); err != nil { return nil, err } @@ -311,20 +296,32 @@ func (ams *authMaintenanceServer) Hash(ctx context.Context, r *pb.HashRequest) ( } func (ams *authMaintenanceServer) HashKV(ctx context.Context, r *pb.HashKVRequest) (*pb.HashKVResponse, error) { - if err := ams.isAuthenticated(ctx); err != nil { + if err := ams.isPermitted(ctx); err != nil { return nil, err } return ams.maintenanceServer.HashKV(ctx, r) } func (ams *authMaintenanceServer) Status(ctx context.Context, ar *pb.StatusRequest) (*pb.StatusResponse, error) { + if err := ams.isPermitted(ctx); err != nil { + return nil, err + } + return ams.maintenanceServer.Status(ctx, ar) } func (ams *authMaintenanceServer) MoveLeader(ctx context.Context, tr *pb.MoveLeaderRequest) (*pb.MoveLeaderResponse, error) { + if err := ams.isPermitted(ctx); err != nil { + return nil, err + } + return ams.maintenanceServer.MoveLeader(ctx, tr) } func (ams *authMaintenanceServer) Downgrade(ctx context.Context, r *pb.DowngradeRequest) (*pb.DowngradeResponse, error) { + if err := ams.isPermitted(ctx); err != nil { + return nil, err + } + return ams.maintenanceServer.Downgrade(ctx, r) } diff --git a/tests/common/auth_util.go b/tests/common/auth_util.go new file mode 100644 index 000000000..f130c0c95 --- /dev/null +++ b/tests/common/auth_util.go @@ -0,0 +1,87 @@ +// Copyright 2022 The etcd Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package common + +import ( + "context" + "fmt" + + clientv3 "go.etcd.io/etcd/client/v3" + "go.etcd.io/etcd/tests/v3/framework" + "go.etcd.io/etcd/tests/v3/framework/config" +) + +type authRole struct { + role string + permission clientv3.PermissionType + key string + keyEnd string +} + +type authUser struct { + user string + pass string + role string +} + +func createRoles(c framework.Client, roles []authRole) error { + for _, r := range roles { + // add role + if _, err := c.RoleAdd(context.TODO(), r.role); err != nil { + return fmt.Errorf("RoleAdd failed: %w", err) + } + + // grant permission to role + if _, err := c.RoleGrantPermission(context.TODO(), r.role, r.key, r.keyEnd, r.permission); err != nil { + return fmt.Errorf("RoleGrantPermission failed: %w", err) + } + } + + return nil +} + +func createUsers(c framework.Client, users []authUser) error { + for _, u := range users { + // add user + if _, err := c.UserAdd(context.TODO(), u.user, u.pass, config.UserAddOptions{}); err != nil { + return fmt.Errorf("UserAdd failed: %w", err) + } + + // grant role to user + if _, err := c.UserGrantRole(context.TODO(), u.user, u.role); err != nil { + return fmt.Errorf("UserGrantRole failed: %w", err) + } + } + + return nil +} + +func setupAuth(c framework.Client, roles []authRole, users []authUser) error { + // create roles + if err := createRoles(c, roles); err != nil { + return err + } + + if err := createUsers(c, users); err != nil { + return err + } + + // enable auth + if err := c.AuthEnable(context.TODO()); err != nil { + return err + } + + return nil +} diff --git a/tests/common/maintenance_auth_test.go b/tests/common/maintenance_auth_test.go new file mode 100644 index 000000000..61277f512 --- /dev/null +++ b/tests/common/maintenance_auth_test.go @@ -0,0 +1,247 @@ +// Copyright 2022 The etcd Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package common + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + + clientv3 "go.etcd.io/etcd/client/v3" + "go.etcd.io/etcd/tests/v3/framework" + "go.etcd.io/etcd/tests/v3/framework/config" + "go.etcd.io/etcd/tests/v3/framework/testutils" +) + +/* +Test Defragment +*/ +func TestDefragmentWithNoAuth(t *testing.T) { + testDefragmentWithAuth(t, false, true) +} + +func TestDefragmentWithInvalidAuth(t *testing.T) { + testDefragmentWithAuth(t, true, true, WithAuth("invalid", "invalid")) +} + +func TestDefragmentWithRootAuth(t *testing.T) { + testDefragmentWithAuth(t, false, false, WithAuth("root", "rootPass")) +} + +func TestDefragmentWithUserAuth(t *testing.T) { + testDefragmentWithAuth(t, false, true, WithAuth("user0", "user0Pass")) +} + +func testDefragmentWithAuth(t *testing.T, expectConnectionError, expectOperationError bool, opts ...config.ClientOption) { + testMaintenanceOperationWithAuth(t, expectConnectionError, expectOperationError, func(ctx context.Context, cc framework.Client) error { + return cc.Defragment(ctx, config.DefragOption{Timeout: 10 * time.Second}) + }, opts...) +} + +/* +Test Downgrade +*/ +func TestDowngradeWithNoAuth(t *testing.T) { + testDowngradeWithAuth(t, false, true) +} + +func TestDowngradeWithInvalidAuth(t *testing.T) { + testDowngradeWithAuth(t, true, true, WithAuth("invalid", "invalid")) +} + +func TestDowngradeWithRootAuth(t *testing.T) { + testDowngradeWithAuth(t, false, false, WithAuth("root", "rootPass")) +} + +func TestDowngradeWithUserAuth(t *testing.T) { + testDowngradeWithAuth(t, false, true, WithAuth("user0", "user0Pass")) +} + +func testDowngradeWithAuth(t *testing.T, expectConnectionError, expectOperationError bool, opts ...config.ClientOption) { + // TODO(ahrtr): finish this after we added interface methods `Downgrade` into `Client` + t.Skip() +} + +/* +Test HashKV +*/ +func TestHashKVWithNoAuth(t *testing.T) { + testHashKVWithAuth(t, false, true) +} + +func TestHashKVWithInvalidAuth(t *testing.T) { + testHashKVWithAuth(t, true, true, WithAuth("invalid", "invalid")) +} + +func TestHashKVWithRootAuth(t *testing.T) { + testHashKVWithAuth(t, false, false, WithAuth("root", "rootPass")) +} + +func TestHashKVWithUserAuth(t *testing.T) { + testHashKVWithAuth(t, false, true, WithAuth("user0", "user0Pass")) +} + +func testHashKVWithAuth(t *testing.T, expectConnectionError, expectOperationError bool, opts ...config.ClientOption) { + testMaintenanceOperationWithAuth(t, expectConnectionError, expectOperationError, func(ctx context.Context, cc framework.Client) error { + _, err := cc.HashKV(ctx, 0) + return err + }, opts...) +} + +/* +Test MoveLeader +*/ +func TestMoveLeaderWithNoAuth(t *testing.T) { + testMoveLeaderWithAuth(t, false, true) +} + +func TestMoveLeaderWithInvalidAuth(t *testing.T) { + testMoveLeaderWithAuth(t, true, true, WithAuth("invalid", "invalid")) +} + +func TestMoveLeaderWithRootAuth(t *testing.T) { + testMoveLeaderWithAuth(t, false, false, WithAuth("root", "rootPass")) +} + +func TestMoveLeaderWithUserAuth(t *testing.T) { + testMoveLeaderWithAuth(t, false, true, WithAuth("user0", "user0Pass")) +} + +func testMoveLeaderWithAuth(t *testing.T, expectConnectionError, expectOperationError bool, opts ...config.ClientOption) { + // TODO(ahrtr): finish this after we added interface methods `MoveLeader` into `Client` + t.Skip() +} + +/* +Test Snapshot +*/ +func TestSnapshotWithNoAuth(t *testing.T) { + testSnapshotWithAuth(t, false, true) +} + +func TestSnapshotWithInvalidAuth(t *testing.T) { + testSnapshotWithAuth(t, true, true, WithAuth("invalid", "invalid")) +} + +func TestSnapshotWithRootAuth(t *testing.T) { + testSnapshotWithAuth(t, false, false, WithAuth("root", "rootPass")) +} + +func TestSnapshotWithUserAuth(t *testing.T) { + testSnapshotWithAuth(t, false, true, WithAuth("user0", "user0Pass")) +} + +func testSnapshotWithAuth(t *testing.T, expectConnectionError, expectOperationError bool, opts ...config.ClientOption) { + // TODO(ahrtr): finish this after we added interface methods `Snapshot` into `Client` + t.Skip() +} + +/* +Test Status +*/ +func TestStatusWithNoAuth(t *testing.T) { + testStatusWithAuth(t, false, true) +} + +func TestStatusWithInvalidAuth(t *testing.T) { + testStatusWithAuth(t, true, true, WithAuth("invalid", "invalid")) +} + +func TestStatusWithRootAuth(t *testing.T) { + testStatusWithAuth(t, false, false, WithAuth("root", "rootPass")) +} + +func TestStatusWithUserAuth(t *testing.T) { + testStatusWithAuth(t, false, true, WithAuth("user0", "user0Pass")) +} + +func testStatusWithAuth(t *testing.T, expectConnectionError, expectOperationError bool, opts ...config.ClientOption) { + testMaintenanceOperationWithAuth(t, expectConnectionError, expectOperationError, func(ctx context.Context, cc framework.Client) error { + _, err := cc.Status(ctx) + return err + }, opts...) +} + +func setupAuthForMaintenanceTest(c framework.Client) error { + roles := []authRole{ + { + role: "role0", + permission: clientv3.PermissionType(clientv3.PermReadWrite), + key: "foo", + }, + } + + users := []authUser{ + { + user: "root", + pass: "rootPass", + role: "root", + }, + { + user: "user0", + pass: "user0Pass", + role: "role0", + }, + } + + return setupAuth(c, roles, users) +} + +func testMaintenanceOperationWithAuth(t *testing.T, expectConnectError, expectOperationError bool, f func(context.Context, framework.Client) error, opts ...config.ClientOption) { + testRunner.BeforeTest(t) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + clus := testRunner.NewCluster(ctx, t) + defer clus.Close() + + cc := framework.MustClient(clus.Client()) + err := setupAuthForMaintenanceTest(cc) + require.NoError(t, err) + + ccWithAuth, err := clus.Client(opts...) + if expectConnectError { + if err == nil { + t.Fatalf("%s: expected connection error, but got successful response", t.Name()) + } + t.Logf("%s: connection error: %v", t.Name(), err) + return + } + if err != nil { + t.Fatalf("%s: unexpected connection error (%v)", t.Name(), err) + return + } + + // sleep 1 second to wait for etcd cluster to finish the authentication process. + // TODO(ahrtr): find a better way to do it. + time.Sleep(1 * time.Second) + testutils.ExecuteUntil(ctx, t, func() { + err := f(ctx, ccWithAuth) + + if expectOperationError { + if err == nil { + t.Fatalf("%s: expected error, but got successful response", t.Name()) + } + t.Logf("%s: operation error: %v", t.Name(), err) + return + } + + if err != nil { + t.Fatalf("%s: unexpected operation error (%v)", t.Name(), err) + } + }) +}