etcdclient: check mutex state in Unlock method of concurrency.Mutex

Check the values of myKey and myRev first in Unlock method to prevent calling Unlock without Lock. Because this may cause the value of pfx to be deleted by mistake.

Signed-off-by: chenyahui <cyhone@qq.com>
This commit is contained in:
chenyahui 2022-11-08 17:43:00 +08:00
parent 554b1bd0b0
commit 5b8c6b548f
2 changed files with 50 additions and 0 deletions

View File

@ -18,6 +18,7 @@ import (
"context"
"errors"
"fmt"
"strings"
"sync"
pb "go.etcd.io/etcd/api/v3/etcdserverpb"
@ -27,6 +28,7 @@ import (
// ErrLocked is returned by TryLock when Mutex is already locked by another session.
var ErrLocked = errors.New("mutex: Locked by another session")
var ErrSessionExpired = errors.New("mutex: session is expired")
var ErrLockReleased = errors.New("mutex: lock has already been released")
// Mutex implements the sync Locker interface with etcd
type Mutex struct {
@ -128,6 +130,14 @@ func (m *Mutex) tryAcquire(ctx context.Context) (*v3.TxnResponse, error) {
}
func (m *Mutex) Unlock(ctx context.Context) error {
if m.myKey == "" || m.myRev <= 0 || m.myKey == "\x00" {
return ErrLockReleased
}
if !strings.HasPrefix(m.myKey, m.pfx) {
return fmt.Errorf("invalid key %q, it should have prefix %q", m.myKey, m.pfx)
}
client := m.s.Client()
if _, err := client.Delete(ctx, m.myKey); err != nil {
return err

View File

@ -16,6 +16,7 @@ package concurrency_test
import (
"context"
"errors"
"testing"
"go.etcd.io/etcd/client/v3"
@ -70,3 +71,42 @@ func TestMutexLockSessionExpired(t *testing.T) {
<-m2Locked
}
func TestMutexUnlock(t *testing.T) {
cli, err := integration2.NewClient(t, clientv3.Config{Endpoints: exampleEndpoints()})
if err != nil {
t.Fatal(err)
}
defer cli.Close()
s1, err := concurrency.NewSession(cli)
if err != nil {
t.Fatal(err)
}
defer s1.Close()
m1 := concurrency.NewMutex(s1, "/my-lock/")
err = m1.Unlock(context.TODO())
if err == nil {
t.Fatal("expect lock released error")
}
if !errors.Is(err, concurrency.ErrLockReleased) {
t.Fatal(err)
}
if err := m1.Lock(context.TODO()); err != nil {
t.Fatal(err)
}
if err := m1.Unlock(context.TODO()); err != nil {
t.Fatal(err)
}
err = m1.Unlock(context.TODO())
if err == nil {
t.Fatal("expect lock released error")
}
if !errors.Is(err, concurrency.ErrLockReleased) {
t.Fatal(err)
}
}