diff --git a/Documentation/modules.md b/Documentation/modules.md index cc25ea6c2..9507c3235 100644 --- a/Documentation/modules.md +++ b/Documentation/modules.md @@ -15,6 +15,7 @@ Use the `-cors='*'` flag to allow your browser to request information from the c The Lock module implements a fair lock that can be used when lots of clients want access to a single resource. A lock can be associated with a value. The value is unique so if a lock tries to request a value that is already queued for a lock then it will find it and watch until that value obtains the lock. +You may supply a `timeout` which will cancel the lock request if it is not obtained within `timeout` seconds. If `timeout` is not supplied, it is presumed to be infinite. If `timeout` is `0`, the lock request will fail if it is not immediately acquired. If you lock the same value on a key from two separate curl sessions they'll both return at the same time. Here's the API: @@ -31,6 +32,12 @@ curl -X POST http://127.0.0.1:4001/mod/v2/lock/customer1?ttl=60 curl -X POST http://127.0.0.1:4001/mod/v2/lock/customer1?ttl=60 -d value=bar ``` +**Acquire a lock for "customer1" that is associated with the value "bar" only if it is done within 2 seconds** + +```sh +curl -X POST http://127.0.0.1:4001/mod/v2/lock/customer1?ttl=60 -d value=bar -d timeout=2 +``` + **Renew the TTL on the "customer1" lock for index 2** ```sh diff --git a/mod/lock/v2/acquire_handler.go b/mod/lock/v2/acquire_handler.go index 58fc1e8aa..00e2a489d 100644 --- a/mod/lock/v2/acquire_handler.go +++ b/mod/lock/v2/acquire_handler.go @@ -24,7 +24,22 @@ func (h *handler) acquireHandler(w http.ResponseWriter, req *http.Request) error // Setup connection watcher. closeNotifier, _ := w.(http.CloseNotifier) closeChan := closeNotifier.CloseNotify() + + // Wrap closeChan so we can pass it to subsequent components + timeoutChan := make(chan bool) stopChan := make(chan bool) + go func() { + select { + case <-closeChan: + // Client closed connection + stopChan <- true + case <-timeoutChan: + // Timeout expired + stopChan <- true + case <-stopChan: + } + close(stopChan) + }() // Parse the lock "key". vars := mux.Vars(req) @@ -39,7 +54,6 @@ func (h *handler) acquireHandler(w http.ResponseWriter, req *http.Request) error } else if timeout, err = strconv.Atoi(req.FormValue("timeout")); err != nil { return etcdErr.NewError(etcdErr.EcodeTimeoutNaN, "Acquire", 0) } - timeout = timeout + 1 // Parse TTL. ttl, err := strconv.Atoi(req.FormValue("ttl")) @@ -47,35 +61,65 @@ func (h *handler) acquireHandler(w http.ResponseWriter, req *http.Request) error return etcdErr.NewError(etcdErr.EcodeTTLNaN, "Acquire", 0) } - // If node exists then just watch it. Otherwise create the node and watch it. - node, index, pos := h.findExistingNode(keypath, value) - if index > 0 { - if pos == 0 { - // If lock is already acquired then update the TTL. - h.client.Update(node.Key, node.Value, uint64(ttl)) - } else { - // Otherwise watch until it becomes acquired (or errors). - err = h.watch(keypath, index, nil) + // Search for the node + _, index, pos := h.findExistingNode(keypath, value) + if index == 0 { + // Node doesn't exist; Create it + pos = -1 // Invalidate previous position + index, err = h.createNode(keypath, value, ttl) + if err != nil { + return err } - } else { - index, err = h.createNode(keypath, value, ttl, closeChan, stopChan) } - // Stop all goroutines. - close(stopChan) + indexpath := path.Join(keypath, strconv.Itoa(index)) - // Check for an error. + // If pos != 0, we do not already have the lock + if pos != 0 { + if timeout == 0 { + // Attempt to get lock once, no waiting + err = h.get(keypath, index) + } else { + // Keep updating TTL while we wait + go h.ttlKeepAlive(keypath, value, ttl, stopChan) + + // Start timeout + go h.timeoutExpire(timeout, timeoutChan, stopChan) + + // wait for lock + err = h.watch(keypath, index, stopChan) + } + } + + // Return on error, deleting our lock request on the way if err != nil { + if index > 0 { + h.client.Delete(indexpath, false) + } return err } + // Check for connection disconnect before we write the lock index. + select { + case <-stopChan: + err = errors.New("user interrupted") + default: + } + + // Update TTL one last time if lock was acquired. Otherwise delete. + if err == nil { + h.client.Update(indexpath, value, uint64(ttl)) + } else { + h.client.Delete(indexpath, false) + } + // Write response. w.Write([]byte(strconv.Itoa(index))) return nil } // createNode creates a new lock node and watches it until it is acquired or acquisition fails. -func (h *handler) createNode(keypath string, value string, ttl int, closeChan <-chan bool, stopChan chan bool) (int, error) { +func (h *handler) createNode(keypath string, value string, ttl int) (int, error) { // Default the value to "-" if it is blank. if len(value) == 0 { value = "-" @@ -87,30 +131,7 @@ func (h *handler) createNode(keypath string, value string, ttl int, closeChan <- return 0, err } indexpath := resp.Node.Key - index, _ := strconv.Atoi(path.Base(indexpath)) - - // Keep updating TTL to make sure lock request is not expired before acquisition. - go h.ttlKeepAlive(indexpath, value, ttl, stopChan) - - // Watch until we acquire or fail. - err = h.watch(keypath, index, closeChan) - - // Check for connection disconnect before we write the lock index. - if err != nil { - select { - case <-closeChan: - err = errors.New("user interrupted") - default: - } - } - - // Update TTL one last time if acquired. Otherwise delete. - if err == nil { - h.client.Update(indexpath, value, uint64(ttl)) - } else { - h.client.Delete(indexpath, false) - } - + index, err := strconv.Atoi(path.Base(indexpath)) return index, err } @@ -141,6 +162,47 @@ func (h *handler) ttlKeepAlive(k string, value string, ttl int, stopChan chan bo } } +// timeoutExpire sets the countdown timer is a positive integer +// cancels on stopChan, sends true on timeoutChan after timer expires +func (h *handler) timeoutExpire(timeout int, timeoutChan chan bool, stopChan chan bool) { + // Set expiration timer if timeout is 1 or higher + if timeout < 1 { + timeoutChan = nil + return + } + select { + case <-stopChan: + return + case <-time.After(time.Duration(timeout) * time.Second): + timeoutChan <- true + return + } +} + +func (h *handler) getLockIndex(keypath string, index int) (int, int, error) { + // Read all nodes for the lock. + resp, err := h.client.Get(keypath, true, true) + if err != nil { + return 0, 0, fmt.Errorf("lock watch lookup error: %s", err.Error()) + } + nodes := lockNodes{resp.Node.Nodes} + prevIndex, modifiedIndex := nodes.PrevIndex(index) + return prevIndex, modifiedIndex, nil +} + +// get tries once to get the lock; no waiting +func (h *handler) get(keypath string, index int) error { + prevIndex, _, err := h.getLockIndex(keypath, index) + if err != nil { + return err + } + if prevIndex == 0 { + // Lock acquired + return nil + } + return fmt.Errorf("failed to acquire lock") +} + // watch continuously waits for a given lock index to be acquired or until lock fails. // Returns a boolean indicating success. func (h *handler) watch(keypath string, index int, closeChan <-chan bool) error { @@ -151,22 +213,15 @@ func (h *handler) watch(keypath string, index int, closeChan <-chan bool) error select { case <-closeChan: stopWatchChan <- true - case <- stopWrapChan: + case <-stopWrapChan: stopWatchChan <- true - case <- stopWatchChan: + case <-stopWatchChan: } }() defer close(stopWrapChan) for { - // Read all nodes for the lock. - resp, err := h.client.Get(keypath, true, true) - if err != nil { - return fmt.Errorf("lock watch lookup error: %s", err.Error()) - } - nodes := lockNodes{resp.Node.Nodes} - prevIndex, modifiedIndex := nodes.PrevIndex(index) - + prevIndex, modifiedIndex, err := h.getLockIndex(keypath, index) // If there is no previous index then we have the lock. if prevIndex == 0 { return nil @@ -175,11 +230,12 @@ func (h *handler) watch(keypath string, index int, closeChan <-chan bool) error // Wait from the last modification of the node. waitIndex := modifiedIndex + 1 - resp, err = h.client.Watch(path.Join(keypath, strconv.Itoa(prevIndex)), uint64(waitIndex), false, nil, stopWatchChan) + _, err = h.client.Watch(path.Join(keypath, strconv.Itoa(prevIndex)), uint64(waitIndex), false, nil, stopWatchChan) if err == etcd.ErrWatchStoppedByUser { return fmt.Errorf("lock watch closed") } else if err != nil { return fmt.Errorf("lock watch error: %s", err.Error()) } + return nil } } diff --git a/mod/lock/v2/tests/mod_lock_test.go b/mod/lock/v2/tests/mod_lock_test.go index 083bb6c72..e33842052 100644 --- a/mod/lock/v2/tests/mod_lock_test.go +++ b/mod/lock/v2/tests/mod_lock_test.go @@ -215,12 +215,151 @@ func TestModLockAcquireAndReleaseByValue(t *testing.T) { }) } +// Ensure that a lock honours the timeout option +func TestModLockAcquireTimeout(t *testing.T) { + tests.RunServer(func(s *server.Server) { + c := make(chan bool) + + // Acquire lock #1. + go func() { + body, status, err := testAcquireLock(s, "foo", "first", 10) + assert.NoError(t, err) + assert.Equal(t, status, 200) + assert.Equal(t, body, "2") + c <- true + }() + <-c + + // Attempt to acquire lock #2, timing out after 1s. + waiting := true + go func() { + c <- true + _, status, err := testAcquireLockWithTimeout(s, "foo", "second", 10, 1) + assert.NoError(t, err) + assert.Equal(t, status, 500) + waiting = false + }() + <-c + + time.Sleep(5 * time.Second) + + // Check that we have the lock #1. + body, status, err := testGetLockIndex(s, "foo") + assert.NoError(t, err) + assert.Equal(t, status, 200) + assert.Equal(t, body, "2") + + // Check that we are not still waiting for lock #2. + assert.Equal(t, waiting, false) + + // Release lock #1. + _, status, err = testReleaseLock(s, "foo", "2", "") + assert.NoError(t, err) + assert.Equal(t, status, 200) + + // Check that we have no lock. + body, status, err = testGetLockIndex(s, "foo") + assert.NoError(t, err) + assert.Equal(t, status, 200) + }) +} + +// Ensure that a lock succeeds when timeout=0 (nowait) +func TestModLockAcquireNoWait(t *testing.T) { + tests.RunServer(func(s *server.Server) { + c := make(chan bool) + + // Acquire lock with no waiting. + go func() { + body, status, err := testAcquireLockWithTimeout(s, "foo", "first", 10, 0) + assert.NoError(t, err) + assert.Equal(t, status, 200) + assert.Equal(t, body, "2") + c <- true + }() + <-c + + time.Sleep(1 * time.Second) + + // Check that we have the lock #1. + body, status, err := testGetLockIndex(s, "foo") + assert.NoError(t, err) + assert.Equal(t, status, 200) + assert.Equal(t, body, "2") + + // Release lock #1. + _, status, err = testReleaseLock(s, "foo", "2", "") + assert.NoError(t, err) + assert.Equal(t, status, 200) + + // Check that we have no lock. + body, status, err = testGetLockIndex(s, "foo") + assert.NoError(t, err) + assert.Equal(t, status, 200) + }) +} + +// Ensure that a lock honours the timeout=0 (nowait) option when lock is already held +func TestModLockAcquireNoWaitWhileLocked(t *testing.T) { + tests.RunServer(func(s *server.Server) { + c := make(chan bool) + + // Acquire lock #1. + go func() { + body, status, err := testAcquireLock(s, "foo", "first", 10) + assert.NoError(t, err) + assert.Equal(t, status, 200) + assert.Equal(t, body, "2") + c <- true + }() + <-c + + // Attempt to acquire lock #2; fail if no lock immediately acquired + waiting := true + go func() { + c <- true + _, status, err := testAcquireLockWithTimeout(s, "foo", "second", 10, 0) + assert.NoError(t, err) + assert.Equal(t, status, 500) + waiting = false + }() + <-c + + time.Sleep(1 * time.Second) + + // Check that we have the lock #1. + body, status, err := testGetLockIndex(s, "foo") + assert.NoError(t, err) + assert.Equal(t, status, 200) + assert.Equal(t, body, "2") + + // Check that we are not still waiting for lock #2. + assert.Equal(t, waiting, false) + + // Release lock #1. + _, status, err = testReleaseLock(s, "foo", "2", "") + assert.NoError(t, err) + assert.Equal(t, status, 200) + + // Check that we have no lock. + body, status, err = testGetLockIndex(s, "foo") + assert.NoError(t, err) + assert.Equal(t, status, 200) + }) +} + func testAcquireLock(s *server.Server, key string, value string, ttl int) (string, int, error) { resp, err := tests.PostForm(fmt.Sprintf("%s/mod/v2/lock/%s?value=%s&ttl=%d", s.URL(), key, value, ttl), nil) ret := tests.ReadBody(resp) return string(ret), resp.StatusCode, err } +func testAcquireLockWithTimeout(s *server.Server, key string, value string, ttl int, timeout int) (string, int, error) { + resp, err := tests.PostForm(fmt.Sprintf("%s/mod/v2/lock/%s?value=%s&ttl=%d&timeout=%d", s.URL(), key, value, ttl, timeout), nil) + ret := tests.ReadBody(resp) + return string(ret), resp.StatusCode, err +} + func testGetLockIndex(s *server.Server, key string) (string, int, error) { resp, err := tests.Get(fmt.Sprintf("%s/mod/v2/lock/%s?field=index", s.URL(), key)) ret := tests.ReadBody(resp)