From 85301dbc4e65f7057dec4c37abdde805887e9d64 Mon Sep 17 00:00:00 2001 From: Ulexus Date: Mon, 24 Mar 2014 13:08:03 -0400 Subject: [PATCH] Add mod/lock timeout. Added timeout goroutine to stop waiting on lock after timeout expiration. This necessitated reworking the flow of the acquire handler. createNode now _only_ creates the node; it no longer waits on the lock itself. getLockIndex (perhaps this is poorly named) extracts out the lock checking routine so that it can be used by "get" and "watch", both. get() was added to instantaneously attempt to acquire a lock with no waiting. If a lock fails to acquire, for whatever reason, an error is returned, resulting in a code 500 to the client. --- Documentation/modules.md | 7 ++ mod/lock/v2/acquire_handler.go | 158 +++++++++++++++++++---------- mod/lock/v2/tests/mod_lock_test.go | 139 +++++++++++++++++++++++++ 3 files changed, 253 insertions(+), 51 deletions(-) diff --git a/Documentation/modules.md b/Documentation/modules.md index cc25ea6c2..9507c3235 100644 --- a/Documentation/modules.md +++ b/Documentation/modules.md @@ -15,6 +15,7 @@ Use the `-cors='*'` flag to allow your browser to request information from the c The Lock module implements a fair lock that can be used when lots of clients want access to a single resource. A lock can be associated with a value. The value is unique so if a lock tries to request a value that is already queued for a lock then it will find it and watch until that value obtains the lock. +You may supply a `timeout` which will cancel the lock request if it is not obtained within `timeout` seconds. If `timeout` is not supplied, it is presumed to be infinite. If `timeout` is `0`, the lock request will fail if it is not immediately acquired. If you lock the same value on a key from two separate curl sessions they'll both return at the same time. Here's the API: @@ -31,6 +32,12 @@ curl -X POST http://127.0.0.1:4001/mod/v2/lock/customer1?ttl=60 curl -X POST http://127.0.0.1:4001/mod/v2/lock/customer1?ttl=60 -d value=bar ``` +**Acquire a lock for "customer1" that is associated with the value "bar" only if it is done within 2 seconds** + +```sh +curl -X POST http://127.0.0.1:4001/mod/v2/lock/customer1?ttl=60 -d value=bar -d timeout=2 +``` + **Renew the TTL on the "customer1" lock for index 2** ```sh diff --git a/mod/lock/v2/acquire_handler.go b/mod/lock/v2/acquire_handler.go index 58fc1e8aa..00e2a489d 100644 --- a/mod/lock/v2/acquire_handler.go +++ b/mod/lock/v2/acquire_handler.go @@ -24,7 +24,22 @@ func (h *handler) acquireHandler(w http.ResponseWriter, req *http.Request) error // Setup connection watcher. closeNotifier, _ := w.(http.CloseNotifier) closeChan := closeNotifier.CloseNotify() + + // Wrap closeChan so we can pass it to subsequent components + timeoutChan := make(chan bool) stopChan := make(chan bool) + go func() { + select { + case <-closeChan: + // Client closed connection + stopChan <- true + case <-timeoutChan: + // Timeout expired + stopChan <- true + case <-stopChan: + } + close(stopChan) + }() // Parse the lock "key". vars := mux.Vars(req) @@ -39,7 +54,6 @@ func (h *handler) acquireHandler(w http.ResponseWriter, req *http.Request) error } else if timeout, err = strconv.Atoi(req.FormValue("timeout")); err != nil { return etcdErr.NewError(etcdErr.EcodeTimeoutNaN, "Acquire", 0) } - timeout = timeout + 1 // Parse TTL. ttl, err := strconv.Atoi(req.FormValue("ttl")) @@ -47,35 +61,65 @@ func (h *handler) acquireHandler(w http.ResponseWriter, req *http.Request) error return etcdErr.NewError(etcdErr.EcodeTTLNaN, "Acquire", 0) } - // If node exists then just watch it. Otherwise create the node and watch it. - node, index, pos := h.findExistingNode(keypath, value) - if index > 0 { - if pos == 0 { - // If lock is already acquired then update the TTL. - h.client.Update(node.Key, node.Value, uint64(ttl)) - } else { - // Otherwise watch until it becomes acquired (or errors). - err = h.watch(keypath, index, nil) + // Search for the node + _, index, pos := h.findExistingNode(keypath, value) + if index == 0 { + // Node doesn't exist; Create it + pos = -1 // Invalidate previous position + index, err = h.createNode(keypath, value, ttl) + if err != nil { + return err } - } else { - index, err = h.createNode(keypath, value, ttl, closeChan, stopChan) } - // Stop all goroutines. - close(stopChan) + indexpath := path.Join(keypath, strconv.Itoa(index)) - // Check for an error. + // If pos != 0, we do not already have the lock + if pos != 0 { + if timeout == 0 { + // Attempt to get lock once, no waiting + err = h.get(keypath, index) + } else { + // Keep updating TTL while we wait + go h.ttlKeepAlive(keypath, value, ttl, stopChan) + + // Start timeout + go h.timeoutExpire(timeout, timeoutChan, stopChan) + + // wait for lock + err = h.watch(keypath, index, stopChan) + } + } + + // Return on error, deleting our lock request on the way if err != nil { + if index > 0 { + h.client.Delete(indexpath, false) + } return err } + // Check for connection disconnect before we write the lock index. + select { + case <-stopChan: + err = errors.New("user interrupted") + default: + } + + // Update TTL one last time if lock was acquired. Otherwise delete. + if err == nil { + h.client.Update(indexpath, value, uint64(ttl)) + } else { + h.client.Delete(indexpath, false) + } + // Write response. w.Write([]byte(strconv.Itoa(index))) return nil } // createNode creates a new lock node and watches it until it is acquired or acquisition fails. -func (h *handler) createNode(keypath string, value string, ttl int, closeChan <-chan bool, stopChan chan bool) (int, error) { +func (h *handler) createNode(keypath string, value string, ttl int) (int, error) { // Default the value to "-" if it is blank. if len(value) == 0 { value = "-" @@ -87,30 +131,7 @@ func (h *handler) createNode(keypath string, value string, ttl int, closeChan <- return 0, err } indexpath := resp.Node.Key - index, _ := strconv.Atoi(path.Base(indexpath)) - - // Keep updating TTL to make sure lock request is not expired before acquisition. - go h.ttlKeepAlive(indexpath, value, ttl, stopChan) - - // Watch until we acquire or fail. - err = h.watch(keypath, index, closeChan) - - // Check for connection disconnect before we write the lock index. - if err != nil { - select { - case <-closeChan: - err = errors.New("user interrupted") - default: - } - } - - // Update TTL one last time if acquired. Otherwise delete. - if err == nil { - h.client.Update(indexpath, value, uint64(ttl)) - } else { - h.client.Delete(indexpath, false) - } - + index, err := strconv.Atoi(path.Base(indexpath)) return index, err } @@ -141,6 +162,47 @@ func (h *handler) ttlKeepAlive(k string, value string, ttl int, stopChan chan bo } } +// timeoutExpire sets the countdown timer is a positive integer +// cancels on stopChan, sends true on timeoutChan after timer expires +func (h *handler) timeoutExpire(timeout int, timeoutChan chan bool, stopChan chan bool) { + // Set expiration timer if timeout is 1 or higher + if timeout < 1 { + timeoutChan = nil + return + } + select { + case <-stopChan: + return + case <-time.After(time.Duration(timeout) * time.Second): + timeoutChan <- true + return + } +} + +func (h *handler) getLockIndex(keypath string, index int) (int, int, error) { + // Read all nodes for the lock. + resp, err := h.client.Get(keypath, true, true) + if err != nil { + return 0, 0, fmt.Errorf("lock watch lookup error: %s", err.Error()) + } + nodes := lockNodes{resp.Node.Nodes} + prevIndex, modifiedIndex := nodes.PrevIndex(index) + return prevIndex, modifiedIndex, nil +} + +// get tries once to get the lock; no waiting +func (h *handler) get(keypath string, index int) error { + prevIndex, _, err := h.getLockIndex(keypath, index) + if err != nil { + return err + } + if prevIndex == 0 { + // Lock acquired + return nil + } + return fmt.Errorf("failed to acquire lock") +} + // watch continuously waits for a given lock index to be acquired or until lock fails. // Returns a boolean indicating success. func (h *handler) watch(keypath string, index int, closeChan <-chan bool) error { @@ -151,22 +213,15 @@ func (h *handler) watch(keypath string, index int, closeChan <-chan bool) error select { case <-closeChan: stopWatchChan <- true - case <- stopWrapChan: + case <-stopWrapChan: stopWatchChan <- true - case <- stopWatchChan: + case <-stopWatchChan: } }() defer close(stopWrapChan) for { - // Read all nodes for the lock. - resp, err := h.client.Get(keypath, true, true) - if err != nil { - return fmt.Errorf("lock watch lookup error: %s", err.Error()) - } - nodes := lockNodes{resp.Node.Nodes} - prevIndex, modifiedIndex := nodes.PrevIndex(index) - + prevIndex, modifiedIndex, err := h.getLockIndex(keypath, index) // If there is no previous index then we have the lock. if prevIndex == 0 { return nil @@ -175,11 +230,12 @@ func (h *handler) watch(keypath string, index int, closeChan <-chan bool) error // Wait from the last modification of the node. waitIndex := modifiedIndex + 1 - resp, err = h.client.Watch(path.Join(keypath, strconv.Itoa(prevIndex)), uint64(waitIndex), false, nil, stopWatchChan) + _, err = h.client.Watch(path.Join(keypath, strconv.Itoa(prevIndex)), uint64(waitIndex), false, nil, stopWatchChan) if err == etcd.ErrWatchStoppedByUser { return fmt.Errorf("lock watch closed") } else if err != nil { return fmt.Errorf("lock watch error: %s", err.Error()) } + return nil } } diff --git a/mod/lock/v2/tests/mod_lock_test.go b/mod/lock/v2/tests/mod_lock_test.go index 083bb6c72..e33842052 100644 --- a/mod/lock/v2/tests/mod_lock_test.go +++ b/mod/lock/v2/tests/mod_lock_test.go @@ -215,12 +215,151 @@ func TestModLockAcquireAndReleaseByValue(t *testing.T) { }) } +// Ensure that a lock honours the timeout option +func TestModLockAcquireTimeout(t *testing.T) { + tests.RunServer(func(s *server.Server) { + c := make(chan bool) + + // Acquire lock #1. + go func() { + body, status, err := testAcquireLock(s, "foo", "first", 10) + assert.NoError(t, err) + assert.Equal(t, status, 200) + assert.Equal(t, body, "2") + c <- true + }() + <-c + + // Attempt to acquire lock #2, timing out after 1s. + waiting := true + go func() { + c <- true + _, status, err := testAcquireLockWithTimeout(s, "foo", "second", 10, 1) + assert.NoError(t, err) + assert.Equal(t, status, 500) + waiting = false + }() + <-c + + time.Sleep(5 * time.Second) + + // Check that we have the lock #1. + body, status, err := testGetLockIndex(s, "foo") + assert.NoError(t, err) + assert.Equal(t, status, 200) + assert.Equal(t, body, "2") + + // Check that we are not still waiting for lock #2. + assert.Equal(t, waiting, false) + + // Release lock #1. + _, status, err = testReleaseLock(s, "foo", "2", "") + assert.NoError(t, err) + assert.Equal(t, status, 200) + + // Check that we have no lock. + body, status, err = testGetLockIndex(s, "foo") + assert.NoError(t, err) + assert.Equal(t, status, 200) + }) +} + +// Ensure that a lock succeeds when timeout=0 (nowait) +func TestModLockAcquireNoWait(t *testing.T) { + tests.RunServer(func(s *server.Server) { + c := make(chan bool) + + // Acquire lock with no waiting. + go func() { + body, status, err := testAcquireLockWithTimeout(s, "foo", "first", 10, 0) + assert.NoError(t, err) + assert.Equal(t, status, 200) + assert.Equal(t, body, "2") + c <- true + }() + <-c + + time.Sleep(1 * time.Second) + + // Check that we have the lock #1. + body, status, err := testGetLockIndex(s, "foo") + assert.NoError(t, err) + assert.Equal(t, status, 200) + assert.Equal(t, body, "2") + + // Release lock #1. + _, status, err = testReleaseLock(s, "foo", "2", "") + assert.NoError(t, err) + assert.Equal(t, status, 200) + + // Check that we have no lock. + body, status, err = testGetLockIndex(s, "foo") + assert.NoError(t, err) + assert.Equal(t, status, 200) + }) +} + +// Ensure that a lock honours the timeout=0 (nowait) option when lock is already held +func TestModLockAcquireNoWaitWhileLocked(t *testing.T) { + tests.RunServer(func(s *server.Server) { + c := make(chan bool) + + // Acquire lock #1. + go func() { + body, status, err := testAcquireLock(s, "foo", "first", 10) + assert.NoError(t, err) + assert.Equal(t, status, 200) + assert.Equal(t, body, "2") + c <- true + }() + <-c + + // Attempt to acquire lock #2; fail if no lock immediately acquired + waiting := true + go func() { + c <- true + _, status, err := testAcquireLockWithTimeout(s, "foo", "second", 10, 0) + assert.NoError(t, err) + assert.Equal(t, status, 500) + waiting = false + }() + <-c + + time.Sleep(1 * time.Second) + + // Check that we have the lock #1. + body, status, err := testGetLockIndex(s, "foo") + assert.NoError(t, err) + assert.Equal(t, status, 200) + assert.Equal(t, body, "2") + + // Check that we are not still waiting for lock #2. + assert.Equal(t, waiting, false) + + // Release lock #1. + _, status, err = testReleaseLock(s, "foo", "2", "") + assert.NoError(t, err) + assert.Equal(t, status, 200) + + // Check that we have no lock. + body, status, err = testGetLockIndex(s, "foo") + assert.NoError(t, err) + assert.Equal(t, status, 200) + }) +} + func testAcquireLock(s *server.Server, key string, value string, ttl int) (string, int, error) { resp, err := tests.PostForm(fmt.Sprintf("%s/mod/v2/lock/%s?value=%s&ttl=%d", s.URL(), key, value, ttl), nil) ret := tests.ReadBody(resp) return string(ret), resp.StatusCode, err } +func testAcquireLockWithTimeout(s *server.Server, key string, value string, ttl int, timeout int) (string, int, error) { + resp, err := tests.PostForm(fmt.Sprintf("%s/mod/v2/lock/%s?value=%s&ttl=%d&timeout=%d", s.URL(), key, value, ttl, timeout), nil) + ret := tests.ReadBody(resp) + return string(ret), resp.StatusCode, err +} + func testGetLockIndex(s *server.Server, key string) (string, int, error) { resp, err := tests.Get(fmt.Sprintf("%s/mod/v2/lock/%s?field=index", s.URL(), key)) ret := tests.ReadBody(resp)