diff --git a/etcdserver2/request.pb.go b/etcdserver2/request.pb.go index f8ff72fe5..d2a4bd18c 100644 --- a/etcdserver2/request.pb.go +++ b/etcdserver2/request.pb.go @@ -35,7 +35,7 @@ type Request struct { Dir bool `protobuf:"varint,5,req,name=dir" json:"dir"` PrevValue string `protobuf:"bytes,6,req,name=prevValue" json:"prevValue"` PrevIndex uint64 `protobuf:"varint,7,req,name=prevIndex" json:"prevIndex"` - PrevExists bool `protobuf:"varint,8,req,name=prevExists" json:"prevExists"` + PrevExists *bool `protobuf:"varint,8,req,name=prevExists" json:"prevExists,omitempty"` Expiration int64 `protobuf:"varint,9,req,name=expiration" json:"expiration"` Wait bool `protobuf:"varint,10,req,name=wait" json:"wait"` Since uint64 `protobuf:"varint,11,req,name=since" json:"since"` @@ -220,7 +220,8 @@ func (m *Request) Unmarshal(data []byte) error { break } } - m.PrevExists = bool(v != 0) + b := bool(v != 0) + m.PrevExists = &b case 9: if wireType != 0 { return code_google_com_p_gogoprotobuf_proto.ErrWrongType @@ -339,7 +340,9 @@ func (m *Request) Size() (n int) { l = len(m.PrevValue) n += 1 + l + sovRequest(uint64(l)) n += 1 + sovRequest(uint64(m.PrevIndex)) - n += 2 + if m.PrevExists != nil { + n += 2 + } n += 1 + sovRequest(uint64(m.Expiration)) n += 2 n += 1 + sovRequest(uint64(m.Since)) @@ -409,14 +412,16 @@ func (m *Request) MarshalTo(data []byte) (n int, err error) { data[i] = 0x38 i++ i = encodeVarintRequest(data, i, uint64(m.PrevIndex)) - data[i] = 0x40 - i++ - if m.PrevExists { - data[i] = 1 - } else { - data[i] = 0 + if m.PrevExists != nil { + data[i] = 0x40 + i++ + if *m.PrevExists { + data[i] = 1 + } else { + data[i] = 0 + } + i++ } - i++ data[i] = 0x48 i++ i = encodeVarintRequest(data, i, uint64(m.Expiration)) diff --git a/etcdserver2/request.proto b/etcdserver2/request.proto index 69203365d..eaac41e2a 100644 --- a/etcdserver2/request.proto +++ b/etcdserver2/request.proto @@ -15,7 +15,7 @@ message Request { required bool dir = 5 [(gogoproto.nullable) = false]; required string prevValue = 6 [(gogoproto.nullable) = false]; required uint64 prevIndex = 7 [(gogoproto.nullable) = false]; - required bool prevExists = 8 [(gogoproto.nullable) = false]; + required bool prevExists = 8 [(gogoproto.nullable) = true]; required int64 expiration = 9 [(gogoproto.nullable) = false]; required bool wait = 10 [(gogoproto.nullable) = false]; required uint64 since = 11 [(gogoproto.nullable) = false]; diff --git a/etcdserver2/server.go b/etcdserver2/server.go index 9dfff8531..5c7acbea2 100644 --- a/etcdserver2/server.go +++ b/etcdserver2/server.go @@ -119,13 +119,18 @@ func (s *Server) apply(ctx context.Context, e raft.Entry) (*store.Event, error) case "POST": return s.st.Create(r.Path, r.Dir, r.Val, true, expr) case "PUT": + exists, set := getBool(r.PrevExists) switch { - case r.PrevExists: - return s.st.Update(r.Path, r.Val, expr) + case set: + if exists { + return s.st.Update(r.Path, r.Val, expr) + } else { + return s.st.Create(r.Path, r.Dir, r.Val, false, expr) + } case r.PrevIndex > 0 || r.PrevValue != "": return s.st.CompareAndSwap(r.Path, r.PrevValue, r.PrevIndex, r.Val, expr) default: - return s.st.Create(r.Path, r.Dir, r.Val, false, expr) + return s.st.Set(r.Path, r.Dir, r.Val, expr) } case "DELETE": switch { @@ -138,3 +143,10 @@ func (s *Server) apply(ctx context.Context, e raft.Entry) (*store.Event, error) return nil, ErrUnknownMethod } } + +func getBool(v *bool) (vv bool, set bool) { + if v == nil { + return false, false + } + return *v, true +}