Refactor mod/lock.

This commit is contained in:
Ben Johnson 2013-12-07 18:04:16 -07:00
parent 8442e7a0dc
commit 4bec461db1
7 changed files with 325 additions and 137 deletions

View File

@ -1,6 +1,8 @@
package v2
import (
"errors"
"fmt"
"net/http"
"path"
"strconv"
@ -12,6 +14,7 @@ import (
// acquireHandler attempts to acquire a lock on the given key.
// The "key" parameter specifies the resource to lock.
// The "value" parameter specifies a value to associate with the lock.
// The "ttl" parameter specifies how long the lock will persist for.
// The "timeout" parameter specifies how long the request should wait for the lock.
func (h *handler) acquireHandler(w http.ResponseWriter, req *http.Request) {
@ -20,109 +23,152 @@ func (h *handler) acquireHandler(w http.ResponseWriter, req *http.Request) {
// Setup connection watcher.
closeNotifier, _ := w.(http.CloseNotifier)
closeChan := closeNotifier.CloseNotify()
stopChan := make(chan bool)
// Parse "key" and "ttl" query parameters.
// Parse the lock "key".
vars := mux.Vars(req)
keypath := path.Join(prefix, vars["key"])
ttl, err := strconv.Atoi(req.FormValue("ttl"))
if err != nil {
http.Error(w, "invalid ttl: " + err.Error(), http.StatusInternalServerError)
return
}
value := req.FormValue("value")
// Parse "timeout" parameter.
var timeout int
if len(req.FormValue("timeout")) == 0 {
var err error
if req.FormValue("timeout") == "" {
timeout = -1
} else if timeout, err = strconv.Atoi(req.FormValue("timeout")); err != nil {
http.Error(w, "invalid timeout: " + err.Error(), http.StatusInternalServerError)
http.Error(w, "invalid timeout: " + req.FormValue("timeout"), http.StatusInternalServerError)
return
}
timeout = timeout + 1
// Create an incrementing id for the lock.
resp, err := h.client.AddChild(keypath, "-", uint64(ttl))
// Parse TTL.
ttl, err := strconv.Atoi(req.FormValue("ttl"))
if err != nil {
http.Error(w, "add lock index error: " + err.Error(), http.StatusInternalServerError)
http.Error(w, "invalid ttl: " + req.FormValue("ttl"), http.StatusInternalServerError)
return
}
indexpath := resp.Node.Key
// Keep updating TTL to make sure lock request is not expired before acquisition.
stop := make(chan bool)
go h.ttlKeepAlive(indexpath, ttl, stop)
// Monitor for broken connection.
stopWatchChan := make(chan bool)
go func() {
select {
case <-closeChan:
stopWatchChan <- true
case <-stop:
// Stop watching for connection disconnect.
}
}()
// Extract the lock index.
index, _ := strconv.Atoi(path.Base(resp.Node.Key))
// Wait until we successfully get a lock or we get a failure.
var success bool
for {
// Read all indices.
resp, err = h.client.Get(keypath, true, true)
if err != nil {
http.Error(w, "lock children lookup error: " + err.Error(), http.StatusInternalServerError)
break
}
indices := extractResponseIndices(resp)
waitIndex := resp.Node.ModifiedIndex
prevIndex := findPrevIndex(indices, index)
// If there is no previous index then we have the lock.
if prevIndex == 0 {
success = true
break
}
// Otherwise watch previous index until it's gone.
_, err = h.client.Watch(path.Join(keypath, strconv.Itoa(prevIndex)), waitIndex, false, nil, stopWatchChan)
if err == etcd.ErrWatchStoppedByUser {
break
} else if err != nil {
http.Error(w, "lock watch error: " + err.Error(), http.StatusInternalServerError)
break
}
}
// Check for connection disconnect before we write the lock index.
select {
case <-stopWatchChan:
success = false
default:
}
// Stop the ttl keep-alive.
close(stop)
if success {
// Write lock index to response body if we acquire the lock.
h.client.Update(indexpath, "-", uint64(ttl))
w.Write([]byte(strconv.Itoa(index)))
// If node exists then just watch it. Otherwise create the node and watch it.
index := h.findExistingNode(keypath, value)
if index > 0 {
err = h.watch(keypath, index, nil)
} else {
// Make sure key is deleted if we couldn't acquire.
h.client.Delete(indexpath, false)
index, err = h.createNode(keypath, value, ttl, closeChan, stopChan)
}
// Stop all goroutines.
close(stopChan)
// Write response.
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
} else {
w.Write([]byte(strconv.Itoa(index)))
}
}
// createNode creates a new lock node and watches it until it is acquired or acquisition fails.
func (h *handler) createNode(keypath string, value string, ttl int, closeChan <- chan bool, stopChan chan bool) (int, error) {
// Default the value to "-" if it is blank.
if len(value) == 0 {
value = "-"
}
// Create an incrementing id for the lock.
resp, err := h.client.AddChild(keypath, value, uint64(ttl))
if err != nil {
return 0, errors.New("acquire lock index error: " + err.Error())
}
indexpath := resp.Node.Key
index, _ := strconv.Atoi(path.Base(indexpath))
// Keep updating TTL to make sure lock request is not expired before acquisition.
go h.ttlKeepAlive(indexpath, value, ttl, stopChan)
// Watch until we acquire or fail.
err = h.watch(keypath, index, closeChan)
// Check for connection disconnect before we write the lock index.
if err != nil {
select {
case <-closeChan:
err = errors.New("acquire lock error: user interrupted")
default:
}
}
// Update TTL one last time if acquired. Otherwise delete.
if err == nil {
h.client.Update(indexpath, value, uint64(ttl))
} else {
h.client.Delete(indexpath, false)
}
return index, err
}
// findExistingNode search for a node on the lock with the given value.
func (h *handler) findExistingNode(keypath string, value string) int {
if len(value) > 0 {
resp, err := h.client.Get(keypath, true, true)
if err == nil {
nodes := lockNodes{resp.Node.Nodes}
if node := nodes.FindByValue(value); node != nil {
index, _ := strconv.Atoi(path.Base(node.Key))
return index
}
}
}
return 0
}
// ttlKeepAlive continues to update a key's TTL until the stop channel is closed.
func (h *handler) ttlKeepAlive(k string, ttl int, stop chan bool) {
func (h *handler) ttlKeepAlive(k string, value string, ttl int, stopChan chan bool) {
for {
select {
case <-time.After(time.Duration(ttl / 2) * time.Second):
h.client.Update(k, "-", uint64(ttl))
case <-stop:
h.client.Update(k, value, uint64(ttl))
case <-stopChan:
return
}
}
}
// watch continuously waits for a given lock index to be acquired or until lock fails.
// Returns a boolean indicating success.
func (h *handler) watch(keypath string, index int, closeChan <- chan bool) error {
// Wrap close chan so we can pass it to Client.Watch().
stopWatchChan := make(chan bool)
go func() {
select {
case <- closeChan:
stopWatchChan <- true
case <- stopWatchChan:
}
}()
defer close(stopWatchChan)
for {
// Read all nodes for the lock.
resp, err := h.client.Get(keypath, true, true)
if err != nil {
return fmt.Errorf("lock watch lookup error: %s", err.Error())
}
waitIndex := resp.Node.ModifiedIndex
nodes := lockNodes{resp.Node.Nodes}
prevIndex := nodes.PrevIndex(index)
// If there is no previous index then we have the lock.
if prevIndex == 0 {
return nil
}
// Watch previous index until it's gone.
_, err = h.client.Watch(path.Join(keypath, strconv.Itoa(prevIndex)), waitIndex, false, nil, stopWatchChan)
if err == etcd.ErrWatchStoppedByUser {
return fmt.Errorf("lock watch closed")
} else if err != nil {
return fmt.Errorf("lock watch error:%s", err.Error())
}
}
}

View File

@ -3,28 +3,41 @@ package v2
import (
"net/http"
"path"
"strconv"
"github.com/gorilla/mux"
)
// getIndexHandler retrieves the current lock index.
// The "field" parameter specifies to read either the lock "index" or lock "value".
func (h *handler) getIndexHandler(w http.ResponseWriter, req *http.Request) {
h.client.SyncCluster()
vars := mux.Vars(req)
keypath := path.Join(prefix, vars["key"])
field := req.FormValue("field")
if len(field) == 0 {
field = "value"
}
// Read all indices.
resp, err := h.client.Get(keypath, true, true)
if err != nil {
http.Error(w, "lock children lookup error: " + err.Error(), http.StatusInternalServerError)
http.Error(w, "read lock error: " + err.Error(), http.StatusInternalServerError)
return
}
nodes := lockNodes{resp.Node.Nodes}
// Write out the index of the last one to the response body.
indices := extractResponseIndices(resp)
if len(indices) > 0 {
w.Write([]byte(strconv.Itoa(indices[0])))
// Write out the requested field.
if node := nodes.First(); node != nil {
switch field {
case "index":
w.Write([]byte(path.Base(node.Key)))
case "value":
w.Write([]byte(node.Value))
default:
http.Error(w, "read lock error: invalid field: " + field, http.StatusInternalServerError)
}
}
}

View File

@ -2,9 +2,6 @@ package v2
import (
"net/http"
"path"
"strconv"
"sort"
"github.com/gorilla/mux"
"github.com/coreos/go-etcd/etcd"
@ -27,32 +24,7 @@ func NewHandler(addr string) (http.Handler) {
h.StrictSlash(false)
h.HandleFunc("/{key:.*}", h.getIndexHandler).Methods("GET")
h.HandleFunc("/{key:.*}", h.acquireHandler).Methods("POST")
h.HandleFunc("/{key_with_index:.*}", h.renewLockHandler).Methods("PUT")
h.HandleFunc("/{key_with_index:.*}", h.releaseLockHandler).Methods("DELETE")
h.HandleFunc("/{key:.*}", h.renewLockHandler).Methods("PUT")
h.HandleFunc("/{key:.*}", h.releaseLockHandler).Methods("DELETE")
return h
}
// extractResponseIndices extracts a sorted list of indicies from a response.
func extractResponseIndices(resp *etcd.Response) []int {
var indices []int
for _, node := range resp.Node.Nodes {
if index, _ := strconv.Atoi(path.Base(node.Key)); index > 0 {
indices = append(indices, index)
}
}
sort.Ints(indices)
return indices
}
// findPrevIndex retrieves the previous index before the given index.
func findPrevIndex(indices []int, idx int) int {
var prevIndex int
for _, index := range indices {
if index == idx {
break
}
prevIndex = index
}
return prevIndex
}

57
mod/lock/v2/lock_nodes.go Normal file
View File

@ -0,0 +1,57 @@
package v2
import (
"path"
"sort"
"strconv"
"github.com/coreos/go-etcd/etcd"
)
// lockNodes is a wrapper for go-etcd's Nodes to allow for sorting by numeric key.
type lockNodes struct {
etcd.Nodes
}
// Less sorts the nodes by key (numerically).
func (s lockNodes) Less(i, j int) bool {
a, _ := strconv.Atoi(path.Base(s.Nodes[i].Key))
b, _ := strconv.Atoi(path.Base(s.Nodes[j].Key))
return a < b
}
// Retrieves the first node in the set of lock nodes.
func (s lockNodes) First() *etcd.Node {
sort.Sort(s)
if len(s.Nodes) > 0 {
return &s.Nodes[0]
}
return nil
}
// Retrieves the first node with a given value.
func (s lockNodes) FindByValue(value string) *etcd.Node {
sort.Sort(s)
for _, node := range s.Nodes {
if node.Value == value {
return &node
}
}
return nil
}
// Retrieves the index that occurs before a given index.
func (s lockNodes) PrevIndex(index int) int {
sort.Sort(s)
var prevIndex int
for _, node := range s.Nodes {
idx, _ := strconv.Atoi(path.Base(node.Key))
if index == idx {
return prevIndex
}
prevIndex = idx
}
return 0
}

View File

@ -12,12 +12,39 @@ func (h *handler) releaseLockHandler(w http.ResponseWriter, req *http.Request) {
h.client.SyncCluster()
vars := mux.Vars(req)
keypath := path.Join(prefix, vars["key_with_index"])
keypath := path.Join(prefix, vars["key"])
// Read index and value parameters.
index := req.FormValue("index")
value := req.FormValue("value")
if len(index) == 0 && len(value) == 0 {
http.Error(w, "release lock error: index or value required", http.StatusInternalServerError)
return
} else if len(index) != 0 && len(value) != 0 {
http.Error(w, "release lock error: index and value cannot both be specified", http.StatusInternalServerError)
return
}
// Look up index by value if index is missing.
if len(index) == 0 {
resp, err := h.client.Get(keypath, true, true)
if err != nil {
http.Error(w, "release lock index error: " + err.Error(), http.StatusInternalServerError)
return
}
nodes := lockNodes{resp.Node.Nodes}
node := nodes.FindByValue(value)
if node == nil {
http.Error(w, "release lock error: cannot find: " + value, http.StatusInternalServerError)
return
}
index = path.Base(node.Key)
}
// Delete the lock.
_, err := h.client.Delete(keypath, false)
_, err := h.client.Delete(path.Join(keypath, index), false)
if err != nil {
http.Error(w, "delete lock index error: " + err.Error(), http.StatusInternalServerError)
http.Error(w, "release lock error: " + err.Error(), http.StatusInternalServerError)
return
}
}

View File

@ -13,18 +13,55 @@ import (
func (h *handler) renewLockHandler(w http.ResponseWriter, req *http.Request) {
h.client.SyncCluster()
// Read the lock path.
vars := mux.Vars(req)
keypath := path.Join(prefix, vars["key_with_index"])
keypath := path.Join(prefix, vars["key"])
// Parse new TTL parameter.
ttl, err := strconv.Atoi(req.FormValue("ttl"))
if err != nil {
http.Error(w, "invalid ttl: " + err.Error(), http.StatusInternalServerError)
return
}
// Read and set defaults for index and value.
index := req.FormValue("index")
value := req.FormValue("value")
if len(index) == 0 && len(value) == 0 {
// The index or value is required.
http.Error(w, "renew lock error: index or value required", http.StatusInternalServerError)
return
}
if len(index) == 0 {
// If index is not specified then look it up by value.
resp, err := h.client.Get(keypath, true, true)
if err != nil {
http.Error(w, "renew lock index error: " + err.Error(), http.StatusInternalServerError)
return
}
nodes := lockNodes{resp.Node.Nodes}
node := nodes.FindByValue(value)
if node == nil {
http.Error(w, "renew lock error: cannot find: " + value, http.StatusInternalServerError)
return
}
index = path.Base(node.Key)
} else if len(value) == 0 {
// If value is not specified then default it to the previous value.
resp, err := h.client.Get(path.Join(keypath, index), true, false)
if err != nil {
http.Error(w, "renew lock value error: " + err.Error(), http.StatusInternalServerError)
return
}
value = resp.Node.Value
}
// Renew the lock, if it exists.
_, err = h.client.Update(keypath, "-", uint64(ttl))
_, err = h.client.Update(path.Join(keypath, index), value, uint64(ttl))
if err != nil {
http.Error(w, "renew lock index error: " + err.Error(), http.StatusInternalServerError)
http.Error(w, "renew lock error: " + err.Error(), http.StatusInternalServerError)
return
}
}

View File

@ -14,7 +14,7 @@ import (
func TestModLockAcquireAndRelease(t *testing.T) {
tests.RunServer(func(s *server.Server) {
// Acquire lock.
body, err := testAcquireLock(s, "foo", 10)
body, err := testAcquireLock(s, "foo", "", 10)
assert.NoError(t, err)
assert.Equal(t, body, "2")
@ -24,7 +24,7 @@ func TestModLockAcquireAndRelease(t *testing.T) {
assert.Equal(t, body, "2")
// Release lock.
body, err = testReleaseLock(s, "foo", 2)
body, err = testReleaseLock(s, "foo", "2", "")
assert.NoError(t, err)
assert.Equal(t, body, "")
@ -42,7 +42,7 @@ func TestModLockBlockUntilAcquire(t *testing.T) {
// Acquire lock #1.
go func() {
body, err := testAcquireLock(s, "foo", 10)
body, err := testAcquireLock(s, "foo", "", 10)
assert.NoError(t, err)
assert.Equal(t, body, "2")
c <- true
@ -50,11 +50,13 @@ func TestModLockBlockUntilAcquire(t *testing.T) {
<- c
// Acquire lock #2.
waiting := true
go func() {
c <- true
body, err := testAcquireLock(s, "foo", 10)
body, err := testAcquireLock(s, "foo", "", 10)
assert.NoError(t, err)
assert.Equal(t, body, "4")
waiting = false
}()
<- c
@ -65,8 +67,11 @@ func TestModLockBlockUntilAcquire(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, body, "2")
// Check that we are still waiting for lock #2.
assert.Equal(t, waiting, true)
// Release lock #1.
body, err = testReleaseLock(s, "foo", 2)
body, err = testReleaseLock(s, "foo", "2", "")
assert.NoError(t, err)
// Check that we have lock #2.
@ -75,7 +80,7 @@ func TestModLockBlockUntilAcquire(t *testing.T) {
assert.Equal(t, body, "4")
// Release lock #2.
body, err = testReleaseLock(s, "foo", 4)
body, err = testReleaseLock(s, "foo", "4", "")
assert.NoError(t, err)
// Check that we have no lock.
@ -92,7 +97,7 @@ func TestModLockExpireAndRelease(t *testing.T) {
// Acquire lock #1.
go func() {
body, err := testAcquireLock(s, "foo", 2)
body, err := testAcquireLock(s, "foo", "", 2)
assert.NoError(t, err)
assert.Equal(t, body, "2")
c <- true
@ -102,7 +107,7 @@ func TestModLockExpireAndRelease(t *testing.T) {
// Acquire lock #2.
go func() {
c <- true
body, err := testAcquireLock(s, "foo", 10)
body, err := testAcquireLock(s, "foo", "", 10)
assert.NoError(t, err)
assert.Equal(t, body, "4")
}()
@ -129,7 +134,7 @@ func TestModLockExpireAndRelease(t *testing.T) {
func TestModLockRenew(t *testing.T) {
tests.RunServer(func(s *server.Server) {
// Acquire lock.
body, err := testAcquireLock(s, "foo", 3)
body, err := testAcquireLock(s, "foo", "", 3)
assert.NoError(t, err)
assert.Equal(t, body, "2")
@ -141,7 +146,7 @@ func TestModLockRenew(t *testing.T) {
assert.Equal(t, body, "2")
// Renew lock.
body, err = testRenewLock(s, "foo", 2, 3)
body, err = testRenewLock(s, "foo", "2", "", 3)
assert.NoError(t, err)
assert.Equal(t, body, "")
@ -161,28 +166,59 @@ func TestModLockRenew(t *testing.T) {
})
}
// Ensure that a lock can be acquired with a value and released by value.
func TestModLockAcquireAndReleaseByValue(t *testing.T) {
tests.RunServer(func(s *server.Server) {
// Acquire lock.
body, err := testAcquireLock(s, "foo", "XXX", 10)
assert.NoError(t, err)
assert.Equal(t, body, "2")
// Check that we have the lock.
body, err = testGetLockValue(s, "foo")
assert.NoError(t, err)
assert.Equal(t, body, "XXX")
// Release lock.
body, err = testReleaseLock(s, "foo", "", "XXX")
assert.NoError(t, err)
assert.Equal(t, body, "")
// Check that we released the lock.
body, err = testGetLockValue(s, "foo")
assert.NoError(t, err)
assert.Equal(t, body, "")
})
}
func testAcquireLock(s *server.Server, key string, ttl int) (string, error) {
resp, err := tests.PostForm(fmt.Sprintf("%s/mod/v2/lock/%s?ttl=%d", s.URL(), key, ttl), nil)
func testAcquireLock(s *server.Server, key string, value string, ttl int) (string, error) {
resp, err := tests.PostForm(fmt.Sprintf("%s/mod/v2/lock/%s?value=%s&ttl=%d", s.URL(), key, value, ttl), nil)
ret := tests.ReadBody(resp)
return string(ret), err
}
func testGetLockIndex(s *server.Server, key string) (string, error) {
resp, err := tests.Get(fmt.Sprintf("%s/mod/v2/lock/%s?field=index", s.URL(), key))
ret := tests.ReadBody(resp)
return string(ret), err
}
func testGetLockValue(s *server.Server, key string) (string, error) {
resp, err := tests.Get(fmt.Sprintf("%s/mod/v2/lock/%s", s.URL(), key))
ret := tests.ReadBody(resp)
return string(ret), err
}
func testReleaseLock(s *server.Server, key string, index int) (string, error) {
resp, err := tests.DeleteForm(fmt.Sprintf("%s/mod/v2/lock/%s/%d", s.URL(), key, index), nil)
func testReleaseLock(s *server.Server, key string, index string, value string) (string, error) {
resp, err := tests.DeleteForm(fmt.Sprintf("%s/mod/v2/lock/%s?index=%s&value=%s", s.URL(), key, index, value), nil)
ret := tests.ReadBody(resp)
return string(ret), err
}
func testRenewLock(s *server.Server, key string, index int, ttl int) (string, error) {
resp, err := tests.PutForm(fmt.Sprintf("%s/mod/v2/lock/%s/%d?ttl=%d", s.URL(), key, index, ttl), nil)
func testRenewLock(s *server.Server, key string, index string, value string, ttl int) (string, error) {
resp, err := tests.PutForm(fmt.Sprintf("%s/mod/v2/lock/%s?index=%s&value=%s&ttl=%d", s.URL(), key, index, value, ttl), nil)
ret := tests.ReadBody(resp)
return string(ret), err
}