From 7ac995cdde6ae56beba93b040fe231dfca03c38d Mon Sep 17 00:00:00 2001 From: ahrtr Date: Sat, 2 Apr 2022 06:02:22 +0800 Subject: [PATCH] enhanced authBackend to support authReadTx --- server/auth/range_perm_cache.go | 4 +- server/auth/store.go | 12 ++++-- server/auth/store_mock_test.go | 4 ++ server/etcdserver/cindex/cindex.go | 2 +- server/etcdserver/server.go | 5 ++- server/storage/schema/auth.go | 54 +++++++++++++++++++------ server/storage/schema/auth_roles.go | 62 +++++++++++++++++------------ server/storage/schema/auth_users.go | 50 +++++++++++++---------- 8 files changed, 127 insertions(+), 66 deletions(-) diff --git a/server/auth/range_perm_cache.go b/server/auth/range_perm_cache.go index bae07ef52..2ebe5439b 100644 --- a/server/auth/range_perm_cache.go +++ b/server/auth/range_perm_cache.go @@ -20,7 +20,7 @@ import ( "go.uber.org/zap" ) -func getMergedPerms(tx AuthBatchTx, userName string) *unifiedRangePermissions { +func getMergedPerms(tx AuthReadTx, userName string) *unifiedRangePermissions { user := tx.UnsafeGetUser(userName) if user == nil { return nil @@ -103,7 +103,7 @@ func checkKeyPoint(lg *zap.Logger, cachedPerms *unifiedRangePermissions, key []b return false } -func (as *authStore) isRangeOpPermitted(tx AuthBatchTx, userName string, key, rangeEnd []byte, permtyp authpb.Permission_Type) bool { +func (as *authStore) isRangeOpPermitted(tx AuthReadTx, userName string, key, rangeEnd []byte, permtyp authpb.Permission_Type) bool { // assumption: tx is Lock()ed _, ok := as.rangePermCache[userName] if !ok { diff --git a/server/auth/store.go b/server/auth/store.go index 408b235ba..762caecd7 100644 --- a/server/auth/store.go +++ b/server/auth/store.go @@ -196,6 +196,7 @@ type TokenProvider interface { type AuthBackend interface { CreateAuthBuckets() ForceCommit() + ReadTx() AuthReadTx BatchTx() AuthBatchTx GetUser(string) *authpb.User @@ -345,7 +346,7 @@ func (as *authStore) CheckPassword(username, password string) (uint64, error) { // CompareHashAndPassword is very expensive, so we use closures // to avoid putting it in the critical section of the tx lock. revision, err := func() (uint64, error) { - tx := as.be.BatchTx() + tx := as.be.ReadTx() tx.Lock() defer tx.Unlock() @@ -855,7 +856,7 @@ func (as *authStore) isOpPermitted(userName string, revision uint64, key, rangeE return ErrAuthOldRevision } - tx := as.be.BatchTx() + tx := as.be.ReadTx() tx.Lock() defer tx.Unlock() @@ -897,7 +898,10 @@ func (as *authStore) IsAdminPermitted(authInfo *AuthInfo) error { return ErrUserEmpty } - u := as.be.GetUser(authInfo.Username) + tx := as.be.ReadTx() + tx.Lock() + defer tx.Unlock() + u := tx.UnsafeGetUser(authInfo.Username) if u == nil { return ErrUserNotFound @@ -935,6 +939,8 @@ func NewAuthStore(lg *zap.Logger, be AuthBackend, tp TokenProvider, bcryptCost i be.CreateAuthBuckets() tx := be.BatchTx() + // We should call LockWithoutHook here, but the txPostLockHoos isn't set + // to EtcdServer yet, so it's OK. tx.Lock() enabled := tx.UnsafeReadAuthEnabled() as := &authStore{ diff --git a/server/auth/store_mock_test.go b/server/auth/store_mock_test.go index d49f8dd33..39c3f6d13 100644 --- a/server/auth/store_mock_test.go +++ b/server/auth/store_mock_test.go @@ -36,6 +36,10 @@ func (b *backendMock) CreateAuthBuckets() { func (b *backendMock) ForceCommit() { } +func (b *backendMock) ReadTx() AuthReadTx { + return &txMock{be: b} +} + func (b *backendMock) BatchTx() AuthBatchTx { return &txMock{be: b} } diff --git a/server/etcdserver/cindex/cindex.go b/server/etcdserver/cindex/cindex.go index 7ec1b1212..6367967f8 100644 --- a/server/etcdserver/cindex/cindex.go +++ b/server/etcdserver/cindex/cindex.go @@ -89,7 +89,7 @@ func (ci *consistentIndex) UnsafeConsistentIndex() uint64 { return index } - v, term := schema.UnsafeReadConsistentIndex(ci.be.BatchTx()) + v, term := schema.UnsafeReadConsistentIndex(ci.be.ReadTx()) ci.SetConsistentIndex(v, term) return v } diff --git a/server/etcdserver/server.go b/server/etcdserver/server.go index b22f680bb..6a89f4592 100644 --- a/server/etcdserver/server.go +++ b/server/etcdserver/server.go @@ -343,7 +343,6 @@ func NewServer(cfg config.ServerConfig) (srv *EtcdServer, err error) { srv.applyV2 = NewApplierV2(cfg.Logger, srv.v2store, srv.cluster) srv.be = b.storage.backend.be - srv.be.SetTxPostLockHook(srv.getTxPostLockHook()) srv.beHooks = b.storage.backend.beHooks minTTL := time.Duration((3*cfg.ElectionTicks)/2) * heartbeat @@ -404,6 +403,10 @@ func NewServer(cfg config.ServerConfig) (srv *EtcdServer, err error) { }) } + // Set the hook after EtcdServer finishes the initialization to avoid + // the hook being called during the initialization process. + srv.be.SetTxPostLockHook(srv.getTxPostLockHook()) + // TODO: move transport initialization near the definition of remote tr := &rafthttp.Transport{ Logger: cfg.Logger, diff --git a/server/storage/schema/auth.go b/server/storage/schema/auth.go index fc334a8bc..3956ca782 100644 --- a/server/storage/schema/auth.go +++ b/server/storage/schema/auth.go @@ -60,15 +60,25 @@ func (abe *authBackend) ForceCommit() { abe.be.ForceCommit() } +func (abe *authBackend) ReadTx() auth.AuthReadTx { + return &authReadTx{tx: abe.be.ReadTx(), lg: abe.lg} +} + func (abe *authBackend) BatchTx() auth.AuthBatchTx { return &authBatchTx{tx: abe.be.BatchTx(), lg: abe.lg} } +type authReadTx struct { + tx backend.ReadTx + lg *zap.Logger +} + type authBatchTx struct { tx backend.BatchTx lg *zap.Logger } +var _ auth.AuthReadTx = (*authReadTx)(nil) var _ auth.AuthBatchTx = (*authBatchTx)(nil) func (atx *authBatchTx) UnsafeSaveAuthEnabled(enabled bool) { @@ -86,22 +96,13 @@ func (atx *authBatchTx) UnsafeSaveAuthRevision(rev uint64) { } func (atx *authBatchTx) UnsafeReadAuthEnabled() bool { - _, vs := atx.tx.UnsafeRange(Auth, AuthEnabledKeyName, nil, 0) - if len(vs) == 1 { - if bytes.Equal(vs[0], authEnabled) { - return true - } - } - return false + arx := &authReadTx{tx: atx.tx, lg: atx.lg} + return arx.UnsafeReadAuthEnabled() } func (atx *authBatchTx) UnsafeReadAuthRevision() uint64 { - _, vs := atx.tx.UnsafeRange(Auth, AuthRevisionKeyName, nil, 0) - if len(vs) != 1 { - // this can happen in the initialization phase - return 0 - } - return binary.BigEndian.Uint64(vs[0]) + arx := &authReadTx{tx: atx.tx, lg: atx.lg} + return arx.UnsafeReadAuthRevision() } func (atx *authBatchTx) Lock() { @@ -111,3 +112,30 @@ func (atx *authBatchTx) Lock() { func (atx *authBatchTx) Unlock() { atx.tx.Unlock() } + +func (atx *authReadTx) UnsafeReadAuthEnabled() bool { + _, vs := atx.tx.UnsafeRange(Auth, AuthEnabledKeyName, nil, 0) + if len(vs) == 1 { + if bytes.Equal(vs[0], authEnabled) { + return true + } + } + return false +} + +func (atx *authReadTx) UnsafeReadAuthRevision() uint64 { + _, vs := atx.tx.UnsafeRange(Auth, AuthRevisionKeyName, nil, 0) + if len(vs) != 1 { + // this can happen in the initialization phase + return 0 + } + return binary.BigEndian.Uint64(vs[0]) +} + +func (atx *authReadTx) Lock() { + atx.tx.RLock() +} + +func (atx *authReadTx) Unlock() { + atx.tx.RUnlock() +} diff --git a/server/storage/schema/auth_roles.go b/server/storage/schema/auth_roles.go index 541e37b71..dfda7ce5b 100644 --- a/server/storage/schema/auth_roles.go +++ b/server/storage/schema/auth_roles.go @@ -32,17 +32,8 @@ func (abe *authBackend) GetRole(roleName string) *authpb.Role { } func (atx *authBatchTx) UnsafeGetRole(roleName string) *authpb.Role { - _, vs := atx.tx.UnsafeRange(AuthRoles, []byte(roleName), nil, 0) - if len(vs) == 0 { - return nil - } - - role := &authpb.Role{} - err := role.Unmarshal(vs[0]) - if err != nil { - atx.lg.Panic("failed to unmarshal 'authpb.Role'", zap.Error(err)) - } - return role + arx := &authReadTx{tx: atx.tx, lg: atx.lg} + return arx.UnsafeGetRole(roleName) } func (abe *authBackend) GetAllRoles() []*authpb.Role { @@ -53,21 +44,8 @@ func (abe *authBackend) GetAllRoles() []*authpb.Role { } func (atx *authBatchTx) UnsafeGetAllRoles() []*authpb.Role { - _, vs := atx.tx.UnsafeRange(AuthRoles, []byte{0}, []byte{0xff}, -1) - if len(vs) == 0 { - return nil - } - - roles := make([]*authpb.Role, len(vs)) - for i := range vs { - role := &authpb.Role{} - err := role.Unmarshal(vs[i]) - if err != nil { - atx.lg.Panic("failed to unmarshal 'authpb.Role'", zap.Error(err)) - } - roles[i] = role - } - return roles + arx := &authReadTx{tx: atx.tx, lg: atx.lg} + return arx.UnsafeGetAllRoles() } func (atx *authBatchTx) UnsafePutRole(role *authpb.Role) { @@ -86,3 +64,35 @@ func (atx *authBatchTx) UnsafePutRole(role *authpb.Role) { func (atx *authBatchTx) UnsafeDeleteRole(rolename string) { atx.tx.UnsafeDelete(AuthRoles, []byte(rolename)) } + +func (atx *authReadTx) UnsafeGetRole(roleName string) *authpb.Role { + _, vs := atx.tx.UnsafeRange(AuthRoles, []byte(roleName), nil, 0) + if len(vs) == 0 { + return nil + } + + role := &authpb.Role{} + err := role.Unmarshal(vs[0]) + if err != nil { + atx.lg.Panic("failed to unmarshal 'authpb.Role'", zap.Error(err)) + } + return role +} + +func (atx *authReadTx) UnsafeGetAllRoles() []*authpb.Role { + _, vs := atx.tx.UnsafeRange(AuthRoles, []byte{0}, []byte{0xff}, -1) + if len(vs) == 0 { + return nil + } + + roles := make([]*authpb.Role, len(vs)) + for i := range vs { + role := &authpb.Role{} + err := role.Unmarshal(vs[i]) + if err != nil { + atx.lg.Panic("failed to unmarshal 'authpb.Role'", zap.Error(err)) + } + roles[i] = role + } + return roles +} diff --git a/server/storage/schema/auth_users.go b/server/storage/schema/auth_users.go index f385afa51..c3e7a92ff 100644 --- a/server/storage/schema/auth_users.go +++ b/server/storage/schema/auth_users.go @@ -27,6 +27,35 @@ func (abe *authBackend) GetUser(username string) *authpb.User { } func (atx *authBatchTx) UnsafeGetUser(username string) *authpb.User { + arx := &authReadTx{tx: atx.tx, lg: atx.lg} + return arx.UnsafeGetUser(username) +} + +func (abe *authBackend) GetAllUsers() []*authpb.User { + tx := abe.BatchTx() + tx.Lock() + defer tx.Unlock() + return tx.UnsafeGetAllUsers() +} + +func (atx *authBatchTx) UnsafeGetAllUsers() []*authpb.User { + arx := &authReadTx{tx: atx.tx, lg: atx.lg} + return arx.UnsafeGetAllUsers() +} + +func (atx *authBatchTx) UnsafePutUser(user *authpb.User) { + b, err := user.Marshal() + if err != nil { + atx.lg.Panic("failed to unmarshal 'authpb.User'", zap.Error(err)) + } + atx.tx.UnsafePut(AuthUsers, user.Name, b) +} + +func (atx *authBatchTx) UnsafeDeleteUser(username string) { + atx.tx.UnsafeDelete(AuthUsers, []byte(username)) +} + +func (atx *authReadTx) UnsafeGetUser(username string) *authpb.User { _, vs := atx.tx.UnsafeRange(AuthUsers, []byte(username), nil, 0) if len(vs) == 0 { return nil @@ -44,14 +73,7 @@ func (atx *authBatchTx) UnsafeGetUser(username string) *authpb.User { return user } -func (abe *authBackend) GetAllUsers() []*authpb.User { - tx := abe.BatchTx() - tx.Lock() - defer tx.Unlock() - return tx.UnsafeGetAllUsers() -} - -func (atx *authBatchTx) UnsafeGetAllUsers() []*authpb.User { +func (atx *authReadTx) UnsafeGetAllUsers() []*authpb.User { _, vs := atx.tx.UnsafeRange(AuthUsers, []byte{0}, []byte{0xff}, -1) if len(vs) == 0 { return nil @@ -68,15 +90,3 @@ func (atx *authBatchTx) UnsafeGetAllUsers() []*authpb.User { } return users } - -func (atx *authBatchTx) UnsafePutUser(user *authpb.User) { - b, err := user.Marshal() - if err != nil { - atx.lg.Panic("failed to unmarshal 'authpb.User'", zap.Error(err)) - } - atx.tx.UnsafePut(AuthUsers, user.Name, b) -} - -func (atx *authBatchTx) UnsafeDeleteUser(username string) { - atx.tx.UnsafeDelete(AuthUsers, []byte(username)) -}