diff --git a/clientv3/concurrency/mutex.go b/clientv3/concurrency/mutex.go index 8ea8acd52..c38b1568f 100644 --- a/clientv3/concurrency/mutex.go +++ b/clientv3/concurrency/mutex.go @@ -15,6 +15,7 @@ package concurrency import ( + "fmt" "sync" v3 "github.com/coreos/etcd/clientv3" @@ -37,12 +38,26 @@ func NewMutex(client *v3.Client, pfx string) *Mutex { // Lock locks the mutex with a cancellable context. If the context is cancelled // while trying to acquire the lock, the mutex tries to clean its stale lock entry. func (m *Mutex) Lock(ctx context.Context) error { - s, err := NewSession(m.client) + s, serr := NewSession(m.client) + if serr != nil { + return serr + } + + m.myKey = fmt.Sprintf("%s/%x", m.pfx, s.Lease()) + cmp := v3.Compare(v3.CreateRevision(m.myKey), "=", 0) + // put self in lock waiters via myKey; oldest waiter holds lock + put := v3.OpPut(m.myKey, "", v3.WithLease(s.Lease())) + // reuse key in case this session already holds the lock + get := v3.OpGet(m.myKey) + resp, err := m.client.Txn(ctx).If(cmp).Then(put).Else(get).Commit() if err != nil { return err } - // put self in lock waiters via myKey; oldest waiter holds lock - m.myKey, m.myRev, err = NewUniqueKey(ctx, m.client, m.pfx, v3.WithLease(s.Lease())) + m.myRev = resp.Header.Revision + if !resp.Succeeded { + m.myRev = resp.Responses[0].GetResponseRange().Kvs[0].CreateRevision + } + // wait for deletion revisions prior to myKey err = waitDeletes(ctx, m.client, m.pfx, v3.WithPrefix(), v3.WithRev(m.myRev-1)) // release lock key if cancelled diff --git a/integration/v3_lock_test.go b/integration/v3_lock_test.go index 7607a71c4..b51b444ae 100644 --- a/integration/v3_lock_test.go +++ b/integration/v3_lock_test.go @@ -16,6 +16,7 @@ package integration import ( "math/rand" + "sync" "testing" "time" @@ -28,18 +29,24 @@ import ( func TestMutexSingleNode(t *testing.T) { clus := NewClusterV3(t, &ClusterConfig{Size: 3}) defer clus.Terminate(t) - testMutex(t, 5, func() *clientv3.Client { return clus.clients[0] }) + + var clients []*clientv3.Client + testMutex(t, 5, makeSingleNodeClients(t, clus.cluster, &clients)) + closeClients(t, clients) } func TestMutexMultiNode(t *testing.T) { clus := NewClusterV3(t, &ClusterConfig{Size: 3}) defer clus.Terminate(t) - testMutex(t, 5, func() *clientv3.Client { return clus.RandClient() }) + + var clients []*clientv3.Client + testMutex(t, 5, makeMultiNodeClients(t, clus.cluster, &clients)) + closeClients(t, clients) } func testMutex(t *testing.T, waiters int, chooseClient func() *clientv3.Client) { // stream lock acquisitions - lockedC := make(chan *concurrency.Mutex, 1) + lockedC := make(chan *concurrency.Mutex) for i := 0; i < waiters; i++ { go func() { m := concurrency.NewMutex(chooseClient(), "test-mutex") @@ -69,6 +76,22 @@ func testMutex(t *testing.T, waiters int, chooseClient func() *clientv3.Client) } } +// TestMutexSessionRelock ensures that acquiring the same lock with the same +// session will not result in deadlock. +func TestMutexSessionRelock(t *testing.T) { + clus := NewClusterV3(t, &ClusterConfig{Size: 3}) + defer clus.Terminate(t) + cli := clus.RandClient() + m := concurrency.NewMutex(cli, "test-mutex") + if err := m.Lock(context.TODO()); err != nil { + t.Fatal(err) + } + m2 := concurrency.NewMutex(cli, "test-mutex") + if err := m2.Lock(context.TODO()); err != nil { + t.Fatal(err) + } +} + func BenchmarkMutex4Waiters(b *testing.B) { // XXX switch tests to use TB interface clus := NewClusterV3(nil, &ClusterConfig{Size: 3}) @@ -137,3 +160,38 @@ func testRWMutex(t *testing.T, waiters int, chooseClient func() *clientv3.Client } } } + +func makeClients(t *testing.T, clients *[]*clientv3.Client, choose func() *member) func() *clientv3.Client { + var mu sync.Mutex + *clients = nil + return func() *clientv3.Client { + cli, err := NewClientV3(choose()) + if err != nil { + t.Fatal(err) + } + mu.Lock() + *clients = append(*clients, cli) + mu.Unlock() + return cli + } +} + +func makeSingleNodeClients(t *testing.T, clus *cluster, clients *[]*clientv3.Client) func() *clientv3.Client { + return makeClients(t, clients, func() *member { + return clus.Members[0] + }) +} + +func makeMultiNodeClients(t *testing.T, clus *cluster, clients *[]*clientv3.Client) func() *clientv3.Client { + return makeClients(t, clients, func() *member { + return clus.Members[rand.Intn(len(clus.Members))] + }) +} + +func closeClients(t *testing.T, clients []*clientv3.Client) { + for _, cli := range clients { + if err := cli.Close(); err != nil { + t.Fatal(err) + } + } +}