diff --git a/mod/lock/v2/acquire_handler.go b/mod/lock/v2/acquire_handler.go index d6fa2aacb..6da62f6cc 100644 --- a/mod/lock/v2/acquire_handler.go +++ b/mod/lock/v2/acquire_handler.go @@ -1,6 +1,8 @@ package v2 import ( + "errors" + "fmt" "net/http" "path" "strconv" @@ -12,6 +14,7 @@ import ( // acquireHandler attempts to acquire a lock on the given key. // The "key" parameter specifies the resource to lock. +// The "value" parameter specifies a value to associate with the lock. // The "ttl" parameter specifies how long the lock will persist for. // The "timeout" parameter specifies how long the request should wait for the lock. func (h *handler) acquireHandler(w http.ResponseWriter, req *http.Request) { @@ -20,109 +23,152 @@ func (h *handler) acquireHandler(w http.ResponseWriter, req *http.Request) { // Setup connection watcher. closeNotifier, _ := w.(http.CloseNotifier) closeChan := closeNotifier.CloseNotify() + stopChan := make(chan bool) - // Parse "key" and "ttl" query parameters. + // Parse the lock "key". vars := mux.Vars(req) keypath := path.Join(prefix, vars["key"]) - ttl, err := strconv.Atoi(req.FormValue("ttl")) - if err != nil { - http.Error(w, "invalid ttl: " + err.Error(), http.StatusInternalServerError) - return - } - + value := req.FormValue("value") + // Parse "timeout" parameter. var timeout int - if len(req.FormValue("timeout")) == 0 { + var err error + if req.FormValue("timeout") == "" { timeout = -1 } else if timeout, err = strconv.Atoi(req.FormValue("timeout")); err != nil { - http.Error(w, "invalid timeout: " + err.Error(), http.StatusInternalServerError) + http.Error(w, "invalid timeout: " + req.FormValue("timeout"), http.StatusInternalServerError) return } timeout = timeout + 1 - // Create an incrementing id for the lock. - resp, err := h.client.AddChild(keypath, "-", uint64(ttl)) + // Parse TTL. + ttl, err := strconv.Atoi(req.FormValue("ttl")) if err != nil { - http.Error(w, "add lock index error: " + err.Error(), http.StatusInternalServerError) + http.Error(w, "invalid ttl: " + req.FormValue("ttl"), http.StatusInternalServerError) return } - indexpath := resp.Node.Key - // Keep updating TTL to make sure lock request is not expired before acquisition. - stop := make(chan bool) - go h.ttlKeepAlive(indexpath, ttl, stop) - - // Monitor for broken connection. - stopWatchChan := make(chan bool) - go func() { - select { - case <-closeChan: - stopWatchChan <- true - case <-stop: - // Stop watching for connection disconnect. - } - }() - - // Extract the lock index. - index, _ := strconv.Atoi(path.Base(resp.Node.Key)) - - // Wait until we successfully get a lock or we get a failure. - var success bool - for { - // Read all indices. - resp, err = h.client.Get(keypath, true, true) - if err != nil { - http.Error(w, "lock children lookup error: " + err.Error(), http.StatusInternalServerError) - break - } - indices := extractResponseIndices(resp) - waitIndex := resp.Node.ModifiedIndex - prevIndex := findPrevIndex(indices, index) - - // If there is no previous index then we have the lock. - if prevIndex == 0 { - success = true - break - } - - // Otherwise watch previous index until it's gone. - _, err = h.client.Watch(path.Join(keypath, strconv.Itoa(prevIndex)), waitIndex, false, nil, stopWatchChan) - if err == etcd.ErrWatchStoppedByUser { - break - } else if err != nil { - http.Error(w, "lock watch error: " + err.Error(), http.StatusInternalServerError) - break - } - } - - // Check for connection disconnect before we write the lock index. - select { - case <-stopWatchChan: - success = false - default: - } - - // Stop the ttl keep-alive. - close(stop) - - if success { - // Write lock index to response body if we acquire the lock. - h.client.Update(indexpath, "-", uint64(ttl)) - w.Write([]byte(strconv.Itoa(index))) + // If node exists then just watch it. Otherwise create the node and watch it. + index := h.findExistingNode(keypath, value) + if index > 0 { + err = h.watch(keypath, index, nil) } else { - // Make sure key is deleted if we couldn't acquire. - h.client.Delete(indexpath, false) + index, err = h.createNode(keypath, value, ttl, closeChan, stopChan) + } + + // Stop all goroutines. + close(stopChan) + + // Write response. + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } else { + w.Write([]byte(strconv.Itoa(index))) } } +// createNode creates a new lock node and watches it until it is acquired or acquisition fails. +func (h *handler) createNode(keypath string, value string, ttl int, closeChan <- chan bool, stopChan chan bool) (int, error) { + // Default the value to "-" if it is blank. + if len(value) == 0 { + value = "-" + } + + // Create an incrementing id for the lock. + resp, err := h.client.AddChild(keypath, value, uint64(ttl)) + if err != nil { + return 0, errors.New("acquire lock index error: " + err.Error()) + } + indexpath := resp.Node.Key + index, _ := strconv.Atoi(path.Base(indexpath)) + + // Keep updating TTL to make sure lock request is not expired before acquisition. + go h.ttlKeepAlive(indexpath, value, ttl, stopChan) + + // Watch until we acquire or fail. + err = h.watch(keypath, index, closeChan) + + // Check for connection disconnect before we write the lock index. + if err != nil { + select { + case <-closeChan: + err = errors.New("acquire lock error: user interrupted") + default: + } + } + + // Update TTL one last time if acquired. Otherwise delete. + if err == nil { + h.client.Update(indexpath, value, uint64(ttl)) + } else { + h.client.Delete(indexpath, false) + } + + return index, err +} + +// findExistingNode search for a node on the lock with the given value. +func (h *handler) findExistingNode(keypath string, value string) int { + if len(value) > 0 { + resp, err := h.client.Get(keypath, true, true) + if err == nil { + nodes := lockNodes{resp.Node.Nodes} + if node := nodes.FindByValue(value); node != nil { + index, _ := strconv.Atoi(path.Base(node.Key)) + return index + } + } + } + return 0 +} + // ttlKeepAlive continues to update a key's TTL until the stop channel is closed. -func (h *handler) ttlKeepAlive(k string, ttl int, stop chan bool) { +func (h *handler) ttlKeepAlive(k string, value string, ttl int, stopChan chan bool) { for { select { case <-time.After(time.Duration(ttl / 2) * time.Second): - h.client.Update(k, "-", uint64(ttl)) - case <-stop: + h.client.Update(k, value, uint64(ttl)) + case <-stopChan: return } } } + +// watch continuously waits for a given lock index to be acquired or until lock fails. +// Returns a boolean indicating success. +func (h *handler) watch(keypath string, index int, closeChan <- chan bool) error { + // Wrap close chan so we can pass it to Client.Watch(). + stopWatchChan := make(chan bool) + go func() { + select { + case <- closeChan: + stopWatchChan <- true + case <- stopWatchChan: + } + }() + defer close(stopWatchChan) + + for { + // Read all nodes for the lock. + resp, err := h.client.Get(keypath, true, true) + if err != nil { + return fmt.Errorf("lock watch lookup error: %s", err.Error()) + } + waitIndex := resp.Node.ModifiedIndex + nodes := lockNodes{resp.Node.Nodes} + prevIndex := nodes.PrevIndex(index) + + // If there is no previous index then we have the lock. + if prevIndex == 0 { + return nil + } + + // Watch previous index until it's gone. + _, err = h.client.Watch(path.Join(keypath, strconv.Itoa(prevIndex)), waitIndex, false, nil, stopWatchChan) + if err == etcd.ErrWatchStoppedByUser { + return fmt.Errorf("lock watch closed") + } else if err != nil { + return fmt.Errorf("lock watch error:%s", err.Error()) + } + } +} diff --git a/mod/lock/v2/get_index_handler.go b/mod/lock/v2/get_index_handler.go index 73ea663ff..3473defe4 100644 --- a/mod/lock/v2/get_index_handler.go +++ b/mod/lock/v2/get_index_handler.go @@ -3,28 +3,41 @@ package v2 import ( "net/http" "path" - "strconv" "github.com/gorilla/mux" ) // getIndexHandler retrieves the current lock index. +// The "field" parameter specifies to read either the lock "index" or lock "value". func (h *handler) getIndexHandler(w http.ResponseWriter, req *http.Request) { h.client.SyncCluster() vars := mux.Vars(req) keypath := path.Join(prefix, vars["key"]) + field := req.FormValue("field") + if len(field) == 0 { + field = "value" + } // Read all indices. resp, err := h.client.Get(keypath, true, true) if err != nil { - http.Error(w, "lock children lookup error: " + err.Error(), http.StatusInternalServerError) + http.Error(w, "read lock error: " + err.Error(), http.StatusInternalServerError) return } + nodes := lockNodes{resp.Node.Nodes} - // Write out the index of the last one to the response body. - indices := extractResponseIndices(resp) - if len(indices) > 0 { - w.Write([]byte(strconv.Itoa(indices[0]))) + // Write out the requested field. + if node := nodes.First(); node != nil { + switch field { + case "index": + w.Write([]byte(path.Base(node.Key))) + + case "value": + w.Write([]byte(node.Value)) + + default: + http.Error(w, "read lock error: invalid field: " + field, http.StatusInternalServerError) + } } } diff --git a/mod/lock/v2/handler.go b/mod/lock/v2/handler.go index 33d25242d..3a84e1b68 100644 --- a/mod/lock/v2/handler.go +++ b/mod/lock/v2/handler.go @@ -2,9 +2,6 @@ package v2 import ( "net/http" - "path" - "strconv" - "sort" "github.com/gorilla/mux" "github.com/coreos/go-etcd/etcd" @@ -27,32 +24,7 @@ func NewHandler(addr string) (http.Handler) { h.StrictSlash(false) h.HandleFunc("/{key:.*}", h.getIndexHandler).Methods("GET") h.HandleFunc("/{key:.*}", h.acquireHandler).Methods("POST") - h.HandleFunc("/{key_with_index:.*}", h.renewLockHandler).Methods("PUT") - h.HandleFunc("/{key_with_index:.*}", h.releaseLockHandler).Methods("DELETE") + h.HandleFunc("/{key:.*}", h.renewLockHandler).Methods("PUT") + h.HandleFunc("/{key:.*}", h.releaseLockHandler).Methods("DELETE") return h } - - -// extractResponseIndices extracts a sorted list of indicies from a response. -func extractResponseIndices(resp *etcd.Response) []int { - var indices []int - for _, node := range resp.Node.Nodes { - if index, _ := strconv.Atoi(path.Base(node.Key)); index > 0 { - indices = append(indices, index) - } - } - sort.Ints(indices) - return indices -} - -// findPrevIndex retrieves the previous index before the given index. -func findPrevIndex(indices []int, idx int) int { - var prevIndex int - for _, index := range indices { - if index == idx { - break - } - prevIndex = index - } - return prevIndex -} diff --git a/mod/lock/v2/lock_nodes.go b/mod/lock/v2/lock_nodes.go new file mode 100644 index 000000000..92446ee3b --- /dev/null +++ b/mod/lock/v2/lock_nodes.go @@ -0,0 +1,57 @@ +package v2 + +import ( + "path" + "sort" + "strconv" + + "github.com/coreos/go-etcd/etcd" +) + +// lockNodes is a wrapper for go-etcd's Nodes to allow for sorting by numeric key. +type lockNodes struct { + etcd.Nodes +} + +// Less sorts the nodes by key (numerically). +func (s lockNodes) Less(i, j int) bool { + a, _ := strconv.Atoi(path.Base(s.Nodes[i].Key)) + b, _ := strconv.Atoi(path.Base(s.Nodes[j].Key)) + return a < b +} + +// Retrieves the first node in the set of lock nodes. +func (s lockNodes) First() *etcd.Node { + sort.Sort(s) + if len(s.Nodes) > 0 { + return &s.Nodes[0] + } + return nil +} + +// Retrieves the first node with a given value. +func (s lockNodes) FindByValue(value string) *etcd.Node { + sort.Sort(s) + + for _, node := range s.Nodes { + if node.Value == value { + return &node + } + } + return nil +} + +// Retrieves the index that occurs before a given index. +func (s lockNodes) PrevIndex(index int) int { + sort.Sort(s) + + var prevIndex int + for _, node := range s.Nodes { + idx, _ := strconv.Atoi(path.Base(node.Key)) + if index == idx { + return prevIndex + } + prevIndex = idx + } + return 0 +} diff --git a/mod/lock/v2/release_handler.go b/mod/lock/v2/release_handler.go index 998fdc51e..f67a769d1 100644 --- a/mod/lock/v2/release_handler.go +++ b/mod/lock/v2/release_handler.go @@ -12,12 +12,39 @@ func (h *handler) releaseLockHandler(w http.ResponseWriter, req *http.Request) { h.client.SyncCluster() vars := mux.Vars(req) - keypath := path.Join(prefix, vars["key_with_index"]) + keypath := path.Join(prefix, vars["key"]) + + // Read index and value parameters. + index := req.FormValue("index") + value := req.FormValue("value") + if len(index) == 0 && len(value) == 0 { + http.Error(w, "release lock error: index or value required", http.StatusInternalServerError) + return + } else if len(index) != 0 && len(value) != 0 { + http.Error(w, "release lock error: index and value cannot both be specified", http.StatusInternalServerError) + return + } + + // Look up index by value if index is missing. + if len(index) == 0 { + resp, err := h.client.Get(keypath, true, true) + if err != nil { + http.Error(w, "release lock index error: " + err.Error(), http.StatusInternalServerError) + return + } + nodes := lockNodes{resp.Node.Nodes} + node := nodes.FindByValue(value) + if node == nil { + http.Error(w, "release lock error: cannot find: " + value, http.StatusInternalServerError) + return + } + index = path.Base(node.Key) + } // Delete the lock. - _, err := h.client.Delete(keypath, false) + _, err := h.client.Delete(path.Join(keypath, index), false) if err != nil { - http.Error(w, "delete lock index error: " + err.Error(), http.StatusInternalServerError) + http.Error(w, "release lock error: " + err.Error(), http.StatusInternalServerError) return } } diff --git a/mod/lock/v2/renew_handler.go b/mod/lock/v2/renew_handler.go index cdd65b3aa..951b52c32 100644 --- a/mod/lock/v2/renew_handler.go +++ b/mod/lock/v2/renew_handler.go @@ -13,18 +13,55 @@ import ( func (h *handler) renewLockHandler(w http.ResponseWriter, req *http.Request) { h.client.SyncCluster() + // Read the lock path. vars := mux.Vars(req) - keypath := path.Join(prefix, vars["key_with_index"]) + keypath := path.Join(prefix, vars["key"]) + + // Parse new TTL parameter. ttl, err := strconv.Atoi(req.FormValue("ttl")) if err != nil { http.Error(w, "invalid ttl: " + err.Error(), http.StatusInternalServerError) return } + // Read and set defaults for index and value. + index := req.FormValue("index") + value := req.FormValue("value") + if len(index) == 0 && len(value) == 0 { + // The index or value is required. + http.Error(w, "renew lock error: index or value required", http.StatusInternalServerError) + return + } + + if len(index) == 0 { + // If index is not specified then look it up by value. + resp, err := h.client.Get(keypath, true, true) + if err != nil { + http.Error(w, "renew lock index error: " + err.Error(), http.StatusInternalServerError) + return + } + nodes := lockNodes{resp.Node.Nodes} + node := nodes.FindByValue(value) + if node == nil { + http.Error(w, "renew lock error: cannot find: " + value, http.StatusInternalServerError) + return + } + index = path.Base(node.Key) + + } else if len(value) == 0 { + // If value is not specified then default it to the previous value. + resp, err := h.client.Get(path.Join(keypath, index), true, false) + if err != nil { + http.Error(w, "renew lock value error: " + err.Error(), http.StatusInternalServerError) + return + } + value = resp.Node.Value + } + // Renew the lock, if it exists. - _, err = h.client.Update(keypath, "-", uint64(ttl)) + _, err = h.client.Update(path.Join(keypath, index), value, uint64(ttl)) if err != nil { - http.Error(w, "renew lock index error: " + err.Error(), http.StatusInternalServerError) + http.Error(w, "renew lock error: " + err.Error(), http.StatusInternalServerError) return } } diff --git a/mod/lock/v2/tests/handler_test.go b/mod/lock/v2/tests/mod_lock_test.go similarity index 62% rename from mod/lock/v2/tests/handler_test.go rename to mod/lock/v2/tests/mod_lock_test.go index b07572bbe..d135290f2 100644 --- a/mod/lock/v2/tests/handler_test.go +++ b/mod/lock/v2/tests/mod_lock_test.go @@ -14,7 +14,7 @@ import ( func TestModLockAcquireAndRelease(t *testing.T) { tests.RunServer(func(s *server.Server) { // Acquire lock. - body, err := testAcquireLock(s, "foo", 10) + body, err := testAcquireLock(s, "foo", "", 10) assert.NoError(t, err) assert.Equal(t, body, "2") @@ -24,7 +24,7 @@ func TestModLockAcquireAndRelease(t *testing.T) { assert.Equal(t, body, "2") // Release lock. - body, err = testReleaseLock(s, "foo", 2) + body, err = testReleaseLock(s, "foo", "2", "") assert.NoError(t, err) assert.Equal(t, body, "") @@ -42,7 +42,7 @@ func TestModLockBlockUntilAcquire(t *testing.T) { // Acquire lock #1. go func() { - body, err := testAcquireLock(s, "foo", 10) + body, err := testAcquireLock(s, "foo", "", 10) assert.NoError(t, err) assert.Equal(t, body, "2") c <- true @@ -50,11 +50,13 @@ func TestModLockBlockUntilAcquire(t *testing.T) { <- c // Acquire lock #2. + waiting := true go func() { c <- true - body, err := testAcquireLock(s, "foo", 10) + body, err := testAcquireLock(s, "foo", "", 10) assert.NoError(t, err) assert.Equal(t, body, "4") + waiting = false }() <- c @@ -65,8 +67,11 @@ func TestModLockBlockUntilAcquire(t *testing.T) { assert.NoError(t, err) assert.Equal(t, body, "2") + // Check that we are still waiting for lock #2. + assert.Equal(t, waiting, true) + // Release lock #1. - body, err = testReleaseLock(s, "foo", 2) + body, err = testReleaseLock(s, "foo", "2", "") assert.NoError(t, err) // Check that we have lock #2. @@ -75,7 +80,7 @@ func TestModLockBlockUntilAcquire(t *testing.T) { assert.Equal(t, body, "4") // Release lock #2. - body, err = testReleaseLock(s, "foo", 4) + body, err = testReleaseLock(s, "foo", "4", "") assert.NoError(t, err) // Check that we have no lock. @@ -92,7 +97,7 @@ func TestModLockExpireAndRelease(t *testing.T) { // Acquire lock #1. go func() { - body, err := testAcquireLock(s, "foo", 2) + body, err := testAcquireLock(s, "foo", "", 2) assert.NoError(t, err) assert.Equal(t, body, "2") c <- true @@ -102,7 +107,7 @@ func TestModLockExpireAndRelease(t *testing.T) { // Acquire lock #2. go func() { c <- true - body, err := testAcquireLock(s, "foo", 10) + body, err := testAcquireLock(s, "foo", "", 10) assert.NoError(t, err) assert.Equal(t, body, "4") }() @@ -129,7 +134,7 @@ func TestModLockExpireAndRelease(t *testing.T) { func TestModLockRenew(t *testing.T) { tests.RunServer(func(s *server.Server) { // Acquire lock. - body, err := testAcquireLock(s, "foo", 3) + body, err := testAcquireLock(s, "foo", "", 3) assert.NoError(t, err) assert.Equal(t, body, "2") @@ -141,7 +146,7 @@ func TestModLockRenew(t *testing.T) { assert.Equal(t, body, "2") // Renew lock. - body, err = testRenewLock(s, "foo", 2, 3) + body, err = testRenewLock(s, "foo", "2", "", 3) assert.NoError(t, err) assert.Equal(t, body, "") @@ -161,28 +166,59 @@ func TestModLockRenew(t *testing.T) { }) } +// Ensure that a lock can be acquired with a value and released by value. +func TestModLockAcquireAndReleaseByValue(t *testing.T) { + tests.RunServer(func(s *server.Server) { + // Acquire lock. + body, err := testAcquireLock(s, "foo", "XXX", 10) + assert.NoError(t, err) + assert.Equal(t, body, "2") + + // Check that we have the lock. + body, err = testGetLockValue(s, "foo") + assert.NoError(t, err) + assert.Equal(t, body, "XXX") + + // Release lock. + body, err = testReleaseLock(s, "foo", "", "XXX") + assert.NoError(t, err) + assert.Equal(t, body, "") + + // Check that we released the lock. + body, err = testGetLockValue(s, "foo") + assert.NoError(t, err) + assert.Equal(t, body, "") + }) +} -func testAcquireLock(s *server.Server, key string, ttl int) (string, error) { - resp, err := tests.PostForm(fmt.Sprintf("%s/mod/v2/lock/%s?ttl=%d", s.URL(), key, ttl), nil) + +func testAcquireLock(s *server.Server, key string, value string, ttl int) (string, error) { + resp, err := tests.PostForm(fmt.Sprintf("%s/mod/v2/lock/%s?value=%s&ttl=%d", s.URL(), key, value, ttl), nil) ret := tests.ReadBody(resp) return string(ret), err } func testGetLockIndex(s *server.Server, key string) (string, error) { + resp, err := tests.Get(fmt.Sprintf("%s/mod/v2/lock/%s?field=index", s.URL(), key)) + ret := tests.ReadBody(resp) + return string(ret), err +} + +func testGetLockValue(s *server.Server, key string) (string, error) { resp, err := tests.Get(fmt.Sprintf("%s/mod/v2/lock/%s", s.URL(), key)) ret := tests.ReadBody(resp) return string(ret), err } -func testReleaseLock(s *server.Server, key string, index int) (string, error) { - resp, err := tests.DeleteForm(fmt.Sprintf("%s/mod/v2/lock/%s/%d", s.URL(), key, index), nil) +func testReleaseLock(s *server.Server, key string, index string, value string) (string, error) { + resp, err := tests.DeleteForm(fmt.Sprintf("%s/mod/v2/lock/%s?index=%s&value=%s", s.URL(), key, index, value), nil) ret := tests.ReadBody(resp) return string(ret), err } -func testRenewLock(s *server.Server, key string, index int, ttl int) (string, error) { - resp, err := tests.PutForm(fmt.Sprintf("%s/mod/v2/lock/%s/%d?ttl=%d", s.URL(), key, index, ttl), nil) +func testRenewLock(s *server.Server, key string, index string, value string, ttl int) (string, error) { + resp, err := tests.PutForm(fmt.Sprintf("%s/mod/v2/lock/%s?index=%s&value=%s&ttl=%d", s.URL(), key, index, value, ttl), nil) ret := tests.ReadBody(resp) return string(ret), err }