diff --git a/etcdserver/api/v3rpc/key.go b/etcdserver/api/v3rpc/key.go index 3f977b749..973346592 100644 --- a/etcdserver/api/v3rpc/key.go +++ b/etcdserver/api/v3rpc/key.go @@ -16,11 +16,10 @@ package v3rpc import ( - "sort" - "github.com/coreos/etcd/etcdserver" "github.com/coreos/etcd/etcdserver/api/v3rpc/rpctypes" pb "github.com/coreos/etcd/etcdserver/etcdserverpb" + "github.com/coreos/etcd/pkg/adt" "github.com/coreos/pkg/capnslog" "golang.org/x/net/context" ) @@ -89,6 +88,13 @@ func (s *kvServer) Txn(ctx context.Context, r *pb.TxnRequest) (*pb.TxnResponse, if err := checkTxnRequest(r, int(s.maxTxnOps)); err != nil { return nil, err } + // check for forbidden put/del overlaps after checking request to avoid quadratic blowup + if _, _, err := checkIntervals(r.Success); err != nil { + return nil, err + } + if _, _, err := checkIntervals(r.Failure); err != nil { + return nil, err + } resp, err := s.kv.Txn(ctx, r) if err != nil { @@ -137,7 +143,14 @@ func checkDeleteRequest(r *pb.DeleteRangeRequest) error { } func checkTxnRequest(r *pb.TxnRequest, maxTxnOps int) error { - if len(r.Compare) > maxTxnOps || len(r.Success) > maxTxnOps || len(r.Failure) > maxTxnOps { + opc := len(r.Compare) + if opc < len(r.Success) { + opc = len(r.Success) + } + if opc < len(r.Failure) { + opc = len(r.Failure) + } + if opc > maxTxnOps { return rpctypes.ErrGRPCTooManyOps } @@ -146,58 +159,29 @@ func checkTxnRequest(r *pb.TxnRequest, maxTxnOps int) error { return rpctypes.ErrGRPCEmptyKey } } - for _, u := range r.Success { - if err := checkRequestOp(u); err != nil { + if err := checkRequestOp(u, maxTxnOps-opc); err != nil { return err } } - if err := checkRequestDupKeys(r.Success); err != nil { - return err + for _, u := range r.Failure { + if err := checkRequestOp(u, maxTxnOps-opc); err != nil { + return err + } } - for _, u := range r.Failure { - if err := checkRequestOp(u); err != nil { - return err - } - } - return checkRequestDupKeys(r.Failure) + return nil } -// checkRequestDupKeys gives rpctypes.ErrGRPCDuplicateKey if the same key is modified twice -func checkRequestDupKeys(reqs []*pb.RequestOp) error { - // check put overlap - keys := make(map[string]struct{}) - for _, requ := range reqs { - tv, ok := requ.Request.(*pb.RequestOp_RequestPut) - if !ok { - continue - } - preq := tv.RequestPut - if preq == nil { - continue - } - if _, ok := keys[string(preq.Key)]; ok { - return rpctypes.ErrGRPCDuplicateKey - } - keys[string(preq.Key)] = struct{}{} - } +// checkIntervals tests whether puts and deletes overlap for a list of ops. If +// there is an overlap, returns an error. If no overlap, return put and delete +// sets for recursive evaluation. +func checkIntervals(reqs []*pb.RequestOp) (map[string]struct{}, adt.IntervalTree, error) { + var dels adt.IntervalTree - // no need to check deletes if no puts; delete overlaps are permitted - if len(keys) == 0 { - return nil - } - - // sort keys for range checking - sortedKeys := []string{} - for k := range keys { - sortedKeys = append(sortedKeys, k) - } - sort.Strings(sortedKeys) - - // check put overlap with deletes - for _, requ := range reqs { - tv, ok := requ.Request.(*pb.RequestOp_RequestDeleteRange) + // collect deletes from this level; build first to check lower level overlapped puts + for _, req := range reqs { + tv, ok := req.Request.(*pb.RequestOp_RequestDeleteRange) if !ok { continue } @@ -205,41 +189,87 @@ func checkRequestDupKeys(reqs []*pb.RequestOp) error { if dreq == nil { continue } - if dreq.RangeEnd == nil { - if _, found := keys[string(dreq.Key)]; found { - return rpctypes.ErrGRPCDuplicateKey - } + var iv adt.Interval + if len(dreq.RangeEnd) != 0 { + iv = adt.NewStringAffineInterval(string(dreq.Key), string(dreq.RangeEnd)) } else { - lo := sort.SearchStrings(sortedKeys, string(dreq.Key)) - hi := sort.SearchStrings(sortedKeys, string(dreq.RangeEnd)) - if lo != hi { - // element between lo and hi => overlap - return rpctypes.ErrGRPCDuplicateKey - } + iv = adt.NewStringAffinePoint(string(dreq.Key)) } + dels.Insert(iv, struct{}{}) } - return nil + // collect children puts/deletes + puts := make(map[string]struct{}) + for _, req := range reqs { + tv, ok := req.Request.(*pb.RequestOp_RequestTxn) + if !ok { + continue + } + putsThen, delsThen, err := checkIntervals(tv.RequestTxn.Success) + if err != nil { + return nil, dels, err + } + putsElse, delsElse, err := checkIntervals(tv.RequestTxn.Failure) + if err != nil { + return nil, dels, err + } + for k := range putsThen { + if _, ok := puts[k]; ok { + return nil, dels, rpctypes.ErrGRPCDuplicateKey + } + if dels.Intersects(adt.NewStringAffinePoint(k)) { + return nil, dels, rpctypes.ErrGRPCDuplicateKey + } + puts[k] = struct{}{} + } + for k := range putsElse { + if _, ok := puts[k]; ok { + // if key is from putsThen, overlap is OK since + // either then/else are mutually exclusive + if _, isSafe := putsThen[k]; !isSafe { + return nil, dels, rpctypes.ErrGRPCDuplicateKey + } + } + if dels.Intersects(adt.NewStringAffinePoint(k)) { + return nil, dels, rpctypes.ErrGRPCDuplicateKey + } + puts[k] = struct{}{} + } + dels.Union(delsThen, adt.NewStringAffineInterval("\x00", "")) + dels.Union(delsElse, adt.NewStringAffineInterval("\x00", "")) + } + + // collect and check this level's puts + for _, req := range reqs { + tv, ok := req.Request.(*pb.RequestOp_RequestPut) + if !ok || tv.RequestPut == nil { + continue + } + k := string(tv.RequestPut.Key) + if _, ok := puts[k]; ok { + return nil, dels, rpctypes.ErrGRPCDuplicateKey + } + if dels.Intersects(adt.NewStringAffinePoint(k)) { + return nil, dels, rpctypes.ErrGRPCDuplicateKey + } + puts[k] = struct{}{} + } + return puts, dels, nil } -func checkRequestOp(u *pb.RequestOp) error { +func checkRequestOp(u *pb.RequestOp, maxTxnOps int) error { // TODO: ensure only one of the field is set. switch uv := u.Request.(type) { case *pb.RequestOp_RequestRange: - if uv.RequestRange != nil { - return checkRangeRequest(uv.RequestRange) - } + return checkRangeRequest(uv.RequestRange) case *pb.RequestOp_RequestPut: - if uv.RequestPut != nil { - return checkPutRequest(uv.RequestPut) - } + return checkPutRequest(uv.RequestPut) case *pb.RequestOp_RequestDeleteRange: - if uv.RequestDeleteRange != nil { - return checkDeleteRequest(uv.RequestDeleteRange) - } + return checkDeleteRequest(uv.RequestDeleteRange) + case *pb.RequestOp_RequestTxn: + return checkTxnRequest(uv.RequestTxn, maxTxnOps) default: // empty op / nil entry return rpctypes.ErrGRPCKeyNotFound } - return nil } diff --git a/etcdserver/apply.go b/etcdserver/apply.go index 707ac89f0..3aa8d71e6 100644 --- a/etcdserver/apply.go +++ b/etcdserver/apply.go @@ -76,14 +76,30 @@ type applierV3 interface { RoleList(ua *pb.AuthRoleListRequest) (*pb.AuthRoleListResponse, error) } +type checkReqFunc func(mvcc.ReadView, *pb.RequestOp) error + type applierV3backend struct { s *EtcdServer + + checkPut checkReqFunc + checkRange checkReqFunc +} + +func (s *EtcdServer) newApplierV3Backend() applierV3 { + base := &applierV3backend{s: s} + base.checkPut = func(rv mvcc.ReadView, req *pb.RequestOp) error { + return base.checkRequestPut(rv, req) + } + base.checkRange = func(rv mvcc.ReadView, req *pb.RequestOp) error { + return base.checkRequestRange(rv, req) + } + return base } func (s *EtcdServer) newApplierV3() applierV3 { return newAuthApplierV3( s.AuthStore(), - newQuotaApplierV3(s, &applierV3backend{s}), + newQuotaApplierV3(s, s.newApplierV3Backend()), s.lessor, ) } @@ -315,24 +331,19 @@ func (a *applierV3backend) Txn(rt *pb.TxnRequest) (*pb.TxnResponse, error) { isWrite := !isTxnReadonly(rt) txn := mvcc.NewReadOnlyTxnWrite(a.s.KV().Read()) - reqs, ok := a.compareToOps(txn, rt) + txnPath := compareToPath(txn, rt) if isWrite { - if err := a.checkRequestPut(txn, reqs); err != nil { + if _, err := checkRequests(txn, rt, txnPath, a.checkPut); err != nil { txn.End() return nil, err } } - if err := checkRequestRange(txn, reqs); err != nil { + if _, err := checkRequests(txn, rt, txnPath, a.checkRange); err != nil { txn.End() return nil, err } - resps := make([]*pb.ResponseOp, len(reqs)) - txnResp := &pb.TxnResponse{ - Responses: resps, - Succeeded: ok, - Header: &pb.ResponseHeader{}, - } + txnResp, _ := newTxnResp(rt, txnPath) // When executing mutable txn ops, etcd must hold the txn lock so // readers do not see any intermediate results. Since writes are @@ -342,9 +353,7 @@ func (a *applierV3backend) Txn(rt *pb.TxnRequest) (*pb.TxnResponse, error) { txn.End() txn = a.s.KV().Write() } - for i := range reqs { - resps[i] = a.applyUnion(txn, reqs[i]) - } + a.applyTxn(txn, rt, txnPath, txnResp) rev := txn.Rev() if len(txn.Changes()) != 0 { rev++ @@ -355,13 +364,60 @@ func (a *applierV3backend) Txn(rt *pb.TxnRequest) (*pb.TxnResponse, error) { return txnResp, nil } -func (a *applierV3backend) compareToOps(rv mvcc.ReadView, rt *pb.TxnRequest) ([]*pb.RequestOp, bool) { - for _, c := range rt.Compare { - if !applyCompare(rv, c) { - return rt.Failure, false +// newTxnResp allocates a txn response for a txn request given a path. +func newTxnResp(rt *pb.TxnRequest, txnPath []bool) (txnResp *pb.TxnResponse, txnCount int) { + reqs := rt.Success + if !txnPath[0] { + reqs = rt.Failure + } + resps := make([]*pb.ResponseOp, len(reqs)) + txnResp = &pb.TxnResponse{ + Responses: resps, + Succeeded: txnPath[0], + Header: &pb.ResponseHeader{}, + } + for i, req := range reqs { + switch tv := req.Request.(type) { + case *pb.RequestOp_RequestRange: + resps[i] = &pb.ResponseOp{Response: &pb.ResponseOp_ResponseRange{}} + case *pb.RequestOp_RequestPut: + resps[i] = &pb.ResponseOp{Response: &pb.ResponseOp_ResponsePut{}} + case *pb.RequestOp_RequestDeleteRange: + resps[i] = &pb.ResponseOp{Response: &pb.ResponseOp_ResponseDeleteRange{}} + case *pb.RequestOp_RequestTxn: + resp, txns := newTxnResp(tv.RequestTxn, txnPath[1:]) + resps[i] = &pb.ResponseOp{Response: &pb.ResponseOp_ResponseTxn{ResponseTxn: resp}} + txnPath = txnPath[1+txns:] + txnCount += txns + 1 + default: } } - return rt.Success, true + return txnResp, txnCount +} + +func compareToPath(rv mvcc.ReadView, rt *pb.TxnRequest) []bool { + txnPath := make([]bool, 1) + ops := rt.Success + if txnPath[0] = applyCompares(rv, rt.Compare); !txnPath[0] { + ops = rt.Failure + } + for _, op := range ops { + tv, ok := op.Request.(*pb.RequestOp_RequestTxn) + if !ok || tv.RequestTxn == nil { + continue + } + txnPath = append(txnPath, compareToPath(rv, tv.RequestTxn)...) + } + return txnPath +} + +func applyCompares(rv mvcc.ReadView, cmps []*pb.Compare) bool { + for _, c := range cmps { + if !applyCompare(rv, c) { + return false + } + } + return true } // applyCompare applies the compare request. @@ -431,38 +487,42 @@ func compareKV(c *pb.Compare, ckv mvccpb.KeyValue) bool { return true } -func (a *applierV3backend) applyUnion(txn mvcc.TxnWrite, union *pb.RequestOp) *pb.ResponseOp { - switch tv := union.Request.(type) { - case *pb.RequestOp_RequestRange: - if tv.RequestRange != nil { +func (a *applierV3backend) applyTxn(txn mvcc.TxnWrite, rt *pb.TxnRequest, txnPath []bool, tresp *pb.TxnResponse) (txns int) { + reqs := rt.Success + if !txnPath[0] { + reqs = rt.Failure + } + for i, req := range reqs { + respi := tresp.Responses[i].Response + switch tv := req.Request.(type) { + case *pb.RequestOp_RequestRange: resp, err := a.Range(txn, tv.RequestRange) if err != nil { plog.Panicf("unexpected error during txn: %v", err) } - return &pb.ResponseOp{Response: &pb.ResponseOp_ResponseRange{ResponseRange: resp}} - } - case *pb.RequestOp_RequestPut: - if tv.RequestPut != nil { + respi.(*pb.ResponseOp_ResponseRange).ResponseRange = resp + case *pb.RequestOp_RequestPut: resp, err := a.Put(txn, tv.RequestPut) if err != nil { plog.Panicf("unexpected error during txn: %v", err) } - return &pb.ResponseOp{Response: &pb.ResponseOp_ResponsePut{ResponsePut: resp}} - } - case *pb.RequestOp_RequestDeleteRange: - if tv.RequestDeleteRange != nil { + respi.(*pb.ResponseOp_ResponsePut).ResponsePut = resp + case *pb.RequestOp_RequestDeleteRange: resp, err := a.DeleteRange(txn, tv.RequestDeleteRange) if err != nil { plog.Panicf("unexpected error during txn: %v", err) } - return &pb.ResponseOp{Response: &pb.ResponseOp_ResponseDeleteRange{ResponseDeleteRange: resp}} + respi.(*pb.ResponseOp_ResponseDeleteRange).ResponseDeleteRange = resp + case *pb.RequestOp_RequestTxn: + resp := respi.(*pb.ResponseOp_ResponseTxn).ResponseTxn + applyTxns := a.applyTxn(txn, tv.RequestTxn, txnPath[1:], resp) + txns += applyTxns + 1 + txnPath = txnPath[applyTxns+1:] + default: + // empty union } - default: - // empty union - return nil } - return nil - + return txns } func (a *applierV3backend) Compaction(compaction *pb.CompactionRequest) (*pb.CompactionResponse, <-chan struct{}, error) { @@ -768,53 +828,66 @@ func (s *kvSortByValue) Less(i, j int) bool { return bytes.Compare(s.kvs[i].Value, s.kvs[j].Value) < 0 } -func (a *applierV3backend) checkRequestPut(rv mvcc.ReadView, reqs []*pb.RequestOp) error { - for _, requ := range reqs { - tv, ok := requ.Request.(*pb.RequestOp_RequestPut) - if !ok { - continue - } - preq := tv.RequestPut - if preq == nil { - continue - } - if preq.IgnoreValue || preq.IgnoreLease { - // expects previous key-value, error if not exist - rr, err := rv.Range(preq.Key, nil, mvcc.RangeOptions{}) +func checkRequests(rv mvcc.ReadView, rt *pb.TxnRequest, txnPath []bool, f checkReqFunc) (int, error) { + txnCount := 0 + reqs := rt.Success + if !txnPath[0] { + reqs = rt.Failure + } + for _, req := range reqs { + if tv, ok := req.Request.(*pb.RequestOp_RequestTxn); ok && tv.RequestTxn != nil { + txns, err := checkRequests(rv, tv.RequestTxn, txnPath[1:], f) if err != nil { - return err + return 0, err } - if rr == nil || len(rr.KVs) == 0 { - return ErrKeyNotFound - } - } - if lease.LeaseID(preq.Lease) == lease.NoLease { + txnCount += txns + 1 + txnPath = txnPath[txns+1:] continue } - if l := a.s.lessor.Lookup(lease.LeaseID(preq.Lease)); l == nil { + if err := f(rv, req); err != nil { + return 0, err + } + } + return txnCount, nil +} + +func (a *applierV3backend) checkRequestPut(rv mvcc.ReadView, reqOp *pb.RequestOp) error { + tv, ok := reqOp.Request.(*pb.RequestOp_RequestPut) + if !ok || tv.RequestPut == nil { + return nil + } + req := tv.RequestPut + if req.IgnoreValue || req.IgnoreLease { + // expects previous key-value, error if not exist + rr, err := rv.Range(req.Key, nil, mvcc.RangeOptions{}) + if err != nil { + return err + } + if rr == nil || len(rr.KVs) == 0 { + return ErrKeyNotFound + } + } + if lease.LeaseID(req.Lease) != lease.NoLease { + if l := a.s.lessor.Lookup(lease.LeaseID(req.Lease)); l == nil { return lease.ErrLeaseNotFound } } return nil } -func checkRequestRange(rv mvcc.ReadView, reqs []*pb.RequestOp) error { - for _, requ := range reqs { - tv, ok := requ.Request.(*pb.RequestOp_RequestRange) - if !ok { - continue - } - greq := tv.RequestRange - if greq == nil || greq.Revision == 0 { - continue - } - - if greq.Revision > rv.Rev() { - return mvcc.ErrFutureRev - } - if greq.Revision < rv.FirstRev() { - return mvcc.ErrCompacted - } +func (a *applierV3backend) checkRequestRange(rv mvcc.ReadView, reqOp *pb.RequestOp) error { + tv, ok := reqOp.Request.(*pb.RequestOp_RequestRange) + if !ok || tv.RequestRange == nil { + return nil + } + req := tv.RequestRange + switch { + case req.Revision == 0: + return nil + case req.Revision > rv.Rev(): + return mvcc.ErrFutureRev + case req.Revision < rv.FirstRev(): + return mvcc.ErrCompacted } return nil } diff --git a/etcdserver/server.go b/etcdserver/server.go index ac18b881e..f2f3874d5 100644 --- a/etcdserver/server.go +++ b/etcdserver/server.go @@ -474,7 +474,7 @@ func NewServer(cfg ServerConfig) (srv *EtcdServer, err error) { srv.compactor.Run() } - srv.applyV3Base = &applierV3backend{srv} + srv.applyV3Base = srv.newApplierV3Backend() if err = srv.restoreAlarms(); err != nil { return nil, err }