mirror of
https://github.com/etcd-io/etcd.git
synced 2024-09-27 06:25:44 +00:00
Refactor mod/lock.
This commit is contained in:
parent
8442e7a0dc
commit
4bec461db1
@ -1,6 +1,8 @@
|
|||||||
package v2
|
package v2
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"path"
|
"path"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -12,6 +14,7 @@ import (
|
|||||||
|
|
||||||
// acquireHandler attempts to acquire a lock on the given key.
|
// acquireHandler attempts to acquire a lock on the given key.
|
||||||
// The "key" parameter specifies the resource to lock.
|
// The "key" parameter specifies the resource to lock.
|
||||||
|
// The "value" parameter specifies a value to associate with the lock.
|
||||||
// The "ttl" parameter specifies how long the lock will persist for.
|
// The "ttl" parameter specifies how long the lock will persist for.
|
||||||
// The "timeout" parameter specifies how long the request should wait for the lock.
|
// The "timeout" parameter specifies how long the request should wait for the lock.
|
||||||
func (h *handler) acquireHandler(w http.ResponseWriter, req *http.Request) {
|
func (h *handler) acquireHandler(w http.ResponseWriter, req *http.Request) {
|
||||||
@ -20,109 +23,152 @@ func (h *handler) acquireHandler(w http.ResponseWriter, req *http.Request) {
|
|||||||
// Setup connection watcher.
|
// Setup connection watcher.
|
||||||
closeNotifier, _ := w.(http.CloseNotifier)
|
closeNotifier, _ := w.(http.CloseNotifier)
|
||||||
closeChan := closeNotifier.CloseNotify()
|
closeChan := closeNotifier.CloseNotify()
|
||||||
|
stopChan := make(chan bool)
|
||||||
|
|
||||||
// Parse "key" and "ttl" query parameters.
|
// Parse the lock "key".
|
||||||
vars := mux.Vars(req)
|
vars := mux.Vars(req)
|
||||||
keypath := path.Join(prefix, vars["key"])
|
keypath := path.Join(prefix, vars["key"])
|
||||||
ttl, err := strconv.Atoi(req.FormValue("ttl"))
|
value := req.FormValue("value")
|
||||||
if err != nil {
|
|
||||||
http.Error(w, "invalid ttl: " + err.Error(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse "timeout" parameter.
|
// Parse "timeout" parameter.
|
||||||
var timeout int
|
var timeout int
|
||||||
if len(req.FormValue("timeout")) == 0 {
|
var err error
|
||||||
|
if req.FormValue("timeout") == "" {
|
||||||
timeout = -1
|
timeout = -1
|
||||||
} else if timeout, err = strconv.Atoi(req.FormValue("timeout")); err != nil {
|
} else if timeout, err = strconv.Atoi(req.FormValue("timeout")); err != nil {
|
||||||
http.Error(w, "invalid timeout: " + err.Error(), http.StatusInternalServerError)
|
http.Error(w, "invalid timeout: " + req.FormValue("timeout"), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
timeout = timeout + 1
|
timeout = timeout + 1
|
||||||
|
|
||||||
// Create an incrementing id for the lock.
|
// Parse TTL.
|
||||||
resp, err := h.client.AddChild(keypath, "-", uint64(ttl))
|
ttl, err := strconv.Atoi(req.FormValue("ttl"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "add lock index error: " + err.Error(), http.StatusInternalServerError)
|
http.Error(w, "invalid ttl: " + req.FormValue("ttl"), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
indexpath := resp.Node.Key
|
|
||||||
|
|
||||||
// Keep updating TTL to make sure lock request is not expired before acquisition.
|
// If node exists then just watch it. Otherwise create the node and watch it.
|
||||||
stop := make(chan bool)
|
index := h.findExistingNode(keypath, value)
|
||||||
go h.ttlKeepAlive(indexpath, ttl, stop)
|
if index > 0 {
|
||||||
|
err = h.watch(keypath, index, nil)
|
||||||
// Monitor for broken connection.
|
|
||||||
stopWatchChan := make(chan bool)
|
|
||||||
go func() {
|
|
||||||
select {
|
|
||||||
case <-closeChan:
|
|
||||||
stopWatchChan <- true
|
|
||||||
case <-stop:
|
|
||||||
// Stop watching for connection disconnect.
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Extract the lock index.
|
|
||||||
index, _ := strconv.Atoi(path.Base(resp.Node.Key))
|
|
||||||
|
|
||||||
// Wait until we successfully get a lock or we get a failure.
|
|
||||||
var success bool
|
|
||||||
for {
|
|
||||||
// Read all indices.
|
|
||||||
resp, err = h.client.Get(keypath, true, true)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, "lock children lookup error: " + err.Error(), http.StatusInternalServerError)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
indices := extractResponseIndices(resp)
|
|
||||||
waitIndex := resp.Node.ModifiedIndex
|
|
||||||
prevIndex := findPrevIndex(indices, index)
|
|
||||||
|
|
||||||
// If there is no previous index then we have the lock.
|
|
||||||
if prevIndex == 0 {
|
|
||||||
success = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// Otherwise watch previous index until it's gone.
|
|
||||||
_, err = h.client.Watch(path.Join(keypath, strconv.Itoa(prevIndex)), waitIndex, false, nil, stopWatchChan)
|
|
||||||
if err == etcd.ErrWatchStoppedByUser {
|
|
||||||
break
|
|
||||||
} else if err != nil {
|
|
||||||
http.Error(w, "lock watch error: " + err.Error(), http.StatusInternalServerError)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for connection disconnect before we write the lock index.
|
|
||||||
select {
|
|
||||||
case <-stopWatchChan:
|
|
||||||
success = false
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop the ttl keep-alive.
|
|
||||||
close(stop)
|
|
||||||
|
|
||||||
if success {
|
|
||||||
// Write lock index to response body if we acquire the lock.
|
|
||||||
h.client.Update(indexpath, "-", uint64(ttl))
|
|
||||||
w.Write([]byte(strconv.Itoa(index)))
|
|
||||||
} else {
|
} else {
|
||||||
// Make sure key is deleted if we couldn't acquire.
|
index, err = h.createNode(keypath, value, ttl, closeChan, stopChan)
|
||||||
h.client.Delete(indexpath, false)
|
}
|
||||||
|
|
||||||
|
// Stop all goroutines.
|
||||||
|
close(stopChan)
|
||||||
|
|
||||||
|
// Write response.
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
} else {
|
||||||
|
w.Write([]byte(strconv.Itoa(index)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// createNode creates a new lock node and watches it until it is acquired or acquisition fails.
|
||||||
|
func (h *handler) createNode(keypath string, value string, ttl int, closeChan <- chan bool, stopChan chan bool) (int, error) {
|
||||||
|
// Default the value to "-" if it is blank.
|
||||||
|
if len(value) == 0 {
|
||||||
|
value = "-"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create an incrementing id for the lock.
|
||||||
|
resp, err := h.client.AddChild(keypath, value, uint64(ttl))
|
||||||
|
if err != nil {
|
||||||
|
return 0, errors.New("acquire lock index error: " + err.Error())
|
||||||
|
}
|
||||||
|
indexpath := resp.Node.Key
|
||||||
|
index, _ := strconv.Atoi(path.Base(indexpath))
|
||||||
|
|
||||||
|
// Keep updating TTL to make sure lock request is not expired before acquisition.
|
||||||
|
go h.ttlKeepAlive(indexpath, value, ttl, stopChan)
|
||||||
|
|
||||||
|
// Watch until we acquire or fail.
|
||||||
|
err = h.watch(keypath, index, closeChan)
|
||||||
|
|
||||||
|
// Check for connection disconnect before we write the lock index.
|
||||||
|
if err != nil {
|
||||||
|
select {
|
||||||
|
case <-closeChan:
|
||||||
|
err = errors.New("acquire lock error: user interrupted")
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update TTL one last time if acquired. Otherwise delete.
|
||||||
|
if err == nil {
|
||||||
|
h.client.Update(indexpath, value, uint64(ttl))
|
||||||
|
} else {
|
||||||
|
h.client.Delete(indexpath, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
return index, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// findExistingNode search for a node on the lock with the given value.
|
||||||
|
func (h *handler) findExistingNode(keypath string, value string) int {
|
||||||
|
if len(value) > 0 {
|
||||||
|
resp, err := h.client.Get(keypath, true, true)
|
||||||
|
if err == nil {
|
||||||
|
nodes := lockNodes{resp.Node.Nodes}
|
||||||
|
if node := nodes.FindByValue(value); node != nil {
|
||||||
|
index, _ := strconv.Atoi(path.Base(node.Key))
|
||||||
|
return index
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
// ttlKeepAlive continues to update a key's TTL until the stop channel is closed.
|
// ttlKeepAlive continues to update a key's TTL until the stop channel is closed.
|
||||||
func (h *handler) ttlKeepAlive(k string, ttl int, stop chan bool) {
|
func (h *handler) ttlKeepAlive(k string, value string, ttl int, stopChan chan bool) {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-time.After(time.Duration(ttl / 2) * time.Second):
|
case <-time.After(time.Duration(ttl / 2) * time.Second):
|
||||||
h.client.Update(k, "-", uint64(ttl))
|
h.client.Update(k, value, uint64(ttl))
|
||||||
case <-stop:
|
case <-stopChan:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// watch continuously waits for a given lock index to be acquired or until lock fails.
|
||||||
|
// Returns a boolean indicating success.
|
||||||
|
func (h *handler) watch(keypath string, index int, closeChan <- chan bool) error {
|
||||||
|
// Wrap close chan so we can pass it to Client.Watch().
|
||||||
|
stopWatchChan := make(chan bool)
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <- closeChan:
|
||||||
|
stopWatchChan <- true
|
||||||
|
case <- stopWatchChan:
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
defer close(stopWatchChan)
|
||||||
|
|
||||||
|
for {
|
||||||
|
// Read all nodes for the lock.
|
||||||
|
resp, err := h.client.Get(keypath, true, true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("lock watch lookup error: %s", err.Error())
|
||||||
|
}
|
||||||
|
waitIndex := resp.Node.ModifiedIndex
|
||||||
|
nodes := lockNodes{resp.Node.Nodes}
|
||||||
|
prevIndex := nodes.PrevIndex(index)
|
||||||
|
|
||||||
|
// If there is no previous index then we have the lock.
|
||||||
|
if prevIndex == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Watch previous index until it's gone.
|
||||||
|
_, err = h.client.Watch(path.Join(keypath, strconv.Itoa(prevIndex)), waitIndex, false, nil, stopWatchChan)
|
||||||
|
if err == etcd.ErrWatchStoppedByUser {
|
||||||
|
return fmt.Errorf("lock watch closed")
|
||||||
|
} else if err != nil {
|
||||||
|
return fmt.Errorf("lock watch error:%s", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -3,28 +3,41 @@ package v2
|
|||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"path"
|
"path"
|
||||||
"strconv"
|
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
)
|
)
|
||||||
|
|
||||||
// getIndexHandler retrieves the current lock index.
|
// getIndexHandler retrieves the current lock index.
|
||||||
|
// The "field" parameter specifies to read either the lock "index" or lock "value".
|
||||||
func (h *handler) getIndexHandler(w http.ResponseWriter, req *http.Request) {
|
func (h *handler) getIndexHandler(w http.ResponseWriter, req *http.Request) {
|
||||||
h.client.SyncCluster()
|
h.client.SyncCluster()
|
||||||
|
|
||||||
vars := mux.Vars(req)
|
vars := mux.Vars(req)
|
||||||
keypath := path.Join(prefix, vars["key"])
|
keypath := path.Join(prefix, vars["key"])
|
||||||
|
field := req.FormValue("field")
|
||||||
|
if len(field) == 0 {
|
||||||
|
field = "value"
|
||||||
|
}
|
||||||
|
|
||||||
// Read all indices.
|
// Read all indices.
|
||||||
resp, err := h.client.Get(keypath, true, true)
|
resp, err := h.client.Get(keypath, true, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "lock children lookup error: " + err.Error(), http.StatusInternalServerError)
|
http.Error(w, "read lock error: " + err.Error(), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
nodes := lockNodes{resp.Node.Nodes}
|
||||||
|
|
||||||
// Write out the index of the last one to the response body.
|
// Write out the requested field.
|
||||||
indices := extractResponseIndices(resp)
|
if node := nodes.First(); node != nil {
|
||||||
if len(indices) > 0 {
|
switch field {
|
||||||
w.Write([]byte(strconv.Itoa(indices[0])))
|
case "index":
|
||||||
|
w.Write([]byte(path.Base(node.Key)))
|
||||||
|
|
||||||
|
case "value":
|
||||||
|
w.Write([]byte(node.Value))
|
||||||
|
|
||||||
|
default:
|
||||||
|
http.Error(w, "read lock error: invalid field: " + field, http.StatusInternalServerError)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2,9 +2,6 @@ package v2
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"path"
|
|
||||||
"strconv"
|
|
||||||
"sort"
|
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/coreos/go-etcd/etcd"
|
"github.com/coreos/go-etcd/etcd"
|
||||||
@ -27,32 +24,7 @@ func NewHandler(addr string) (http.Handler) {
|
|||||||
h.StrictSlash(false)
|
h.StrictSlash(false)
|
||||||
h.HandleFunc("/{key:.*}", h.getIndexHandler).Methods("GET")
|
h.HandleFunc("/{key:.*}", h.getIndexHandler).Methods("GET")
|
||||||
h.HandleFunc("/{key:.*}", h.acquireHandler).Methods("POST")
|
h.HandleFunc("/{key:.*}", h.acquireHandler).Methods("POST")
|
||||||
h.HandleFunc("/{key_with_index:.*}", h.renewLockHandler).Methods("PUT")
|
h.HandleFunc("/{key:.*}", h.renewLockHandler).Methods("PUT")
|
||||||
h.HandleFunc("/{key_with_index:.*}", h.releaseLockHandler).Methods("DELETE")
|
h.HandleFunc("/{key:.*}", h.releaseLockHandler).Methods("DELETE")
|
||||||
return h
|
return h
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// extractResponseIndices extracts a sorted list of indicies from a response.
|
|
||||||
func extractResponseIndices(resp *etcd.Response) []int {
|
|
||||||
var indices []int
|
|
||||||
for _, node := range resp.Node.Nodes {
|
|
||||||
if index, _ := strconv.Atoi(path.Base(node.Key)); index > 0 {
|
|
||||||
indices = append(indices, index)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
sort.Ints(indices)
|
|
||||||
return indices
|
|
||||||
}
|
|
||||||
|
|
||||||
// findPrevIndex retrieves the previous index before the given index.
|
|
||||||
func findPrevIndex(indices []int, idx int) int {
|
|
||||||
var prevIndex int
|
|
||||||
for _, index := range indices {
|
|
||||||
if index == idx {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
prevIndex = index
|
|
||||||
}
|
|
||||||
return prevIndex
|
|
||||||
}
|
|
||||||
|
57
mod/lock/v2/lock_nodes.go
Normal file
57
mod/lock/v2/lock_nodes.go
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
package v2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/coreos/go-etcd/etcd"
|
||||||
|
)
|
||||||
|
|
||||||
|
// lockNodes is a wrapper for go-etcd's Nodes to allow for sorting by numeric key.
|
||||||
|
type lockNodes struct {
|
||||||
|
etcd.Nodes
|
||||||
|
}
|
||||||
|
|
||||||
|
// Less sorts the nodes by key (numerically).
|
||||||
|
func (s lockNodes) Less(i, j int) bool {
|
||||||
|
a, _ := strconv.Atoi(path.Base(s.Nodes[i].Key))
|
||||||
|
b, _ := strconv.Atoi(path.Base(s.Nodes[j].Key))
|
||||||
|
return a < b
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieves the first node in the set of lock nodes.
|
||||||
|
func (s lockNodes) First() *etcd.Node {
|
||||||
|
sort.Sort(s)
|
||||||
|
if len(s.Nodes) > 0 {
|
||||||
|
return &s.Nodes[0]
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieves the first node with a given value.
|
||||||
|
func (s lockNodes) FindByValue(value string) *etcd.Node {
|
||||||
|
sort.Sort(s)
|
||||||
|
|
||||||
|
for _, node := range s.Nodes {
|
||||||
|
if node.Value == value {
|
||||||
|
return &node
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieves the index that occurs before a given index.
|
||||||
|
func (s lockNodes) PrevIndex(index int) int {
|
||||||
|
sort.Sort(s)
|
||||||
|
|
||||||
|
var prevIndex int
|
||||||
|
for _, node := range s.Nodes {
|
||||||
|
idx, _ := strconv.Atoi(path.Base(node.Key))
|
||||||
|
if index == idx {
|
||||||
|
return prevIndex
|
||||||
|
}
|
||||||
|
prevIndex = idx
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
@ -12,12 +12,39 @@ func (h *handler) releaseLockHandler(w http.ResponseWriter, req *http.Request) {
|
|||||||
h.client.SyncCluster()
|
h.client.SyncCluster()
|
||||||
|
|
||||||
vars := mux.Vars(req)
|
vars := mux.Vars(req)
|
||||||
keypath := path.Join(prefix, vars["key_with_index"])
|
keypath := path.Join(prefix, vars["key"])
|
||||||
|
|
||||||
|
// Read index and value parameters.
|
||||||
|
index := req.FormValue("index")
|
||||||
|
value := req.FormValue("value")
|
||||||
|
if len(index) == 0 && len(value) == 0 {
|
||||||
|
http.Error(w, "release lock error: index or value required", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
} else if len(index) != 0 && len(value) != 0 {
|
||||||
|
http.Error(w, "release lock error: index and value cannot both be specified", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look up index by value if index is missing.
|
||||||
|
if len(index) == 0 {
|
||||||
|
resp, err := h.client.Get(keypath, true, true)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "release lock index error: " + err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
nodes := lockNodes{resp.Node.Nodes}
|
||||||
|
node := nodes.FindByValue(value)
|
||||||
|
if node == nil {
|
||||||
|
http.Error(w, "release lock error: cannot find: " + value, http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
index = path.Base(node.Key)
|
||||||
|
}
|
||||||
|
|
||||||
// Delete the lock.
|
// Delete the lock.
|
||||||
_, err := h.client.Delete(keypath, false)
|
_, err := h.client.Delete(path.Join(keypath, index), false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "delete lock index error: " + err.Error(), http.StatusInternalServerError)
|
http.Error(w, "release lock error: " + err.Error(), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -13,18 +13,55 @@ import (
|
|||||||
func (h *handler) renewLockHandler(w http.ResponseWriter, req *http.Request) {
|
func (h *handler) renewLockHandler(w http.ResponseWriter, req *http.Request) {
|
||||||
h.client.SyncCluster()
|
h.client.SyncCluster()
|
||||||
|
|
||||||
|
// Read the lock path.
|
||||||
vars := mux.Vars(req)
|
vars := mux.Vars(req)
|
||||||
keypath := path.Join(prefix, vars["key_with_index"])
|
keypath := path.Join(prefix, vars["key"])
|
||||||
|
|
||||||
|
// Parse new TTL parameter.
|
||||||
ttl, err := strconv.Atoi(req.FormValue("ttl"))
|
ttl, err := strconv.Atoi(req.FormValue("ttl"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "invalid ttl: " + err.Error(), http.StatusInternalServerError)
|
http.Error(w, "invalid ttl: " + err.Error(), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Read and set defaults for index and value.
|
||||||
|
index := req.FormValue("index")
|
||||||
|
value := req.FormValue("value")
|
||||||
|
if len(index) == 0 && len(value) == 0 {
|
||||||
|
// The index or value is required.
|
||||||
|
http.Error(w, "renew lock error: index or value required", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(index) == 0 {
|
||||||
|
// If index is not specified then look it up by value.
|
||||||
|
resp, err := h.client.Get(keypath, true, true)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "renew lock index error: " + err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
nodes := lockNodes{resp.Node.Nodes}
|
||||||
|
node := nodes.FindByValue(value)
|
||||||
|
if node == nil {
|
||||||
|
http.Error(w, "renew lock error: cannot find: " + value, http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
index = path.Base(node.Key)
|
||||||
|
|
||||||
|
} else if len(value) == 0 {
|
||||||
|
// If value is not specified then default it to the previous value.
|
||||||
|
resp, err := h.client.Get(path.Join(keypath, index), true, false)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "renew lock value error: " + err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
value = resp.Node.Value
|
||||||
|
}
|
||||||
|
|
||||||
// Renew the lock, if it exists.
|
// Renew the lock, if it exists.
|
||||||
_, err = h.client.Update(keypath, "-", uint64(ttl))
|
_, err = h.client.Update(path.Join(keypath, index), value, uint64(ttl))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "renew lock index error: " + err.Error(), http.StatusInternalServerError)
|
http.Error(w, "renew lock error: " + err.Error(), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -14,7 +14,7 @@ import (
|
|||||||
func TestModLockAcquireAndRelease(t *testing.T) {
|
func TestModLockAcquireAndRelease(t *testing.T) {
|
||||||
tests.RunServer(func(s *server.Server) {
|
tests.RunServer(func(s *server.Server) {
|
||||||
// Acquire lock.
|
// Acquire lock.
|
||||||
body, err := testAcquireLock(s, "foo", 10)
|
body, err := testAcquireLock(s, "foo", "", 10)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, body, "2")
|
assert.Equal(t, body, "2")
|
||||||
|
|
||||||
@ -24,7 +24,7 @@ func TestModLockAcquireAndRelease(t *testing.T) {
|
|||||||
assert.Equal(t, body, "2")
|
assert.Equal(t, body, "2")
|
||||||
|
|
||||||
// Release lock.
|
// Release lock.
|
||||||
body, err = testReleaseLock(s, "foo", 2)
|
body, err = testReleaseLock(s, "foo", "2", "")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, body, "")
|
assert.Equal(t, body, "")
|
||||||
|
|
||||||
@ -42,7 +42,7 @@ func TestModLockBlockUntilAcquire(t *testing.T) {
|
|||||||
|
|
||||||
// Acquire lock #1.
|
// Acquire lock #1.
|
||||||
go func() {
|
go func() {
|
||||||
body, err := testAcquireLock(s, "foo", 10)
|
body, err := testAcquireLock(s, "foo", "", 10)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, body, "2")
|
assert.Equal(t, body, "2")
|
||||||
c <- true
|
c <- true
|
||||||
@ -50,11 +50,13 @@ func TestModLockBlockUntilAcquire(t *testing.T) {
|
|||||||
<- c
|
<- c
|
||||||
|
|
||||||
// Acquire lock #2.
|
// Acquire lock #2.
|
||||||
|
waiting := true
|
||||||
go func() {
|
go func() {
|
||||||
c <- true
|
c <- true
|
||||||
body, err := testAcquireLock(s, "foo", 10)
|
body, err := testAcquireLock(s, "foo", "", 10)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, body, "4")
|
assert.Equal(t, body, "4")
|
||||||
|
waiting = false
|
||||||
}()
|
}()
|
||||||
<- c
|
<- c
|
||||||
|
|
||||||
@ -65,8 +67,11 @@ func TestModLockBlockUntilAcquire(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, body, "2")
|
assert.Equal(t, body, "2")
|
||||||
|
|
||||||
|
// Check that we are still waiting for lock #2.
|
||||||
|
assert.Equal(t, waiting, true)
|
||||||
|
|
||||||
// Release lock #1.
|
// Release lock #1.
|
||||||
body, err = testReleaseLock(s, "foo", 2)
|
body, err = testReleaseLock(s, "foo", "2", "")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// Check that we have lock #2.
|
// Check that we have lock #2.
|
||||||
@ -75,7 +80,7 @@ func TestModLockBlockUntilAcquire(t *testing.T) {
|
|||||||
assert.Equal(t, body, "4")
|
assert.Equal(t, body, "4")
|
||||||
|
|
||||||
// Release lock #2.
|
// Release lock #2.
|
||||||
body, err = testReleaseLock(s, "foo", 4)
|
body, err = testReleaseLock(s, "foo", "4", "")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// Check that we have no lock.
|
// Check that we have no lock.
|
||||||
@ -92,7 +97,7 @@ func TestModLockExpireAndRelease(t *testing.T) {
|
|||||||
|
|
||||||
// Acquire lock #1.
|
// Acquire lock #1.
|
||||||
go func() {
|
go func() {
|
||||||
body, err := testAcquireLock(s, "foo", 2)
|
body, err := testAcquireLock(s, "foo", "", 2)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, body, "2")
|
assert.Equal(t, body, "2")
|
||||||
c <- true
|
c <- true
|
||||||
@ -102,7 +107,7 @@ func TestModLockExpireAndRelease(t *testing.T) {
|
|||||||
// Acquire lock #2.
|
// Acquire lock #2.
|
||||||
go func() {
|
go func() {
|
||||||
c <- true
|
c <- true
|
||||||
body, err := testAcquireLock(s, "foo", 10)
|
body, err := testAcquireLock(s, "foo", "", 10)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, body, "4")
|
assert.Equal(t, body, "4")
|
||||||
}()
|
}()
|
||||||
@ -129,7 +134,7 @@ func TestModLockExpireAndRelease(t *testing.T) {
|
|||||||
func TestModLockRenew(t *testing.T) {
|
func TestModLockRenew(t *testing.T) {
|
||||||
tests.RunServer(func(s *server.Server) {
|
tests.RunServer(func(s *server.Server) {
|
||||||
// Acquire lock.
|
// Acquire lock.
|
||||||
body, err := testAcquireLock(s, "foo", 3)
|
body, err := testAcquireLock(s, "foo", "", 3)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, body, "2")
|
assert.Equal(t, body, "2")
|
||||||
|
|
||||||
@ -141,7 +146,7 @@ func TestModLockRenew(t *testing.T) {
|
|||||||
assert.Equal(t, body, "2")
|
assert.Equal(t, body, "2")
|
||||||
|
|
||||||
// Renew lock.
|
// Renew lock.
|
||||||
body, err = testRenewLock(s, "foo", 2, 3)
|
body, err = testRenewLock(s, "foo", "2", "", 3)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, body, "")
|
assert.Equal(t, body, "")
|
||||||
|
|
||||||
@ -161,28 +166,59 @@ func TestModLockRenew(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Ensure that a lock can be acquired with a value and released by value.
|
||||||
|
func TestModLockAcquireAndReleaseByValue(t *testing.T) {
|
||||||
|
tests.RunServer(func(s *server.Server) {
|
||||||
|
// Acquire lock.
|
||||||
|
body, err := testAcquireLock(s, "foo", "XXX", 10)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, body, "2")
|
||||||
|
|
||||||
|
// Check that we have the lock.
|
||||||
|
body, err = testGetLockValue(s, "foo")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, body, "XXX")
|
||||||
|
|
||||||
|
// Release lock.
|
||||||
|
body, err = testReleaseLock(s, "foo", "", "XXX")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, body, "")
|
||||||
|
|
||||||
|
// Check that we released the lock.
|
||||||
|
body, err = testGetLockValue(s, "foo")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, body, "")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
func testAcquireLock(s *server.Server, key string, ttl int) (string, error) {
|
|
||||||
resp, err := tests.PostForm(fmt.Sprintf("%s/mod/v2/lock/%s?ttl=%d", s.URL(), key, ttl), nil)
|
func testAcquireLock(s *server.Server, key string, value string, ttl int) (string, error) {
|
||||||
|
resp, err := tests.PostForm(fmt.Sprintf("%s/mod/v2/lock/%s?value=%s&ttl=%d", s.URL(), key, value, ttl), nil)
|
||||||
ret := tests.ReadBody(resp)
|
ret := tests.ReadBody(resp)
|
||||||
return string(ret), err
|
return string(ret), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func testGetLockIndex(s *server.Server, key string) (string, error) {
|
func testGetLockIndex(s *server.Server, key string) (string, error) {
|
||||||
|
resp, err := tests.Get(fmt.Sprintf("%s/mod/v2/lock/%s?field=index", s.URL(), key))
|
||||||
|
ret := tests.ReadBody(resp)
|
||||||
|
return string(ret), err
|
||||||
|
}
|
||||||
|
|
||||||
|
func testGetLockValue(s *server.Server, key string) (string, error) {
|
||||||
resp, err := tests.Get(fmt.Sprintf("%s/mod/v2/lock/%s", s.URL(), key))
|
resp, err := tests.Get(fmt.Sprintf("%s/mod/v2/lock/%s", s.URL(), key))
|
||||||
ret := tests.ReadBody(resp)
|
ret := tests.ReadBody(resp)
|
||||||
return string(ret), err
|
return string(ret), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func testReleaseLock(s *server.Server, key string, index int) (string, error) {
|
func testReleaseLock(s *server.Server, key string, index string, value string) (string, error) {
|
||||||
resp, err := tests.DeleteForm(fmt.Sprintf("%s/mod/v2/lock/%s/%d", s.URL(), key, index), nil)
|
resp, err := tests.DeleteForm(fmt.Sprintf("%s/mod/v2/lock/%s?index=%s&value=%s", s.URL(), key, index, value), nil)
|
||||||
ret := tests.ReadBody(resp)
|
ret := tests.ReadBody(resp)
|
||||||
return string(ret), err
|
return string(ret), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func testRenewLock(s *server.Server, key string, index int, ttl int) (string, error) {
|
func testRenewLock(s *server.Server, key string, index string, value string, ttl int) (string, error) {
|
||||||
resp, err := tests.PutForm(fmt.Sprintf("%s/mod/v2/lock/%s/%d?ttl=%d", s.URL(), key, index, ttl), nil)
|
resp, err := tests.PutForm(fmt.Sprintf("%s/mod/v2/lock/%s?index=%s&value=%s&ttl=%d", s.URL(), key, index, value, ttl), nil)
|
||||||
ret := tests.ReadBody(resp)
|
ret := tests.ReadBody(resp)
|
||||||
return string(ret), err
|
return string(ret), err
|
||||||
}
|
}
|
Loading…
x
Reference in New Issue
Block a user