diff --git a/command.go b/command.go index dd288f3bf..ba8c140f6 100644 --- a/command.go +++ b/command.go @@ -8,6 +8,7 @@ package main import ( "github.com/benbjohnson/go-raft" "encoding/json" + "time" ) // A command represents an action to be taken on the replicated state machine. @@ -34,7 +35,7 @@ func (c *SetCommand) CommandName() string { // Set the value of key to value func (c *SetCommand) Apply(server *raft.Server) ([]byte, error) { - res := s.Set(c.Key, c.Value) + res := s.Set(c.Key, c.Value, time.Unix(0, 0)) return json.Marshal(res) } diff --git a/store.go b/store.go index d26554169..dabbc2a54 100644 --- a/store.go +++ b/store.go @@ -3,26 +3,35 @@ package main import ( "path" "encoding/json" + "time" + "fmt" ) // CONSTANTS const ( ERROR = -1 + iota SET - GET DELETE + GET ) type Store struct { - Nodes map[string]string `json:"nodes"` + Nodes map[string]Node `json:"nodes"` +} + +type Node struct { + Value string + ExpireTime time.Time + update chan time.Time } type Response struct { - Action int `json:action` - Key string `json:key` - OldValue string `json:oldValue` - NewValue string `json:newValue` - Exist bool `json:exist` + Action int `json:"action"` + Key string `json:"key"` + OldValue string `json:"oldValue"` + NewValue string `json:"newValue"` + Exist bool `json:"exist"` + Expiration time.Time `json:"expiration"` } @@ -36,25 +45,75 @@ func init() { // make a new stroe func createStore() *Store{ s := new(Store) - s.Nodes = make(map[string]string) + s.Nodes = make(map[string]Node) return s } // set the key to value, return the old value if the key exists -func (s *Store) Set(key string, value string) Response { +func (s *Store) Set(key string, value string, expireTime time.Time) Response { + key = path.Clean(key) - oldValue, ok := s.Nodes[key] + var expire bool = false + + expire = !expireTime.Equal(time.Unix(0,0)) + + // when the slow follower receive the set command + // the key may be expired, we need also to delete + // the previous value of key + if expire && expireTime.Sub(time.Now()) < 0 { + return s.Delete(key) + } + + node, ok := s.Nodes[key] if ok { - s.Nodes[key] = value - w.notify(SET, key, oldValue, value, true) - return Response{SET, key, oldValue, value, true} + update := make(chan time.Time) + s.Nodes[key] = Node{value, expireTime, update} + w.notify(SET, key, node.Value, value, true) + + // node is not permanent before + if !node.ExpireTime.Equal(time.Unix(0,0)) { + node.update <- expireTime + } else { + // if current node is not permanent + if expire { + go s.expire(key, update, expireTime) + } + } + + return Response{SET, key, node.Value, value, true, time.Unix(0, 0)} } else { - s.Nodes[key] = value + update := make(chan time.Time) + s.Nodes[key] = Node{value, expireTime, update} w.notify(SET, key, "", value, false) - return Response{SET, key, "", value, false} + if expire { + go s.expire(key, update, expireTime) + } + return Response{SET, key, "", value, false, time.Unix(0, 0)} + } +} + +// delete the key when it expires +func (s *Store) expire(key string, update chan time.Time, expireTime time.Time) { + duration := expireTime.Sub(time.Now()) + + for { + select { + // timeout delte key + case <-time.After(duration): + fmt.Println("expired at ", time.Now()) + s.Delete(key) + return + case updateTime := <-update: + //update duration + if updateTime.Equal(time.Unix(0,0)) { + fmt.Println("node became stable") + return + } + duration = updateTime.Sub(time.Now()) + } } } @@ -62,12 +121,12 @@ func (s *Store) Set(key string, value string) Response { func (s *Store) Get(key string) Response { key = path.Clean(key) - value, ok := s.Nodes[key] + node, ok := s.Nodes[key] if ok { - return Response{GET, key, value, value, true} + return Response{GET, key, node.Value, node.Value, true, node.ExpireTime} } else { - return Response{GET, key, "", value, false} + return Response{GET, key, "", "", false, time.Unix(0, 0)} } } @@ -75,16 +134,16 @@ func (s *Store) Get(key string) Response { func (s *Store) Delete(key string) Response { key = path.Clean(key) - oldValue, ok := s.Nodes[key] + node, ok := s.Nodes[key] if ok { delete(s.Nodes, key) - w.notify(DELETE, key, oldValue, "", true) + w.notify(DELETE, key, node.Value, "", true) - return Response{DELETE, key, oldValue, "", true} + return Response{DELETE, key, node.Value, "", true, node.ExpireTime} } else { - return Response{DELETE, key, "", "", false} + return Response{DELETE, key, "", "", false, time.Unix(0, 0)} } } diff --git a/store_test.go b/store_test.go index 705b901f9..6d1f77e10 100644 --- a/store_test.go +++ b/store_test.go @@ -2,29 +2,31 @@ package main import ( "testing" + "time" + "fmt" ) func TestStoreGet(t *testing.T) { - s.Set("foo", "bar") + s.Set("foo", "bar", time.Unix(0, 0)) - value, err := s.Get("foo") + res := s.Get("foo") - if err!= nil || value != "bar" { + if res.NewValue != "bar" { t.Fatalf("Cannot get stored value") } s.Delete("foo") - value, err = s.Get("foo") + res = s.Get("foo") - if err == nil{ + if res.Exist { t.Fatalf("Got deleted value") } } func TestSaveAndRecovery(t *testing.T) { - s.Set("foo", "bar") + s.Set("foo", "bar", time.Unix(0, 0)) state, err := s.Save() @@ -35,10 +37,78 @@ func TestSaveAndRecovery(t *testing.T) { newStore := createStore() newStore.Recovery(state) - value, err := newStore.Get("foo") + res := newStore.Get("foo") - if err!= nil || value != "bar" { + if res.OldValue != "bar" { t.Fatalf("Cannot recovery") } + s.Delete("foo") + +} + +func TestExpire(t *testing.T) { + fmt.Println(time.Now()) + fmt.Println("TEST EXPIRE") + + // test expire + s.Set("foo", "bar", time.Now().Add(time.Second * 1)) + time.Sleep(2*time.Second) + + res := s.Get("foo") + + if res.Exist { + t.Fatalf("Got expired value") + } + + //test change expire time + s.Set("foo", "bar", time.Now().Add(time.Second * 10)) + + res = s.Get("foo") + + if !res.Exist { + t.Fatalf("Cannot get Value") + } + + s.Set("foo", "barbar", time.Now().Add(time.Second * 1)) + + time.Sleep(2 * time.Second) + + res = s.Get("foo") + + if res.Exist { + t.Fatalf("Got expired value") + } + + + // test change expire to stable + s.Set("foo", "bar", time.Now().Add(time.Second * 1)) + + s.Set("foo", "bar", time.Unix(0,0)) + + time.Sleep(2*time.Second) + + res = s.Get("foo") + + if !res.Exist { + t.Fatalf("Cannot get Value") + } + + // test stable to expire + s.Set("foo", "bar", time.Now().Add(time.Second * 1)) + time.Sleep(2*time.Second) + res = s.Get("foo") + + if res.Exist { + t.Fatalf("Got expired value") + } + + // test set older node + s.Set("foo", "bar", time.Now().Add(-time.Second * 1)) + res = s.Get("foo") + + if res.Exist { + t.Fatalf("Got expired value") + } + } diff --git a/watcher.go b/watcher.go index 43682b696..f766ea84f 100644 --- a/watcher.go +++ b/watcher.go @@ -3,6 +3,8 @@ package main import ( "path" "strings" + //"fmt" + "time" ) @@ -46,20 +48,18 @@ func (w *Watcher) add(prefix string, c chan Response) error { func (w *Watcher) notify(action int, key string, oldValue string, newValue string, exist bool) error { key = path.Clean(key) segments := strings.Split(key, "/") - currPath := "/" // walk through all the pathes for _, segment := range segments { - - currPath := path.Join(currPath, segment) + currPath = path.Join(currPath, segment) chans, ok := w.chanMap[currPath] if ok { - debug("Notify at ", currPath) + debug("Notify at %s", currPath) - n := Response {action, key, oldValue, newValue, exist} + n := Response {action, key, oldValue, newValue, exist, time.Unix(0, 0)} // notify all the watchers for _, c := range chans { diff --git a/watcher_test.go b/watcher_test.go index f7197eb9c..f6da49719 100644 --- a/watcher_test.go +++ b/watcher_test.go @@ -3,23 +3,24 @@ package main import ( "testing" "fmt" + "time" ) func TestWatch(t *testing.T) { // watcher := createWatcher() - c := make(chan Notification) - d := make(chan Notification) + c := make(chan Response) + d := make(chan Response) w.add("/", c) go say(c) w.add("/prefix/", d) go say(d) - s.Set("/prefix/foo", "bar") + s.Set("/prefix/foo", "bar", time.Unix(0, 0)) } -func say(c chan Notification) { +func say(c chan Response) { result := <-c - if result.action != -1 { + if result.Action != -1 { fmt.Println("yes") } else { fmt.Println("no")