mirror of
				https://github.com/etcd-io/etcd.git
				synced 2024-09-27 06:25:44 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			392 lines
		
	
	
		
			9.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			392 lines
		
	
	
		
			9.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Copyright 2016 The etcd Authors
 | |
| //
 | |
| // Licensed under the Apache License, Version 2.0 (the "License");
 | |
| // you may not use this file except in compliance with the License.
 | |
| // You may obtain a copy of the License at
 | |
| //
 | |
| //     http://www.apache.org/licenses/LICENSE-2.0
 | |
| //
 | |
| // Unless required by applicable law or agreed to in writing, software
 | |
| // distributed under the License is distributed on an "AS IS" BASIS,
 | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| // See the License for the specific language governing permissions and
 | |
| // limitations under the License.
 | |
| 
 | |
| package concurrency
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"math"
 | |
| 
 | |
| 	v3 "go.etcd.io/etcd/client/v3"
 | |
| )
 | |
| 
 | |
| // STM is an interface for software transactional memory.
 | |
| type STM interface {
 | |
| 	// Get returns the value for a key and inserts the key in the txn's read set.
 | |
| 	// If Get fails, it aborts the transaction with an error, never returning.
 | |
| 	Get(key ...string) string
 | |
| 	// Put adds a value for a key to the write set.
 | |
| 	Put(key, val string, opts ...v3.OpOption)
 | |
| 	// Rev returns the revision of a key in the read set.
 | |
| 	Rev(key string) int64
 | |
| 	// Del deletes a key.
 | |
| 	Del(key string)
 | |
| 
 | |
| 	// commit attempts to apply the txn's changes to the server.
 | |
| 	commit() *v3.TxnResponse
 | |
| 	reset()
 | |
| }
 | |
| 
 | |
| // Isolation is an enumeration of transactional isolation levels which
 | |
| // describes how transactions should interfere and conflict.
 | |
| type Isolation int
 | |
| 
 | |
| const (
 | |
| 	// SerializableSnapshot provides serializable isolation and also checks
 | |
| 	// for write conflicts.
 | |
| 	SerializableSnapshot Isolation = iota
 | |
| 	// Serializable reads within the same transaction attempt return data
 | |
| 	// from the at the revision of the first read.
 | |
| 	Serializable
 | |
| 	// RepeatableReads reads within the same transaction attempt always
 | |
| 	// return the same data.
 | |
| 	RepeatableReads
 | |
| 	// ReadCommitted reads keys from any committed revision.
 | |
| 	ReadCommitted
 | |
| )
 | |
| 
 | |
| // stmError safely passes STM errors through panic to the STM error channel.
 | |
| type stmError struct{ err error }
 | |
| 
 | |
| type stmOptions struct {
 | |
| 	iso      Isolation
 | |
| 	ctx      context.Context
 | |
| 	prefetch []string
 | |
| }
 | |
| 
 | |
| type stmOption func(*stmOptions)
 | |
| 
 | |
| // WithIsolation specifies the transaction isolation level.
 | |
| func WithIsolation(lvl Isolation) stmOption {
 | |
| 	return func(so *stmOptions) { so.iso = lvl }
 | |
| }
 | |
| 
 | |
| // WithAbortContext specifies the context for permanently aborting the transaction.
 | |
| func WithAbortContext(ctx context.Context) stmOption {
 | |
| 	return func(so *stmOptions) { so.ctx = ctx }
 | |
| }
 | |
| 
 | |
| // WithPrefetch is a hint to prefetch a list of keys before trying to apply.
 | |
| // If an STM transaction will unconditionally fetch a set of keys, prefetching
 | |
| // those keys will save the round-trip cost from requesting each key one by one
 | |
| // with Get().
 | |
| func WithPrefetch(keys ...string) stmOption {
 | |
| 	return func(so *stmOptions) { so.prefetch = append(so.prefetch, keys...) }
 | |
| }
 | |
| 
 | |
| // NewSTM initiates a new STM instance, using serializable snapshot isolation by default.
 | |
| func NewSTM(c *v3.Client, apply func(STM) error, so ...stmOption) (*v3.TxnResponse, error) {
 | |
| 	opts := &stmOptions{ctx: c.Ctx()}
 | |
| 	for _, f := range so {
 | |
| 		f(opts)
 | |
| 	}
 | |
| 	if len(opts.prefetch) != 0 {
 | |
| 		f := apply
 | |
| 		apply = func(s STM) error {
 | |
| 			s.Get(opts.prefetch...)
 | |
| 			return f(s)
 | |
| 		}
 | |
| 	}
 | |
| 	return runSTM(mkSTM(c, opts), apply)
 | |
| }
 | |
| 
 | |
| func mkSTM(c *v3.Client, opts *stmOptions) STM {
 | |
| 	switch opts.iso {
 | |
| 	case SerializableSnapshot:
 | |
| 		s := &stmSerializable{
 | |
| 			stm:      stm{client: c, ctx: opts.ctx},
 | |
| 			prefetch: make(map[string]*v3.GetResponse),
 | |
| 		}
 | |
| 		s.conflicts = func() []v3.Cmp {
 | |
| 			return append(s.rset.cmps(), s.wset.cmps(s.rset.first()+1)...)
 | |
| 		}
 | |
| 		return s
 | |
| 	case Serializable:
 | |
| 		s := &stmSerializable{
 | |
| 			stm:      stm{client: c, ctx: opts.ctx},
 | |
| 			prefetch: make(map[string]*v3.GetResponse),
 | |
| 		}
 | |
| 		s.conflicts = func() []v3.Cmp { return s.rset.cmps() }
 | |
| 		return s
 | |
| 	case RepeatableReads:
 | |
| 		s := &stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}}
 | |
| 		s.conflicts = func() []v3.Cmp { return s.rset.cmps() }
 | |
| 		return s
 | |
| 	case ReadCommitted:
 | |
| 		s := &stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}}
 | |
| 		s.conflicts = func() []v3.Cmp { return nil }
 | |
| 		return s
 | |
| 	default:
 | |
| 		panic("unsupported stm")
 | |
| 	}
 | |
| }
 | |
| 
 | |
| type stmResponse struct {
 | |
| 	resp *v3.TxnResponse
 | |
| 	err  error
 | |
| }
 | |
| 
 | |
| func runSTM(s STM, apply func(STM) error) (*v3.TxnResponse, error) {
 | |
| 	outc := make(chan stmResponse, 1)
 | |
| 	go func() {
 | |
| 		defer func() {
 | |
| 			if r := recover(); r != nil {
 | |
| 				e, ok := r.(stmError)
 | |
| 				if !ok {
 | |
| 					// client apply panicked
 | |
| 					panic(r)
 | |
| 				}
 | |
| 				outc <- stmResponse{nil, e.err}
 | |
| 			}
 | |
| 		}()
 | |
| 		var out stmResponse
 | |
| 		for {
 | |
| 			s.reset()
 | |
| 			if out.err = apply(s); out.err != nil {
 | |
| 				break
 | |
| 			}
 | |
| 			if out.resp = s.commit(); out.resp != nil {
 | |
| 				break
 | |
| 			}
 | |
| 		}
 | |
| 		outc <- out
 | |
| 	}()
 | |
| 	r := <-outc
 | |
| 	return r.resp, r.err
 | |
| }
 | |
| 
 | |
| // stm implements repeatable-read software transactional memory over etcd
 | |
| type stm struct {
 | |
| 	client *v3.Client
 | |
| 	ctx    context.Context
 | |
| 	// rset holds read key values and revisions
 | |
| 	rset readSet
 | |
| 	// wset holds overwritten keys and their values
 | |
| 	wset writeSet
 | |
| 	// getOpts are the opts used for gets
 | |
| 	getOpts []v3.OpOption
 | |
| 	// conflicts computes the current conflicts on the txn
 | |
| 	conflicts func() []v3.Cmp
 | |
| }
 | |
| 
 | |
| type stmPut struct {
 | |
| 	val string
 | |
| 	op  v3.Op
 | |
| }
 | |
| 
 | |
| type readSet map[string]*v3.GetResponse
 | |
| 
 | |
| func (rs readSet) add(keys []string, txnresp *v3.TxnResponse) {
 | |
| 	for i, resp := range txnresp.Responses {
 | |
| 		rs[keys[i]] = (*v3.GetResponse)(resp.GetResponseRange())
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // first returns the store revision from the first fetch
 | |
| func (rs readSet) first() int64 {
 | |
| 	ret := int64(math.MaxInt64 - 1)
 | |
| 	for _, resp := range rs {
 | |
| 		if rev := resp.Header.Revision; rev < ret {
 | |
| 			ret = rev
 | |
| 		}
 | |
| 	}
 | |
| 	return ret
 | |
| }
 | |
| 
 | |
| // cmps guards the txn from updates to read set
 | |
| func (rs readSet) cmps() []v3.Cmp {
 | |
| 	cmps := make([]v3.Cmp, 0, len(rs))
 | |
| 	for k, rk := range rs {
 | |
| 		cmps = append(cmps, isKeyCurrent(k, rk))
 | |
| 	}
 | |
| 	return cmps
 | |
| }
 | |
| 
 | |
| type writeSet map[string]stmPut
 | |
| 
 | |
| func (ws writeSet) get(keys ...string) *stmPut {
 | |
| 	for _, key := range keys {
 | |
| 		if wv, ok := ws[key]; ok {
 | |
| 			return &wv
 | |
| 		}
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // cmps returns a cmp list testing no writes have happened past rev
 | |
| func (ws writeSet) cmps(rev int64) []v3.Cmp {
 | |
| 	cmps := make([]v3.Cmp, 0, len(ws))
 | |
| 	for key := range ws {
 | |
| 		cmps = append(cmps, v3.Compare(v3.ModRevision(key), "<", rev))
 | |
| 	}
 | |
| 	return cmps
 | |
| }
 | |
| 
 | |
| // puts is the list of ops for all pending writes
 | |
| func (ws writeSet) puts() []v3.Op {
 | |
| 	puts := make([]v3.Op, 0, len(ws))
 | |
| 	for _, v := range ws {
 | |
| 		puts = append(puts, v.op)
 | |
| 	}
 | |
| 	return puts
 | |
| }
 | |
| 
 | |
| func (s *stm) Get(keys ...string) string {
 | |
| 	if wv := s.wset.get(keys...); wv != nil {
 | |
| 		return wv.val
 | |
| 	}
 | |
| 	return respToValue(s.fetch(keys...))
 | |
| }
 | |
| 
 | |
| func (s *stm) Put(key, val string, opts ...v3.OpOption) {
 | |
| 	s.wset[key] = stmPut{val, v3.OpPut(key, val, opts...)}
 | |
| }
 | |
| 
 | |
| func (s *stm) Del(key string) { s.wset[key] = stmPut{"", v3.OpDelete(key)} }
 | |
| 
 | |
| func (s *stm) Rev(key string) int64 {
 | |
| 	if resp := s.fetch(key); resp != nil && len(resp.Kvs) != 0 {
 | |
| 		return resp.Kvs[0].ModRevision
 | |
| 	}
 | |
| 	return 0
 | |
| }
 | |
| 
 | |
| func (s *stm) commit() *v3.TxnResponse {
 | |
| 	txnresp, err := s.client.Txn(s.ctx).If(s.conflicts()...).Then(s.wset.puts()...).Commit()
 | |
| 	if err != nil {
 | |
| 		panic(stmError{err})
 | |
| 	}
 | |
| 	if txnresp.Succeeded {
 | |
| 		return txnresp
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (s *stm) fetch(keys ...string) *v3.GetResponse {
 | |
| 	if len(keys) == 0 {
 | |
| 		return nil
 | |
| 	}
 | |
| 	ops := make([]v3.Op, len(keys))
 | |
| 	for i, key := range keys {
 | |
| 		if resp, ok := s.rset[key]; ok {
 | |
| 			return resp
 | |
| 		}
 | |
| 		ops[i] = v3.OpGet(key, s.getOpts...)
 | |
| 	}
 | |
| 	txnresp, err := s.client.Txn(s.ctx).Then(ops...).Commit()
 | |
| 	if err != nil {
 | |
| 		panic(stmError{err})
 | |
| 	}
 | |
| 	s.rset.add(keys, txnresp)
 | |
| 	return (*v3.GetResponse)(txnresp.Responses[0].GetResponseRange())
 | |
| }
 | |
| 
 | |
| func (s *stm) reset() {
 | |
| 	s.rset = make(map[string]*v3.GetResponse)
 | |
| 	s.wset = make(map[string]stmPut)
 | |
| }
 | |
| 
 | |
| type stmSerializable struct {
 | |
| 	stm
 | |
| 	prefetch map[string]*v3.GetResponse
 | |
| }
 | |
| 
 | |
| func (s *stmSerializable) Get(keys ...string) string {
 | |
| 	if len(keys) == 0 {
 | |
| 		return ""
 | |
| 	}
 | |
| 
 | |
| 	if wv := s.wset.get(keys...); wv != nil {
 | |
| 		return wv.val
 | |
| 	}
 | |
| 	firstRead := len(s.rset) == 0
 | |
| 	for _, key := range keys {
 | |
| 		if resp, ok := s.prefetch[key]; ok {
 | |
| 			delete(s.prefetch, key)
 | |
| 			s.rset[key] = resp
 | |
| 		}
 | |
| 	}
 | |
| 	resp := s.stm.fetch(keys...)
 | |
| 	if firstRead {
 | |
| 		// txn's base revision is defined by the first read
 | |
| 		s.getOpts = []v3.OpOption{
 | |
| 			v3.WithRev(resp.Header.Revision),
 | |
| 			v3.WithSerializable(),
 | |
| 		}
 | |
| 	}
 | |
| 	return respToValue(resp)
 | |
| }
 | |
| 
 | |
| func (s *stmSerializable) Rev(key string) int64 {
 | |
| 	s.Get(key)
 | |
| 	return s.stm.Rev(key)
 | |
| }
 | |
| 
 | |
| func (s *stmSerializable) gets() ([]string, []v3.Op) {
 | |
| 	keys := make([]string, 0, len(s.rset))
 | |
| 	ops := make([]v3.Op, 0, len(s.rset))
 | |
| 	for k := range s.rset {
 | |
| 		keys = append(keys, k)
 | |
| 		ops = append(ops, v3.OpGet(k))
 | |
| 	}
 | |
| 	return keys, ops
 | |
| }
 | |
| 
 | |
| func (s *stmSerializable) commit() *v3.TxnResponse {
 | |
| 	keys, getops := s.gets()
 | |
| 	txn := s.client.Txn(s.ctx).If(s.conflicts()...).Then(s.wset.puts()...)
 | |
| 	// use Else to prefetch keys in case of conflict to save a round trip
 | |
| 	txnresp, err := txn.Else(getops...).Commit()
 | |
| 	if err != nil {
 | |
| 		panic(stmError{err})
 | |
| 	}
 | |
| 	if txnresp.Succeeded {
 | |
| 		return txnresp
 | |
| 	}
 | |
| 	// load prefetch with Else data
 | |
| 	s.rset.add(keys, txnresp)
 | |
| 	s.prefetch = s.rset
 | |
| 	s.getOpts = nil
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func isKeyCurrent(k string, r *v3.GetResponse) v3.Cmp {
 | |
| 	if len(r.Kvs) != 0 {
 | |
| 		return v3.Compare(v3.ModRevision(k), "=", r.Kvs[0].ModRevision)
 | |
| 	}
 | |
| 	return v3.Compare(v3.ModRevision(k), "=", 0)
 | |
| }
 | |
| 
 | |
| func respToValue(resp *v3.GetResponse) string {
 | |
| 	if resp == nil || len(resp.Kvs) == 0 {
 | |
| 		return ""
 | |
| 	}
 | |
| 	return string(resp.Kvs[0].Value)
 | |
| }
 | |
| 
 | |
| // NewSTMRepeatable is deprecated.
 | |
| func NewSTMRepeatable(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) {
 | |
| 	return NewSTM(c, apply, WithAbortContext(ctx), WithIsolation(RepeatableReads))
 | |
| }
 | |
| 
 | |
| // NewSTMSerializable is deprecated.
 | |
| func NewSTMSerializable(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) {
 | |
| 	return NewSTM(c, apply, WithAbortContext(ctx), WithIsolation(Serializable))
 | |
| }
 | |
| 
 | |
| // NewSTMReadCommitted is deprecated.
 | |
| func NewSTMReadCommitted(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) {
 | |
| 	return NewSTM(c, apply, WithAbortContext(ctx), WithIsolation(ReadCommitted))
 | |
| }
 | 
