diff --git a/clientv3/concurrency/example_mutex_test.go b/clientv3/concurrency/example_mutex_test.go index a0463d5b5..6b55340cf 100644 --- a/clientv3/concurrency/example_mutex_test.go +++ b/clientv3/concurrency/example_mutex_test.go @@ -23,6 +23,57 @@ import ( "go.etcd.io/etcd/clientv3/concurrency" ) +func ExampleMutex_TryLock() { + cli, err := clientv3.New(clientv3.Config{Endpoints: endpoints}) + if err != nil { + log.Fatal(err) + } + defer cli.Close() + + // create two separate sessions for lock competition + s1, err := concurrency.NewSession(cli) + if err != nil { + log.Fatal(err) + } + defer s1.Close() + m1 := concurrency.NewMutex(s1, "/my-lock/") + + s2, err := concurrency.NewSession(cli) + if err != nil { + log.Fatal(err) + } + defer s2.Close() + m2 := concurrency.NewMutex(s2, "/my-lock/") + + // acquire lock for s1 + if err = m1.Lock(context.TODO()); err != nil { + log.Fatal(err) + } + fmt.Println("acquired lock for s1") + + if err = m2.TryLock(context.TODO()); err == nil { + log.Fatal("should not acquire lock") + } + if err == concurrency.ErrLocked { + fmt.Println("cannot acquire lock for s2, as already locked in another session") + } + + if err = m1.Unlock(context.TODO()); err != nil { + log.Fatal(err) + } + fmt.Println("released lock for s1") + if err = m2.TryLock(context.TODO()); err != nil { + log.Fatal(err) + } + fmt.Println("acquired lock for s2") + + // Output: + // acquired lock for s1 + // cannot acquire lock for s2, as already locked in another session + // released lock for s1 + // acquired lock for s2 +} + func ExampleMutex_Lock() { cli, err := clientv3.New(clientv3.Config{Endpoints: endpoints}) if err != nil { diff --git a/clientv3/concurrency/mutex.go b/clientv3/concurrency/mutex.go index 013534193..306470b88 100644 --- a/clientv3/concurrency/mutex.go +++ b/clientv3/concurrency/mutex.go @@ -16,6 +16,7 @@ package concurrency import ( "context" + "errors" "fmt" "sync" @@ -23,6 +24,9 @@ import ( pb "go.etcd.io/etcd/etcdserver/etcdserverpb" ) +// ErrLocked is returned by TryLock when Mutex is already locked by another session. +var ErrLocked = errors.New("mutex: Locked by another session") + // Mutex implements the sync Locker interface with etcd type Mutex struct { s *Session @@ -37,9 +41,56 @@ func NewMutex(s *Session, pfx string) *Mutex { return &Mutex{s, pfx + "/", "", -1, nil} } +// TryLock locks the mutex if not already locked by another session. +// If lock is held by another session, return immediately after attempting necessary cleanup +// The ctx argument is used for the sending/receiving Txn RPC. +func (m *Mutex) TryLock(ctx context.Context) error { + resp, err := m.tryAcquire(ctx) + if err != nil { + return err + } + // if no key on prefix / the minimum rev is key, already hold the lock + ownerKey := resp.Responses[1].GetResponseRange().Kvs + if len(ownerKey) == 0 || ownerKey[0].CreateRevision == m.myRev { + m.hdr = resp.Header + return nil + } + client := m.s.Client() + // Cannot lock, so delete the key + if _, err := client.Delete(ctx, m.myKey); err != nil { + return err + } + m.myKey = "\x00" + m.myRev = -1 + return ErrLocked +} + // Lock locks the mutex with a cancelable context. If the context is canceled // while trying to acquire the lock, the mutex tries to clean its stale lock entry. func (m *Mutex) Lock(ctx context.Context) error { + resp, err := m.tryAcquire(ctx) + if err != nil { + return err + } + // if no key on prefix / the minimum rev is key, already hold the lock + ownerKey := resp.Responses[1].GetResponseRange().Kvs + if len(ownerKey) == 0 || ownerKey[0].CreateRevision == m.myRev { + m.hdr = resp.Header + return nil + } + client := m.s.Client() + // wait for deletion revisions prior to myKey + hdr, werr := waitDeletes(ctx, client, m.pfx, m.myRev-1) + // release lock key if wait failed + if werr != nil { + m.Unlock(client.Ctx()) + } else { + m.hdr = hdr + } + return werr +} + +func (m *Mutex) tryAcquire(ctx context.Context) (*v3.TxnResponse, error) { s := m.s client := m.s.Client() @@ -53,28 +104,13 @@ func (m *Mutex) Lock(ctx context.Context) error { getOwner := v3.OpGet(m.pfx, v3.WithFirstCreate()...) resp, err := client.Txn(ctx).If(cmp).Then(put, getOwner).Else(get, getOwner).Commit() if err != nil { - return err + return nil, err } m.myRev = resp.Header.Revision if !resp.Succeeded { m.myRev = resp.Responses[0].GetResponseRange().Kvs[0].CreateRevision } - // if no key on prefix / the minimum rev is key, already hold the lock - ownerKey := resp.Responses[1].GetResponseRange().Kvs - if len(ownerKey) == 0 || ownerKey[0].CreateRevision == m.myRev { - m.hdr = resp.Header - return nil - } - - // wait for deletion revisions prior to myKey - hdr, werr := waitDeletes(ctx, client, m.pfx, m.myRev-1) - // release lock key if wait failed - if werr != nil { - m.Unlock(client.Ctx()) - } else { - m.hdr = hdr - } - return werr + return resp, nil } func (m *Mutex) Unlock(ctx context.Context) error { diff --git a/integration/v3_lock_test.go b/integration/v3_lock_test.go index 88e032796..e36af2d43 100644 --- a/integration/v3_lock_test.go +++ b/integration/v3_lock_test.go @@ -23,30 +23,30 @@ import ( "go.etcd.io/etcd/clientv3" "go.etcd.io/etcd/clientv3/concurrency" - "go.etcd.io/etcd/contrib/recipes" + recipe "go.etcd.io/etcd/contrib/recipes" "go.etcd.io/etcd/mvcc/mvccpb" "go.etcd.io/etcd/pkg/testutil" ) -func TestMutexSingleNode(t *testing.T) { +func TestMutexLockSingleNode(t *testing.T) { clus := NewClusterV3(t, &ClusterConfig{Size: 3}) defer clus.Terminate(t) var clients []*clientv3.Client - testMutex(t, 5, makeSingleNodeClients(t, clus.cluster, &clients)) + testMutexLock(t, 5, makeSingleNodeClients(t, clus.cluster, &clients)) closeClients(t, clients) } -func TestMutexMultiNode(t *testing.T) { +func TestMutexLockMultiNode(t *testing.T) { clus := NewClusterV3(t, &ClusterConfig{Size: 3}) defer clus.Terminate(t) var clients []*clientv3.Client - testMutex(t, 5, makeMultiNodeClients(t, clus.cluster, &clients)) + testMutexLock(t, 5, makeMultiNodeClients(t, clus.cluster, &clients)) closeClients(t, clients) } -func testMutex(t *testing.T, waiters int, chooseClient func() *clientv3.Client) { +func testMutexLock(t *testing.T, waiters int, chooseClient func() *clientv3.Client) { // stream lock acquisitions lockedC := make(chan *concurrency.Mutex) for i := 0; i < waiters; i++ { @@ -82,6 +82,62 @@ func testMutex(t *testing.T, waiters int, chooseClient func() *clientv3.Client) } } +func TestMutexTryLockSingleNode(t *testing.T) { + clus := NewClusterV3(t, &ClusterConfig{Size: 3}) + defer clus.Terminate(t) + + var clients []*clientv3.Client + testMutexTryLock(t, 5, makeSingleNodeClients(t, clus.cluster, &clients)) + closeClients(t, clients) +} + +func TestMutexTryLockMultiNode(t *testing.T) { + clus := NewClusterV3(t, &ClusterConfig{Size: 3}) + defer clus.Terminate(t) + + var clients []*clientv3.Client + testMutexTryLock(t, 5, makeMultiNodeClients(t, clus.cluster, &clients)) + closeClients(t, clients) +} + +func testMutexTryLock(t *testing.T, lockers int, chooseClient func() *clientv3.Client) { + lockedC := make(chan *concurrency.Mutex) + notlockedC := make(chan *concurrency.Mutex) + for i := 0; i < lockers; i++ { + go func() { + session, err := concurrency.NewSession(chooseClient()) + if err != nil { + t.Error(err) + } + m := concurrency.NewMutex(session, "test-mutex-try-lock") + err = m.TryLock(context.TODO()) + if err == nil { + lockedC <- m + } else if err == concurrency.ErrLocked { + notlockedC <- m + } else { + t.Errorf("Unexpected Error %v", err) + } + }() + } + + timerC := time.After(time.Second) + select { + case <-lockedC: + for i := 0; i < lockers-1; i++ { + select { + case <-lockedC: + t.Fatalf("Multiple Mutes locked on same key") + case <-notlockedC: + case <-timerC: + t.Errorf("timed out waiting for lock") + } + } + case <-timerC: + t.Errorf("timed out waiting for lock") + } +} + // TestMutexSessionRelock ensures that acquiring the same lock with the same // session will not result in deadlock. func TestMutexSessionRelock(t *testing.T) { @@ -219,7 +275,7 @@ func BenchmarkMutex4Waiters(b *testing.B) { clus := NewClusterV3(nil, &ClusterConfig{Size: 3}) defer clus.Terminate(nil) for i := 0; i < b.N; i++ { - testMutex(nil, 4, func() *clientv3.Client { return clus.RandClient() }) + testMutexLock(nil, 4, func() *clientv3.Client { return clus.RandClient() }) } }