diff --git a/clientv3/client.go b/clientv3/client.go index 4f3f5abb1..d8cb13808 100644 --- a/clientv3/client.go +++ b/clientv3/client.go @@ -86,9 +86,8 @@ type Client struct { // Username is a user name for authentication. Username string // Password is a password for authentication. - Password string - // tokenCred is an instance of WithPerRPCCredentials()'s argument - tokenCred *authTokenCredential + Password string + authTokenBundle credentials.Bundle callOpts []grpc.CallOption @@ -193,23 +192,6 @@ func (c *Client) autoSync() { } } -type authTokenCredential struct { - token string - tokenMu *sync.RWMutex -} - -func (cred authTokenCredential) RequireTransportSecurity() bool { - return false -} - -func (cred authTokenCredential) GetRequestMetadata(ctx context.Context, s ...string) (map[string]string, error) { - cred.tokenMu.RLock() - defer cred.tokenMu.RUnlock() - return map[string]string{ - rpctypes.TokenFieldNameGRPC: cred.token, - }, nil -} - func (c *Client) processCreds(scheme string) (creds grpccredentials.TransportCredentials) { creds = c.creds switch scheme { @@ -316,10 +298,7 @@ func (c *Client) getToken(ctx context.Context) error { continue } - c.tokenCred.tokenMu.Lock() - c.tokenCred.token = resp.Token - c.tokenCred.tokenMu.Unlock() - + c.authTokenBundle.UpdateAuthToken(resp.Token) return nil } @@ -343,9 +322,7 @@ func (c *Client) dial(target string, creds grpccredentials.TransportCredentials, } if c.Username != "" && c.Password != "" { - c.tokenCred = &authTokenCredential{ - tokenMu: &sync.RWMutex{}, - } + c.authTokenBundle = credentials.NewBundle(credentials.Config{}) ctx, cancel := c.ctx, func() {} if c.cfg.DialTimeout > 0 { @@ -362,7 +339,7 @@ func (c *Client) dial(target string, creds grpccredentials.TransportCredentials, return nil, err } } else { - opts = append(opts, grpc.WithPerRPCCredentials(c.tokenCred)) + opts = append(opts, grpc.WithPerRPCCredentials(c.authTokenBundle.PerRPCCredentials())) } cancel() } diff --git a/clientv3/credentials/credentials.go b/clientv3/credentials/credentials.go index b9f66bd4b..e6fd75cc3 100644 --- a/clientv3/credentials/credentials.go +++ b/clientv3/credentials/credentials.go @@ -29,14 +29,19 @@ import ( // Config defines gRPC credential configuration. type Config struct { TLSConfig *tls.Config - AuthToken string +} + +// Bundle defines gRPC credential interface. +type Bundle interface { + grpccredentials.Bundle + UpdateAuthToken(token string) } // NewBundle constructs a new gRPC credential bundle. -func NewBundle(cfg Config) grpccredentials.Bundle { +func NewBundle(cfg Config) Bundle { return &bundle{ tc: newTransportCredential(cfg.TLSConfig), - rc: newPerRPCCredential(cfg.AuthToken), + rc: newPerRPCCredential(), } } @@ -125,14 +130,7 @@ type perRPCCredential struct { authTokenMu sync.RWMutex } -func newPerRPCCredential(authToken string) *perRPCCredential { - if authToken == "" { - return nil - } - return &perRPCCredential{ - authToken: authToken, - } -} +func newPerRPCCredential() *perRPCCredential { return &perRPCCredential{} } func (rc *perRPCCredential) RequireTransportSecurity() bool { return false } @@ -142,3 +140,16 @@ func (rc *perRPCCredential) GetRequestMetadata(ctx context.Context, s ...string) rc.authTokenMu.RUnlock() return map[string]string{rpctypes.TokenFieldNameGRPC: authToken}, nil } + +func (b *bundle) UpdateAuthToken(token string) { + if b.rc == nil { + return + } + b.rc.UpdateAuthToken(token) +} + +func (rc *perRPCCredential) UpdateAuthToken(token string) { + rc.authTokenMu.Lock() + rc.authToken = token + rc.authTokenMu.Unlock() +}