[NOD-721] Add defers (#638)

* [NOD-721] Defer unlocks

* [NOD-721] Add functions with locks to rpcmodel

* [NOD-721] Defer unlocks

* [NOD-721] Add filterDataWithLock function

* [NOD-721] Defer unlocks

* [NOD-721] Defer .Close()

* [NOD-721] Fix access to wsc.filterData without a lock

* [NOD-721] De-anonymize some anonymous functions

* [NOD-721] Remove redundant assignments

* [NOD-721] Remove redundant assignments

* [NOD-721] Remove redundant assignments

* [NOD-721] Get rid of submitOld, and break handleGetBlockTemplateLongPoll to smaller functions

* [NOD-721] Rename existsUnspentOutpoint->existsUnspentOutpointNoLock, existsUnspentOutpointWithLock->existsUnspentOutpoint

* [NOD-721] Rename filterDataWithLock->FilterData

* [NOD-721] Fixed comments
This commit is contained in:
Ori Newman 2020-02-24 09:19:44 +02:00 committed by GitHub
parent 98987f4a8f
commit de9aa39cc5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 338 additions and 308 deletions

View File

@ -43,8 +43,8 @@ func newBlockIndex(db database.DB, dagParams *dagconfig.Params) *blockIndex {
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (bi *blockIndex) HaveBlock(hash *daghash.Hash) bool { func (bi *blockIndex) HaveBlock(hash *daghash.Hash) bool {
bi.RLock() bi.RLock()
defer bi.RUnlock()
_, hasBlock := bi.index[*hash] _, hasBlock := bi.index[*hash]
bi.RUnlock()
return hasBlock return hasBlock
} }
@ -54,8 +54,8 @@ func (bi *blockIndex) HaveBlock(hash *daghash.Hash) bool {
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (bi *blockIndex) LookupNode(hash *daghash.Hash) *blockNode { func (bi *blockIndex) LookupNode(hash *daghash.Hash) *blockNode {
bi.RLock() bi.RLock()
defer bi.RUnlock()
node := bi.index[*hash] node := bi.index[*hash]
bi.RUnlock()
return node return node
} }
@ -65,9 +65,9 @@ func (bi *blockIndex) LookupNode(hash *daghash.Hash) *blockNode {
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (bi *blockIndex) AddNode(node *blockNode) { func (bi *blockIndex) AddNode(node *blockNode) {
bi.Lock() bi.Lock()
defer bi.Unlock()
bi.addNode(node) bi.addNode(node)
bi.dirty[node] = struct{}{} bi.dirty[node] = struct{}{}
bi.Unlock()
} }
// addNode adds the provided node to the block index, but does not mark it as // addNode adds the provided node to the block index, but does not mark it as
@ -83,8 +83,8 @@ func (bi *blockIndex) addNode(node *blockNode) {
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (bi *blockIndex) NodeStatus(node *blockNode) blockStatus { func (bi *blockIndex) NodeStatus(node *blockNode) blockStatus {
bi.RLock() bi.RLock()
defer bi.RUnlock()
status := node.status status := node.status
bi.RUnlock()
return status return status
} }
@ -95,9 +95,9 @@ func (bi *blockIndex) NodeStatus(node *blockNode) blockStatus {
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (bi *blockIndex) SetStatusFlags(node *blockNode, flags blockStatus) { func (bi *blockIndex) SetStatusFlags(node *blockNode, flags blockStatus) {
bi.Lock() bi.Lock()
defer bi.Unlock()
node.status |= flags node.status |= flags
bi.dirty[node] = struct{}{} bi.dirty[node] = struct{}{}
bi.Unlock()
} }
// UnsetStatusFlags flips the provided status flags on the block node to off, // UnsetStatusFlags flips the provided status flags on the block node to off,
@ -106,9 +106,9 @@ func (bi *blockIndex) SetStatusFlags(node *blockNode, flags blockStatus) {
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (bi *blockIndex) UnsetStatusFlags(node *blockNode, flags blockStatus) { func (bi *blockIndex) UnsetStatusFlags(node *blockNode, flags blockStatus) {
bi.Lock() bi.Lock()
defer bi.Unlock()
node.status &^= flags node.status &^= flags
bi.dirty[node] = struct{}{} bi.dirty[node] = struct{}{}
bi.Unlock()
} }
// flushToDB writes all dirty block nodes to the database. If all writes // flushToDB writes all dirty block nodes to the database. If all writes

View File

@ -194,8 +194,8 @@ func (dag *BlockDAG) IsKnownOrphan(hash *daghash.Hash) bool {
// Protect concurrent access. Using a read lock only so multiple // Protect concurrent access. Using a read lock only so multiple
// readers can query without blocking each other. // readers can query without blocking each other.
dag.orphanLock.RLock() dag.orphanLock.RLock()
defer dag.orphanLock.RUnlock()
_, exists := dag.orphans[*hash] _, exists := dag.orphans[*hash]
dag.orphanLock.RUnlock()
return exists return exists
} }

View File

@ -771,6 +771,7 @@ func (idx *AddrIndex) indexUnconfirmedAddresses(scriptPubKey []byte, tx *util.Tx
// Add a mapping from the address to the transaction. // Add a mapping from the address to the transaction.
idx.unconfirmedLock.Lock() idx.unconfirmedLock.Lock()
defer idx.unconfirmedLock.Unlock()
addrIndexEntry := idx.txnsByAddr[addrKey] addrIndexEntry := idx.txnsByAddr[addrKey]
if addrIndexEntry == nil { if addrIndexEntry == nil {
addrIndexEntry = make(map[daghash.TxID]*util.Tx) addrIndexEntry = make(map[daghash.TxID]*util.Tx)
@ -785,7 +786,6 @@ func (idx *AddrIndex) indexUnconfirmedAddresses(scriptPubKey []byte, tx *util.Tx
idx.addrsByTx[*tx.ID()] = addrsByTxEntry idx.addrsByTx[*tx.ID()] = addrsByTxEntry
} }
addrsByTxEntry[addrKey] = struct{}{} addrsByTxEntry[addrKey] = struct{}{}
idx.unconfirmedLock.Unlock()
} }
// AddUnconfirmedTx adds all addresses related to the transaction to the // AddUnconfirmedTx adds all addresses related to the transaction to the

View File

@ -57,8 +57,8 @@ type Notification struct {
// NotificationType for details on the types and contents of notifications. // NotificationType for details on the types and contents of notifications.
func (dag *BlockDAG) Subscribe(callback NotificationCallback) { func (dag *BlockDAG) Subscribe(callback NotificationCallback) {
dag.notificationsLock.Lock() dag.notificationsLock.Lock()
defer dag.notificationsLock.Unlock()
dag.notifications = append(dag.notifications, callback) dag.notifications = append(dag.notifications, callback)
dag.notificationsLock.Unlock()
} }
// sendNotification sends a notification with the passed type and data if the // sendNotification sends a notification with the passed type and data if the
@ -68,10 +68,10 @@ func (dag *BlockDAG) sendNotification(typ NotificationType, data interface{}) {
// Generate and send the notification. // Generate and send the notification.
n := Notification{Type: typ, Data: data} n := Notification{Type: typ, Data: data}
dag.notificationsLock.RLock() dag.notificationsLock.RLock()
defer dag.notificationsLock.RUnlock()
for _, callback := range dag.notifications { for _, callback := range dag.notifications {
callback(&n) callback(&n)
} }
dag.notificationsLock.RUnlock()
} }
// BlockAddedNotificationData defines data to be sent along with a BlockAdded // BlockAddedNotificationData defines data to be sent along with a BlockAdded

View File

@ -113,8 +113,8 @@ func (v *virtualBlock) updateSelectedParentSet(oldSelectedParent *blockNode) *ch
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (v *virtualBlock) SetTips(tips blockSet) { func (v *virtualBlock) SetTips(tips blockSet) {
v.mtx.Lock() v.mtx.Lock()
defer v.mtx.Unlock()
v.setTips(tips) v.setTips(tips)
v.mtx.Unlock()
} }
// addTip adds the given tip to the set of tips in the virtual block. // addTip adds the given tip to the set of tips in the virtual block.

View File

@ -95,11 +95,12 @@ func sendPostRequest(marshalledJSON []byte, cfg *ConfigFlags) ([]byte, error) {
} }
// Read the raw bytes and close the response. // Read the raw bytes and close the response.
respBytes, err := ioutil.ReadAll(httpResponse.Body) respBytes, err := func() ([]byte, error) {
httpResponse.Body.Close() defer httpResponse.Body.Close()
return ioutil.ReadAll(httpResponse.Body)
}()
if err != nil { if err != nil {
err = errors.Errorf("error reading json reply: %s", err) return nil, errors.Wrap(err, "error reading json reply")
return nil, err
} }
// Handle unsuccessful HTTP responses // Handle unsuccessful HTTP responses

View File

@ -89,8 +89,8 @@ type ConnReq struct {
// updateState updates the state of the connection request. // updateState updates the state of the connection request.
func (c *ConnReq) updateState(state ConnState) { func (c *ConnReq) updateState(state ConnState) {
c.stateMtx.Lock() c.stateMtx.Lock()
defer c.stateMtx.Unlock()
c.state = state c.state = state
c.stateMtx.Unlock()
} }
// ID returns a unique identifier for the connection request. // ID returns a unique identifier for the connection request.
@ -101,8 +101,8 @@ func (c *ConnReq) ID() uint64 {
// State is the connection state of the requested connection. // State is the connection state of the requested connection.
func (c *ConnReq) State() ConnState { func (c *ConnReq) State() ConnState {
c.stateMtx.RLock() c.stateMtx.RLock()
defer c.stateMtx.RUnlock()
state := c.state state := c.state
c.stateMtx.RUnlock()
return state return state
} }

View File

@ -68,9 +68,9 @@ type DynamicBanScore struct {
// String returns the ban score as a human-readable string. // String returns the ban score as a human-readable string.
func (s *DynamicBanScore) String() string { func (s *DynamicBanScore) String() string {
s.mtx.Lock() s.mtx.Lock()
defer s.mtx.Unlock()
r := fmt.Sprintf("persistent %d + transient %f at %d = %d as of now", r := fmt.Sprintf("persistent %d + transient %f at %d = %d as of now",
s.persistent, s.transient, s.lastUnix, s.Int()) s.persistent, s.transient, s.lastUnix, s.Int())
s.mtx.Unlock()
return r return r
} }
@ -80,8 +80,8 @@ func (s *DynamicBanScore) String() string {
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (s *DynamicBanScore) Int() uint32 { func (s *DynamicBanScore) Int() uint32 {
s.mtx.Lock() s.mtx.Lock()
defer s.mtx.Unlock()
r := s.int(time.Now()) r := s.int(time.Now())
s.mtx.Unlock()
return r return r
} }
@ -91,8 +91,8 @@ func (s *DynamicBanScore) Int() uint32 {
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (s *DynamicBanScore) Increase(persistent, transient uint32) uint32 { func (s *DynamicBanScore) Increase(persistent, transient uint32) uint32 {
s.mtx.Lock() s.mtx.Lock()
defer s.mtx.Unlock()
r := s.increase(persistent, transient, time.Now()) r := s.increase(persistent, transient, time.Now())
s.mtx.Unlock()
return r return r
} }
@ -101,10 +101,10 @@ func (s *DynamicBanScore) Increase(persistent, transient uint32) uint32 {
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (s *DynamicBanScore) Reset() { func (s *DynamicBanScore) Reset() {
s.mtx.Lock() s.mtx.Lock()
defer s.mtx.Unlock()
s.persistent = 0 s.persistent = 0
s.transient = 0 s.transient = 0
s.lastUnix = 0 s.lastUnix = 0
s.mtx.Unlock()
} }
// int returns the ban score, the sum of the persistent and decaying scores at a // int returns the ban score, the sum of the persistent and decaying scores at a

View File

@ -335,13 +335,13 @@ func (b *Backend) printf(lvl Level, tag string, format string, args ...interface
func (b *Backend) write(lvl Level, bytesToWrite []byte) { func (b *Backend) write(lvl Level, bytesToWrite []byte) {
b.mu.Lock() b.mu.Lock()
defer b.mu.Unlock()
os.Stdout.Write(bytesToWrite) os.Stdout.Write(bytesToWrite)
for _, r := range b.rotators { for _, r := range b.rotators {
if lvl >= r.logLevel { if lvl >= r.logLevel {
r.Write(bytesToWrite) r.Write(bytesToWrite)
} }
} }
b.mu.Unlock()
} }
// Close finalizes all log rotators for this backend // Close finalizes all log rotators for this backend

View File

@ -375,8 +375,8 @@ func (mp *TxPool) isTransactionInPool(hash *daghash.TxID) bool {
func (mp *TxPool) IsTransactionInPool(hash *daghash.TxID) bool { func (mp *TxPool) IsTransactionInPool(hash *daghash.TxID) bool {
// Protect concurrent access. // Protect concurrent access.
mp.mtx.RLock() mp.mtx.RLock()
defer mp.mtx.RUnlock()
inPool := mp.isTransactionInPool(hash) inPool := mp.isTransactionInPool(hash)
mp.mtx.RUnlock()
return inPool return inPool
} }
@ -423,8 +423,8 @@ func (mp *TxPool) isOrphanInPool(hash *daghash.TxID) bool {
func (mp *TxPool) IsOrphanInPool(hash *daghash.TxID) bool { func (mp *TxPool) IsOrphanInPool(hash *daghash.TxID) bool {
// Protect concurrent access. // Protect concurrent access.
mp.mtx.RLock() mp.mtx.RLock()
defer mp.mtx.RUnlock()
inPool := mp.isOrphanInPool(hash) inPool := mp.isOrphanInPool(hash)
mp.mtx.RUnlock()
return inPool return inPool
} }
@ -444,8 +444,8 @@ func (mp *TxPool) haveTransaction(hash *daghash.TxID) bool {
func (mp *TxPool) HaveTransaction(hash *daghash.TxID) bool { func (mp *TxPool) HaveTransaction(hash *daghash.TxID) bool {
// Protect concurrent access. // Protect concurrent access.
mp.mtx.RLock() mp.mtx.RLock()
defer mp.mtx.RUnlock()
haveTx := mp.haveTransaction(hash) haveTx := mp.haveTransaction(hash)
mp.mtx.RUnlock()
return haveTx return haveTx
} }
@ -745,8 +745,8 @@ func (mp *TxPool) checkPoolDoubleSpend(tx *util.Tx) error {
// be returned, if not nil will be returned. // be returned, if not nil will be returned.
func (mp *TxPool) CheckSpend(op wire.Outpoint) *util.Tx { func (mp *TxPool) CheckSpend(op wire.Outpoint) *util.Tx {
mp.mtx.RLock() mp.mtx.RLock()
defer mp.mtx.RUnlock()
txR := mp.outpoints[op] txR := mp.outpoints[op]
mp.mtx.RUnlock()
return txR return txR
} }
@ -1207,8 +1207,8 @@ func (mp *TxPool) ProcessTransaction(tx *util.Tx, allowOrphan bool, tag Tag) ([]
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (mp *TxPool) Count() int { func (mp *TxPool) Count() int {
mp.mtx.RLock() mp.mtx.RLock()
defer mp.mtx.RUnlock()
count := len(mp.pool) count := len(mp.pool)
mp.mtx.RUnlock()
return count return count
} }
@ -1229,6 +1229,7 @@ func (mp *TxPool) DepCount() int {
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (mp *TxPool) TxIDs() []*daghash.TxID { func (mp *TxPool) TxIDs() []*daghash.TxID {
mp.mtx.RLock() mp.mtx.RLock()
defer mp.mtx.RUnlock()
ids := make([]*daghash.TxID, len(mp.pool)) ids := make([]*daghash.TxID, len(mp.pool))
i := 0 i := 0
for txID := range mp.pool { for txID := range mp.pool {
@ -1236,7 +1237,6 @@ func (mp *TxPool) TxIDs() []*daghash.TxID {
ids[i] = &idCopy ids[i] = &idCopy
i++ i++
} }
mp.mtx.RUnlock()
return ids return ids
} }
@ -1247,13 +1247,13 @@ func (mp *TxPool) TxIDs() []*daghash.TxID {
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (mp *TxPool) TxDescs() []*TxDesc { func (mp *TxPool) TxDescs() []*TxDesc {
mp.mtx.RLock() mp.mtx.RLock()
defer mp.mtx.RUnlock()
descs := make([]*TxDesc, len(mp.pool)) descs := make([]*TxDesc, len(mp.pool))
i := 0 i := 0
for _, desc := range mp.pool { for _, desc := range mp.pool {
descs[i] = desc descs[i] = desc
i++ i++
} }
mp.mtx.RUnlock()
return descs return descs
} }
@ -1265,13 +1265,13 @@ func (mp *TxPool) TxDescs() []*TxDesc {
// concurrent access as required by the interface contract. // concurrent access as required by the interface contract.
func (mp *TxPool) MiningDescs() []*mining.TxDesc { func (mp *TxPool) MiningDescs() []*mining.TxDesc {
mp.mtx.RLock() mp.mtx.RLock()
defer mp.mtx.RUnlock()
descs := make([]*mining.TxDesc, len(mp.pool)) descs := make([]*mining.TxDesc, len(mp.pool))
i := 0 i := 0
for _, desc := range mp.pool { for _, desc := range mp.pool {
descs[i] = &desc.TxDesc descs[i] = &desc.TxDesc
i++ i++
} }
mp.mtx.RUnlock()
return descs return descs
} }

View File

@ -41,24 +41,23 @@ type fakeDAG struct {
// instance. // instance.
func (s *fakeDAG) BlueScore() uint64 { func (s *fakeDAG) BlueScore() uint64 {
s.RLock() s.RLock()
blueScore := s.currentBlueScore defer s.RUnlock()
s.RUnlock() return s.currentBlueScore
return blueScore
} }
// SetBlueScore sets the current blueScore associated with the fake DAG instance. // SetBlueScore sets the current blueScore associated with the fake DAG instance.
func (s *fakeDAG) SetBlueScore(blueScore uint64) { func (s *fakeDAG) SetBlueScore(blueScore uint64) {
s.Lock() s.Lock()
defer s.Unlock()
s.currentBlueScore = blueScore s.currentBlueScore = blueScore
s.Unlock()
} }
// MedianTimePast returns the current median time past associated with the fake // MedianTimePast returns the current median time past associated with the fake
// DAG instance. // DAG instance.
func (s *fakeDAG) MedianTimePast() time.Time { func (s *fakeDAG) MedianTimePast() time.Time {
s.RLock() s.RLock()
defer s.RUnlock()
mtp := s.medianTimePast mtp := s.medianTimePast
s.RUnlock()
return mtp return mtp
} }
@ -66,8 +65,8 @@ func (s *fakeDAG) MedianTimePast() time.Time {
// DAG instance. // DAG instance.
func (s *fakeDAG) SetMedianTimePast(mtp time.Time) { func (s *fakeDAG) SetMedianTimePast(mtp time.Time) {
s.Lock() s.Lock()
defer s.Unlock()
s.medianTimePast = mtp s.medianTimePast = mtp
s.Unlock()
} }
func calcSequenceLock(tx *util.Tx, func calcSequenceLock(tx *util.Tx,
@ -1319,9 +1318,11 @@ func TestOrphanChainRemoval(t *testing.T) {
// Remove the first orphan that starts the orphan chain without the // Remove the first orphan that starts the orphan chain without the
// remove redeemer flag set and ensure that only the first orphan was // remove redeemer flag set and ensure that only the first orphan was
// removed. // removed.
harness.txPool.mtx.Lock() func() {
harness.txPool.removeOrphan(chainedTxns[1], false) harness.txPool.mtx.Lock()
harness.txPool.mtx.Unlock() defer harness.txPool.mtx.Unlock()
harness.txPool.removeOrphan(chainedTxns[1], false)
}()
testPoolMembership(tc, chainedTxns[1], false, false, false) testPoolMembership(tc, chainedTxns[1], false, false, false)
for _, tx := range chainedTxns[2 : maxOrphans+1] { for _, tx := range chainedTxns[2 : maxOrphans+1] {
testPoolMembership(tc, tx, true, false, false) testPoolMembership(tc, tx, true, false, false)
@ -1329,9 +1330,11 @@ func TestOrphanChainRemoval(t *testing.T) {
// Remove the first remaining orphan that starts the orphan chain with // Remove the first remaining orphan that starts the orphan chain with
// the remove redeemer flag set and ensure they are all removed. // the remove redeemer flag set and ensure they are all removed.
harness.txPool.mtx.Lock() func() {
harness.txPool.removeOrphan(chainedTxns[2], true) harness.txPool.mtx.Lock()
harness.txPool.mtx.Unlock() defer harness.txPool.mtx.Unlock()
harness.txPool.removeOrphan(chainedTxns[2], true)
}()
for _, tx := range chainedTxns[2 : maxOrphans+1] { for _, tx := range chainedTxns[2 : maxOrphans+1] {
testPoolMembership(tc, tx, false, false, false) testPoolMembership(tc, tx, false, false, false)
} }

View File

@ -50,8 +50,8 @@ func (m *mruInventoryMap) String() string {
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (m *mruInventoryMap) Exists(iv *wire.InvVect) bool { func (m *mruInventoryMap) Exists(iv *wire.InvVect) bool {
m.invMtx.Lock() m.invMtx.Lock()
defer m.invMtx.Unlock()
_, exists := m.invMap[*iv] _, exists := m.invMap[*iv]
m.invMtx.Unlock()
return exists return exists
} }
@ -106,11 +106,11 @@ func (m *mruInventoryMap) Add(iv *wire.InvVect) {
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (m *mruInventoryMap) Delete(iv *wire.InvVect) { func (m *mruInventoryMap) Delete(iv *wire.InvVect) {
m.invMtx.Lock() m.invMtx.Lock()
defer m.invMtx.Unlock()
if node, exists := m.invMap[*iv]; exists { if node, exists := m.invMap[*iv]; exists {
m.invList.Remove(node) m.invList.Remove(node)
delete(m.invMap, *iv) delete(m.invMap, *iv)
} }
m.invMtx.Unlock()
} }
// newMruInventoryMap returns a new inventory map that is limited to the number // newMruInventoryMap returns a new inventory map that is limited to the number

View File

@ -48,8 +48,8 @@ func (m *mruNonceMap) String() string {
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (m *mruNonceMap) Exists(nonce uint64) bool { func (m *mruNonceMap) Exists(nonce uint64) bool {
m.mtx.Lock() m.mtx.Lock()
defer m.mtx.Unlock()
_, exists := m.nonceMap[nonce] _, exists := m.nonceMap[nonce]
m.mtx.Unlock()
return exists return exists
} }
@ -104,11 +104,11 @@ func (m *mruNonceMap) Add(nonce uint64) {
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (m *mruNonceMap) Delete(nonce uint64) { func (m *mruNonceMap) Delete(nonce uint64) {
m.mtx.Lock() m.mtx.Lock()
defer m.mtx.Unlock()
if node, exists := m.nonceMap[nonce]; exists { if node, exists := m.nonceMap[nonce]; exists {
m.nonceList.Remove(node) m.nonceList.Remove(node)
delete(m.nonceMap, nonce) delete(m.nonceMap, nonce)
} }
m.mtx.Unlock()
} }
// newMruNonceMap returns a new nonce map that is limited to the number of // newMruNonceMap returns a new nonce map that is limited to the number of

View File

@ -456,14 +456,16 @@ func (p *Peer) AddKnownInventory(invVect *wire.InvVect) {
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (p *Peer) StatsSnapshot() *StatsSnap { func (p *Peer) StatsSnapshot() *StatsSnap {
p.statsMtx.RLock() p.statsMtx.RLock()
defer p.statsMtx.RUnlock()
p.flagsMtx.Lock() p.flagsMtx.Lock()
defer p.flagsMtx.Unlock()
id := p.id id := p.id
addr := p.addr addr := p.addr
userAgent := p.userAgent userAgent := p.userAgent
services := p.services services := p.services
protocolVersion := p.advertisedProtoVer protocolVersion := p.advertisedProtoVer
p.flagsMtx.Unlock()
// Get a copy of all relevant flags and stats. // Get a copy of all relevant flags and stats.
statsSnap := &StatsSnap{ statsSnap := &StatsSnap{
@ -485,7 +487,6 @@ func (p *Peer) StatsSnapshot() *StatsSnap {
LastPingTime: p.lastPingTime, LastPingTime: p.lastPingTime,
} }
p.statsMtx.RUnlock()
return statsSnap return statsSnap
} }
@ -494,10 +495,8 @@ func (p *Peer) StatsSnapshot() *StatsSnap {
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (p *Peer) ID() int32 { func (p *Peer) ID() int32 {
p.flagsMtx.Lock() p.flagsMtx.Lock()
id := p.id defer p.flagsMtx.Unlock()
p.flagsMtx.Unlock() return p.id
return id
} }
// NA returns the peer network address. // NA returns the peer network address.
@ -505,10 +504,8 @@ func (p *Peer) ID() int32 {
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (p *Peer) NA() *wire.NetAddress { func (p *Peer) NA() *wire.NetAddress {
p.flagsMtx.Lock() p.flagsMtx.Lock()
na := p.na defer p.flagsMtx.Unlock()
p.flagsMtx.Unlock() return p.na
return na
} }
// Addr returns the peer address. // Addr returns the peer address.
@ -532,10 +529,8 @@ func (p *Peer) Inbound() bool {
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (p *Peer) Services() wire.ServiceFlag { func (p *Peer) Services() wire.ServiceFlag {
p.flagsMtx.Lock() p.flagsMtx.Lock()
services := p.services defer p.flagsMtx.Unlock()
p.flagsMtx.Unlock() return p.services
return services
} }
// UserAgent returns the user agent of the remote peer. // UserAgent returns the user agent of the remote peer.
@ -543,19 +538,15 @@ func (p *Peer) Services() wire.ServiceFlag {
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (p *Peer) UserAgent() string { func (p *Peer) UserAgent() string {
p.flagsMtx.Lock() p.flagsMtx.Lock()
userAgent := p.userAgent defer p.flagsMtx.Unlock()
p.flagsMtx.Unlock() return p.userAgent
return userAgent
} }
// SubnetworkID returns peer subnetwork ID // SubnetworkID returns peer subnetwork ID
func (p *Peer) SubnetworkID() *subnetworkid.SubnetworkID { func (p *Peer) SubnetworkID() *subnetworkid.SubnetworkID {
p.flagsMtx.Lock() p.flagsMtx.Lock()
subnetworkID := p.cfg.SubnetworkID defer p.flagsMtx.Unlock()
p.flagsMtx.Unlock() return p.cfg.SubnetworkID
return subnetworkID
} }
// LastPingNonce returns the last ping nonce of the remote peer. // LastPingNonce returns the last ping nonce of the remote peer.
@ -563,10 +554,8 @@ func (p *Peer) SubnetworkID() *subnetworkid.SubnetworkID {
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (p *Peer) LastPingNonce() uint64 { func (p *Peer) LastPingNonce() uint64 {
p.statsMtx.RLock() p.statsMtx.RLock()
lastPingNonce := p.lastPingNonce defer p.statsMtx.RUnlock()
p.statsMtx.RUnlock() return p.lastPingNonce
return lastPingNonce
} }
// LastPingTime returns the last ping time of the remote peer. // LastPingTime returns the last ping time of the remote peer.
@ -574,10 +563,8 @@ func (p *Peer) LastPingNonce() uint64 {
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (p *Peer) LastPingTime() time.Time { func (p *Peer) LastPingTime() time.Time {
p.statsMtx.RLock() p.statsMtx.RLock()
lastPingTime := p.lastPingTime defer p.statsMtx.RUnlock()
p.statsMtx.RUnlock() return p.lastPingTime
return lastPingTime
} }
// LastPingMicros returns the last ping micros of the remote peer. // LastPingMicros returns the last ping micros of the remote peer.
@ -585,10 +572,8 @@ func (p *Peer) LastPingTime() time.Time {
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (p *Peer) LastPingMicros() int64 { func (p *Peer) LastPingMicros() int64 {
p.statsMtx.RLock() p.statsMtx.RLock()
lastPingMicros := p.lastPingMicros defer p.statsMtx.RUnlock()
p.statsMtx.RUnlock() return p.lastPingMicros
return lastPingMicros
} }
// VersionKnown returns the whether or not the version of a peer is known // VersionKnown returns the whether or not the version of a peer is known
@ -597,10 +582,8 @@ func (p *Peer) LastPingMicros() int64 {
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (p *Peer) VersionKnown() bool { func (p *Peer) VersionKnown() bool {
p.flagsMtx.Lock() p.flagsMtx.Lock()
versionKnown := p.versionKnown defer p.flagsMtx.Unlock()
p.flagsMtx.Unlock() return p.versionKnown
return versionKnown
} }
// VerAckReceived returns whether or not a verack message was received by the // VerAckReceived returns whether or not a verack message was received by the
@ -609,10 +592,8 @@ func (p *Peer) VersionKnown() bool {
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (p *Peer) VerAckReceived() bool { func (p *Peer) VerAckReceived() bool {
p.flagsMtx.Lock() p.flagsMtx.Lock()
verAckReceived := p.verAckReceived defer p.flagsMtx.Unlock()
p.flagsMtx.Unlock() return p.verAckReceived
return verAckReceived
} }
// ProtocolVersion returns the negotiated peer protocol version. // ProtocolVersion returns the negotiated peer protocol version.
@ -620,10 +601,8 @@ func (p *Peer) VerAckReceived() bool {
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (p *Peer) ProtocolVersion() uint32 { func (p *Peer) ProtocolVersion() uint32 {
p.flagsMtx.Lock() p.flagsMtx.Lock()
protocolVersion := p.protocolVersion defer p.flagsMtx.Unlock()
p.flagsMtx.Unlock() return p.protocolVersion
return protocolVersion
} }
// SelectedTipHash returns the selected tip of the peer. // SelectedTipHash returns the selected tip of the peer.
@ -631,14 +610,14 @@ func (p *Peer) ProtocolVersion() uint32 {
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (p *Peer) SelectedTipHash() *daghash.Hash { func (p *Peer) SelectedTipHash() *daghash.Hash {
p.statsMtx.RLock() p.statsMtx.RLock()
selectedTipHash := p.selectedTipHash defer p.statsMtx.RUnlock()
p.statsMtx.RUnlock() return p.selectedTipHash
return selectedTipHash
} }
// SetSelectedTipHash sets the selected tip of the peer. // SetSelectedTipHash sets the selected tip of the peer.
func (p *Peer) SetSelectedTipHash(selectedTipHash *daghash.Hash) { func (p *Peer) SetSelectedTipHash(selectedTipHash *daghash.Hash) {
p.statsMtx.Lock()
defer p.statsMtx.Unlock()
p.selectedTipHash = selectedTipHash p.selectedTipHash = selectedTipHash
} }
@ -683,10 +662,8 @@ func (p *Peer) BytesReceived() uint64 {
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (p *Peer) TimeConnected() time.Time { func (p *Peer) TimeConnected() time.Time {
p.statsMtx.RLock() p.statsMtx.RLock()
timeConnected := p.timeConnected defer p.statsMtx.RUnlock()
p.statsMtx.RUnlock() return p.timeConnected
return timeConnected
} }
// TimeOffset returns the number of seconds the local time was offset from the // TimeOffset returns the number of seconds the local time was offset from the
@ -696,10 +673,8 @@ func (p *Peer) TimeConnected() time.Time {
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (p *Peer) TimeOffset() int64 { func (p *Peer) TimeOffset() int64 {
p.statsMtx.RLock() p.statsMtx.RLock()
timeOffset := p.timeOffset defer p.statsMtx.RUnlock()
p.statsMtx.RUnlock() return p.timeOffset
return timeOffset
} }
// localVersionMsg creates a version message that can be used to send to the // localVersionMsg creates a version message that can be used to send to the
@ -804,19 +779,21 @@ func (p *Peer) PushGetBlockLocatorMsg(highHash, lowHash *daghash.Hash) {
p.QueueMessage(msg, nil) p.QueueMessage(msg, nil)
} }
func (p *Peer) isDuplicateGetBlockInvsMsg(lowHash, highHash *daghash.Hash) bool {
p.prevGetBlockInvsMtx.Lock()
defer p.prevGetBlockInvsMtx.Unlock()
return p.prevGetBlockInvsHigh != nil && p.prevGetBlockInvsLow != nil &&
lowHash != nil && highHash.IsEqual(p.prevGetBlockInvsHigh) &&
lowHash.IsEqual(p.prevGetBlockInvsLow)
}
// PushGetBlockInvsMsg sends a getblockinvs message for the provided block locator // PushGetBlockInvsMsg sends a getblockinvs message for the provided block locator
// and high hash. It will ignore back-to-back duplicate requests. // and high hash. It will ignore back-to-back duplicate requests.
// //
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (p *Peer) PushGetBlockInvsMsg(lowHash, highHash *daghash.Hash) error { func (p *Peer) PushGetBlockInvsMsg(lowHash, highHash *daghash.Hash) error {
// Filter duplicate getblockinvs requests. // Filter duplicate getblockinvs requests.
p.prevGetBlockInvsMtx.Lock() if p.isDuplicateGetBlockInvsMsg(lowHash, highHash) {
isDuplicate := p.prevGetBlockInvsHigh != nil && p.prevGetBlockInvsLow != nil &&
lowHash != nil && highHash.IsEqual(p.prevGetBlockInvsHigh) &&
lowHash.IsEqual(p.prevGetBlockInvsLow)
p.prevGetBlockInvsMtx.Unlock()
if isDuplicate {
log.Tracef("Filtering duplicate [getblockinvs] with low "+ log.Tracef("Filtering duplicate [getblockinvs] with low "+
"hash %s, high hash %s", lowHash, highHash) "hash %s, high hash %s", lowHash, highHash)
return nil return nil
@ -829,9 +806,9 @@ func (p *Peer) PushGetBlockInvsMsg(lowHash, highHash *daghash.Hash) error {
// Update the previous getblockinvs request information for filtering // Update the previous getblockinvs request information for filtering
// duplicates. // duplicates.
p.prevGetBlockInvsMtx.Lock() p.prevGetBlockInvsMtx.Lock()
defer p.prevGetBlockInvsMtx.Unlock()
p.prevGetBlockInvsLow = lowHash p.prevGetBlockInvsLow = lowHash
p.prevGetBlockInvsHigh = highHash p.prevGetBlockInvsHigh = highHash
p.prevGetBlockInvsMtx.Unlock()
return nil return nil
} }
@ -913,15 +890,26 @@ func (p *Peer) handleRemoteVersionMsg(msg *wire.MsgVersion) error {
return errors.New("incompatible subnetworks") return errors.New("incompatible subnetworks")
} }
// Updating a bunch of stats including block based stats, and the p.updateStatsFromVersionMsg(msg)
// peer's time offset. p.updateFlagsFromVersionMsg(msg)
return nil
}
// updateStatsFromVersionMsg updates a bunch of stats including block based stats, and the
// peer's time offset.
func (p *Peer) updateStatsFromVersionMsg(msg *wire.MsgVersion) {
p.statsMtx.Lock() p.statsMtx.Lock()
defer p.statsMtx.Unlock()
p.selectedTipHash = msg.SelectedTipHash p.selectedTipHash = msg.SelectedTipHash
p.timeOffset = msg.Timestamp.Unix() - time.Now().Unix() p.timeOffset = msg.Timestamp.Unix() - time.Now().Unix()
p.statsMtx.Unlock() }
func (p *Peer) updateFlagsFromVersionMsg(msg *wire.MsgVersion) {
// Negotiate the protocol version. // Negotiate the protocol version.
p.flagsMtx.Lock() p.flagsMtx.Lock()
defer p.flagsMtx.Unlock()
p.advertisedProtoVer = uint32(msg.ProtocolVersion) p.advertisedProtoVer = uint32(msg.ProtocolVersion)
p.protocolVersion = minUint32(p.protocolVersion, p.advertisedProtoVer) p.protocolVersion = minUint32(p.protocolVersion, p.advertisedProtoVer)
p.versionKnown = true p.versionKnown = true
@ -937,10 +925,6 @@ func (p *Peer) handleRemoteVersionMsg(msg *wire.MsgVersion) error {
// Set the remote peer's user agent. // Set the remote peer's user agent.
p.userAgent = msg.UserAgent p.userAgent = msg.UserAgent
p.flagsMtx.Unlock()
return nil
} }
// handlePingMsg is invoked when a peer receives a ping kaspa message. For // handlePingMsg is invoked when a peer receives a ping kaspa message. For
@ -963,12 +947,12 @@ func (p *Peer) handlePongMsg(msg *wire.MsgPong) {
// without large usage of the ping rpc call since we ping infrequently // without large usage of the ping rpc call since we ping infrequently
// enough that if they overlap we would have timed out the peer. // enough that if they overlap we would have timed out the peer.
p.statsMtx.Lock() p.statsMtx.Lock()
defer p.statsMtx.Unlock()
if p.lastPingNonce != 0 && msg.Nonce == p.lastPingNonce { if p.lastPingNonce != 0 && msg.Nonce == p.lastPingNonce {
p.lastPingMicros = time.Since(p.lastPingTime).Nanoseconds() p.lastPingMicros = time.Since(p.lastPingTime).Nanoseconds()
p.lastPingMicros /= 1000 // convert to usec. p.lastPingMicros /= 1000 // convert to usec.
p.lastPingNonce = 0 p.lastPingNonce = 0
} }
p.statsMtx.Unlock()
} }
// readMessage reads the next kaspa message from the peer with logging. // readMessage reads the next kaspa message from the peer with logging.
@ -1346,9 +1330,7 @@ out:
"disconnecting", p) "disconnecting", p)
break out break out
} }
p.flagsMtx.Lock() p.markVerAckReceived()
p.verAckReceived = true
p.flagsMtx.Unlock()
if p.cfg.Listeners.OnVerAck != nil { if p.cfg.Listeners.OnVerAck != nil {
p.cfg.Listeners.OnVerAck(p, msg) p.cfg.Listeners.OnVerAck(p, msg)
} }
@ -1475,6 +1457,12 @@ out:
log.Tracef("Peer input handler done for %s", p) log.Tracef("Peer input handler done for %s", p)
} }
func (p *Peer) markVerAckReceived() {
p.flagsMtx.Lock()
defer p.flagsMtx.Unlock()
p.verAckReceived = true
}
// queueHandler handles the queuing of outgoing data for the peer. This runs as // queueHandler handles the queuing of outgoing data for the peer. This runs as
// a muxer for various sources of input so we can ensure that server and peer // a muxer for various sources of input so we can ensure that server and peer
// handlers will not block on us sending a message. That data is then passed on // handlers will not block on us sending a message. That data is then passed on
@ -1622,10 +1610,12 @@ out:
case msg := <-p.sendQueue: case msg := <-p.sendQueue:
switch m := msg.msg.(type) { switch m := msg.msg.(type) {
case *wire.MsgPing: case *wire.MsgPing:
p.statsMtx.Lock() func() {
p.lastPingNonce = m.Nonce p.statsMtx.Lock()
p.lastPingTime = time.Now() defer p.statsMtx.Unlock()
p.statsMtx.Unlock() p.lastPingNonce = m.Nonce
p.lastPingTime = time.Now()
}()
} }
p.stallControl <- stallControlMsg{sccSendMessage, msg.msg} p.stallControl <- stallControlMsg{sccSendMessage, msg.msg}

View File

@ -442,9 +442,8 @@ out:
// is being reassigned during a reconnect. // is being reassigned during a reconnect.
func (c *Client) disconnectChan() <-chan struct{} { func (c *Client) disconnectChan() <-chan struct{} {
c.mtx.Lock() c.mtx.Lock()
ch := c.disconnect defer c.mtx.Unlock()
c.mtx.Unlock() return c.disconnect
return ch
} }
// wsOutHandler handles all outgoing messages for the websocket connection. It // wsOutHandler handles all outgoing messages for the websocket connection. It
@ -511,9 +510,11 @@ func (c *Client) reregisterNtfns() error {
// the notification state (while not under the lock of course) which // the notification state (while not under the lock of course) which
// also register it with the remote RPC server, so this prevents double // also register it with the remote RPC server, so this prevents double
// registrations. // registrations.
c.ntfnStateLock.Lock() stateCopy := func() *notificationState {
stateCopy := c.ntfnState.Copy() c.ntfnStateLock.Lock()
c.ntfnStateLock.Unlock() defer c.ntfnStateLock.Unlock()
return c.ntfnState.Copy()
}()
// Reregister notifyblocks if needed. // Reregister notifyblocks if needed.
if stateCopy.notifyBlocks { if stateCopy.notifyBlocks {
@ -550,23 +551,9 @@ var ignoreResends = map[string]struct{}{
"rescan": {}, "rescan": {},
} }
// resendRequests resends any requests that had not completed when the client func (c *Client) collectResendRequests() []*jsonRequest {
// disconnected. It is intended to be called once the client has reconnected as
// a separate goroutine.
func (c *Client) resendRequests() {
// Set the notification state back up. If anything goes wrong,
// disconnect the client.
if err := c.reregisterNtfns(); err != nil {
log.Warnf("Unable to re-establish notification state: %s", err)
c.Disconnect()
return
}
// Since it's possible to block on send and more requests might be
// added by the caller while resending, make a copy of all of the
// requests that need to be resent now and work from the copy. This
// also allows the lock to be released quickly.
c.requestLock.Lock() c.requestLock.Lock()
defer c.requestLock.Unlock()
resendReqs := make([]*jsonRequest, 0, c.requestList.Len()) resendReqs := make([]*jsonRequest, 0, c.requestList.Len())
var nextElem *list.Element var nextElem *list.Element
for e := c.requestList.Front(); e != nil; e = nextElem { for e := c.requestList.Front(); e != nil; e = nextElem {
@ -583,7 +570,26 @@ func (c *Client) resendRequests() {
resendReqs = append(resendReqs, jReq) resendReqs = append(resendReqs, jReq)
} }
} }
c.requestLock.Unlock() return resendReqs
}
// resendRequests resends any requests that had not completed when the client
// disconnected. It is intended to be called once the client has reconnected as
// a separate goroutine.
func (c *Client) resendRequests() {
// Set the notification state back up. If anything goes wrong,
// disconnect the client.
if err := c.reregisterNtfns(); err != nil {
log.Warnf("Unable to re-establish notification state: %s", err)
c.Disconnect()
return
}
// Since it's possible to block on send and more requests might be
// added by the caller while resending, make a copy of all of the
// requests that need to be resent now and work from the copy. This
// also allows the lock to be released quickly.
resendReqs := c.collectResendRequests()
for _, jReq := range resendReqs { for _, jReq := range resendReqs {
// Stop resending commands if the client disconnected again // Stop resending commands if the client disconnected again
@ -654,10 +660,12 @@ out:
c.wsConn = wsConn c.wsConn = wsConn
c.retryCount = 0 c.retryCount = 0
c.mtx.Lock() func() {
c.disconnect = make(chan struct{}) c.mtx.Lock()
c.disconnected = false defer c.mtx.Unlock()
c.mtx.Unlock() c.disconnect = make(chan struct{})
c.disconnected = false
}()
// Start processing input and output for the // Start processing input and output for the
// new connection. // new connection.
@ -689,11 +697,14 @@ func (c *Client) handleSendPostMessage(details *sendPostDetails) {
} }
// Read the raw bytes and close the response. // Read the raw bytes and close the response.
respBytes, err := ioutil.ReadAll(httpResponse.Body) respBytes, err := func() ([]byte, error) {
httpResponse.Body.Close() defer httpResponse.Body.Close()
return ioutil.ReadAll(httpResponse.Body)
}()
if err != nil { if err != nil {
err = errors.Errorf("error reading json reply: %s", err) jReq.responseChan <- &response{
jReq.responseChan <- &response{err: err} err: errors.Wrap(err, "error reading json reply"),
}
return return
} }

View File

@ -10,15 +10,38 @@ import (
"strings" "strings"
) )
func concreteTypeToMethodWithRLock(rt reflect.Type) (string, bool) {
registerLock.RLock()
defer registerLock.RUnlock()
method, ok := concreteTypeToMethod[rt]
return method, ok
}
func methodToInfoWithRLock(method string) (methodInfo, bool) {
registerLock.RLock()
defer registerLock.RUnlock()
info, ok := methodToInfo[method]
return info, ok
}
func methodConcreteTypeAndInfoWithRLock(method string) (reflect.Type, methodInfo, bool) {
registerLock.RLock()
defer registerLock.RUnlock()
rtp, ok := methodToConcreteType[method]
if !ok {
return nil, methodInfo{}, false
}
info := methodToInfo[method]
return rtp, info, ok
}
// CommandMethod returns the method for the passed command. The provided command // CommandMethod returns the method for the passed command. The provided command
// type must be a registered type. All commands provided by this package are // type must be a registered type. All commands provided by this package are
// registered by default. // registered by default.
func CommandMethod(cmd interface{}) (string, error) { func CommandMethod(cmd interface{}) (string, error) {
// Look up the cmd type and error out if not registered. // Look up the cmd type and error out if not registered.
rt := reflect.TypeOf(cmd) rt := reflect.TypeOf(cmd)
registerLock.RLock() method, ok := concreteTypeToMethodWithRLock(rt)
method, ok := concreteTypeToMethod[rt]
registerLock.RUnlock()
if !ok { if !ok {
str := fmt.Sprintf("%q is not registered", method) str := fmt.Sprintf("%q is not registered", method)
return "", makeError(ErrUnregisteredMethod, str) return "", makeError(ErrUnregisteredMethod, str)
@ -33,9 +56,7 @@ func CommandMethod(cmd interface{}) (string, error) {
func MethodUsageFlags(method string) (UsageFlag, error) { func MethodUsageFlags(method string) (UsageFlag, error) {
// Look up details about the provided method and error out if not // Look up details about the provided method and error out if not
// registered. // registered.
registerLock.RLock() info, ok := methodToInfoWithRLock(method)
info, ok := methodToInfo[method]
registerLock.RUnlock()
if !ok { if !ok {
str := fmt.Sprintf("%q is not registered", method) str := fmt.Sprintf("%q is not registered", method)
return 0, makeError(ErrUnregisteredMethod, str) return 0, makeError(ErrUnregisteredMethod, str)
@ -225,9 +246,9 @@ func MethodUsageText(method string) (string, error) {
// Look up details about the provided method and error out if not // Look up details about the provided method and error out if not
// registered. // registered.
registerLock.RLock() registerLock.RLock()
defer registerLock.RUnlock()
rtp, ok := methodToConcreteType[method] rtp, ok := methodToConcreteType[method]
info := methodToInfo[method] info := methodToInfo[method]
registerLock.RUnlock()
if !ok { if !ok {
str := fmt.Sprintf("%q is not registered", method) str := fmt.Sprintf("%q is not registered", method)
return "", makeError(ErrUnregisteredMethod, str) return "", makeError(ErrUnregisteredMethod, str)
@ -241,9 +262,7 @@ func MethodUsageText(method string) (string, error) {
// Generate and store the usage string for future calls and return it. // Generate and store the usage string for future calls and return it.
usage := methodUsageText(rtp, info.defaults, method) usage := methodUsageText(rtp, info.defaults, method)
registerLock.Lock()
info.usage = usage info.usage = usage
methodToInfo[method] = info methodToInfo[method] = info
registerLock.Unlock()
return usage, nil return usage, nil
} }

View File

@ -39,9 +39,7 @@ func makeParams(rt reflect.Type, rv reflect.Value) []interface{} {
func MarshalCommand(id interface{}, cmd interface{}) ([]byte, error) { func MarshalCommand(id interface{}, cmd interface{}) ([]byte, error) {
// Look up the cmd type and error out if not registered. // Look up the cmd type and error out if not registered.
rt := reflect.TypeOf(cmd) rt := reflect.TypeOf(cmd)
registerLock.RLock() method, ok := concreteTypeToMethodWithRLock(rt)
method, ok := concreteTypeToMethod[rt]
registerLock.RUnlock()
if !ok { if !ok {
str := fmt.Sprintf("%q is not registered", method) str := fmt.Sprintf("%q is not registered", method)
return nil, makeError(ErrUnregisteredMethod, str) return nil, makeError(ErrUnregisteredMethod, str)
@ -109,10 +107,7 @@ func populateDefaults(numParams int, info *methodInfo, rv reflect.Value) {
// so long as the method type contained within the marshalled request is // so long as the method type contained within the marshalled request is
// registered. // registered.
func UnmarshalCommand(r *Request) (interface{}, error) { func UnmarshalCommand(r *Request) (interface{}, error) {
registerLock.RLock() rtp, info, ok := methodConcreteTypeAndInfoWithRLock(r.Method)
rtp, ok := methodToConcreteType[r.Method]
info := methodToInfo[r.Method]
registerLock.RUnlock()
if !ok { if !ok {
str := fmt.Sprintf("%q is not registered", r.Method) str := fmt.Sprintf("%q is not registered", r.Method)
return nil, makeError(ErrUnregisteredMethod, str) return nil, makeError(ErrUnregisteredMethod, str)
@ -513,10 +508,7 @@ func assignField(paramNum int, fieldName string, dest reflect.Value, src reflect
func NewCommand(method string, args ...interface{}) (interface{}, error) { func NewCommand(method string, args ...interface{}) (interface{}, error) {
// Look up details about the provided method. Any methods that aren't // Look up details about the provided method. Any methods that aren't
// registered are an error. // registered are an error.
registerLock.RLock() rtp, info, ok := methodConcreteTypeAndInfoWithRLock(method)
rtp, ok := methodToConcreteType[method]
info := methodToInfo[method]
registerLock.RUnlock()
if !ok { if !ok {
str := fmt.Sprintf("%q is not registered", method) str := fmt.Sprintf("%q is not registered", method)
return nil, makeError(ErrUnregisteredMethod, str) return nil, makeError(ErrUnregisteredMethod, str)

View File

@ -507,10 +507,7 @@ func isValidResultType(kind reflect.Kind) bool {
func GenerateHelp(method string, descs map[string]string, resultTypes ...interface{}) (string, error) { func GenerateHelp(method string, descs map[string]string, resultTypes ...interface{}) (string, error) {
// Look up details about the provided method and error out if not // Look up details about the provided method and error out if not
// registered. // registered.
registerLock.RLock() rtp, info, ok := methodConcreteTypeAndInfoWithRLock(method)
rtp, ok := methodToConcreteType[method]
info := methodToInfo[method]
registerLock.RUnlock()
if !ok { if !ok {
str := fmt.Sprintf("%q is not registered", method) str := fmt.Sprintf("%q is not registered", method)
return "", makeError(ErrUnregisteredMethod, str) return "", makeError(ErrUnregisteredMethod, str)

View File

@ -155,7 +155,6 @@ type GetBlockTemplateResult struct {
// Optional long polling from BIP 0022. // Optional long polling from BIP 0022.
LongPollID string `json:"longPollId,omitempty"` LongPollID string `json:"longPollId,omitempty"`
LongPollURI string `json:"longPollUri,omitempty"` LongPollURI string `json:"longPollUri,omitempty"`
SubmitOld *bool `json:"submitOld,omitempty"`
// Basic pool extension from BIP 0023. // Basic pool extension from BIP 0023.
Target string `json:"target,omitempty"` Target string `json:"target,omitempty"`

View File

@ -303,8 +303,8 @@ func (sp *Peer) addressKnown(na *wire.NetAddress) bool {
// It is safe for concurrent access. // It is safe for concurrent access.
func (sp *Peer) setDisableRelayTx(disable bool) { func (sp *Peer) setDisableRelayTx(disable bool) {
sp.relayMtx.Lock() sp.relayMtx.Lock()
defer sp.relayMtx.Unlock()
sp.DisableRelayTx = disable sp.DisableRelayTx = disable
sp.relayMtx.Unlock()
} }
// relayTxDisabled returns whether or not relaying of transactions for the given // relayTxDisabled returns whether or not relaying of transactions for the given
@ -312,10 +312,8 @@ func (sp *Peer) setDisableRelayTx(disable bool) {
// It is safe for concurrent access. // It is safe for concurrent access.
func (sp *Peer) relayTxDisabled() bool { func (sp *Peer) relayTxDisabled() bool {
sp.relayMtx.Lock() sp.relayMtx.Lock()
isDisabled := sp.DisableRelayTx defer sp.relayMtx.Unlock()
sp.relayMtx.Unlock() return sp.DisableRelayTx
return isDisabled
} }
// pushAddrMsg sends an addr message to the connected peer using the provided // pushAddrMsg sends an addr message to the connected peer using the provided

View File

@ -204,7 +204,7 @@ func handleGetBlockTemplateRequest(s *Server, request *rpcmodel.TemplateRequest,
if err := state.updateBlockTemplate(s, useCoinbaseValue); err != nil { if err := state.updateBlockTemplate(s, useCoinbaseValue); err != nil {
return nil, err return nil, err
} }
return state.blockTemplateResult(s.cfg.DAG, useCoinbaseValue, nil) return state.blockTemplateResult(s.cfg.DAG, useCoinbaseValue)
} }
// handleGetBlockTemplateLongPoll is a helper for handleGetBlockTemplateRequest // handleGetBlockTemplateLongPoll is a helper for handleGetBlockTemplateRequest
@ -217,66 +217,23 @@ func handleGetBlockTemplateRequest(s *Server, request *rpcmodel.TemplateRequest,
// has passed without finding a solution. // has passed without finding a solution.
func handleGetBlockTemplateLongPoll(s *Server, longPollID string, useCoinbaseValue bool, closeChan <-chan struct{}) (interface{}, error) { func handleGetBlockTemplateLongPoll(s *Server, longPollID string, useCoinbaseValue bool, closeChan <-chan struct{}) (interface{}, error) {
state := s.gbtWorkState state := s.gbtWorkState
state.Lock()
// The state unlock is intentionally not deferred here since it needs to
// be manually unlocked before waiting for a notification about block
// template changes.
if err := state.updateBlockTemplate(s, useCoinbaseValue); err != nil { result, longPollChan, err := blockTemplateOrLongPollChan(s, longPollID, useCoinbaseValue)
state.Unlock() if err != nil {
return nil, err return nil, err
} }
// Just return the current block template if the long poll ID provided by if result != nil {
// the caller is invalid.
parentHashes, lastGenerated, err := decodeLongPollID(longPollID)
if err != nil {
result, err := state.blockTemplateResult(s.cfg.DAG, useCoinbaseValue, nil)
if err != nil {
state.Unlock()
return nil, err
}
state.Unlock()
return result, nil return result, nil
} }
// Return the block template now if the specific block template
// identified by the long poll ID no longer matches the current block
// template as this means the provided template is stale.
areHashesEqual := daghash.AreEqual(state.template.Block.Header.ParentHashes, parentHashes)
if !areHashesEqual ||
lastGenerated != state.lastGenerated.Unix() {
// Include whether or not it is valid to submit work against the
// old block template depending on whether or not a solution has
// already been found and added to the block DAG.
submitOld := areHashesEqual
result, err := state.blockTemplateResult(s.cfg.DAG, useCoinbaseValue,
&submitOld)
if err != nil {
state.Unlock()
return nil, err
}
state.Unlock()
return result, nil
}
// Register the parent hashes and last generated time for notifications
// Get a channel that will be notified when the template associated with
// the provided ID is stale and a new block template should be returned to
// the caller.
longPollChan := state.templateUpdateChan(parentHashes, lastGenerated)
state.Unlock()
select { select {
// When the client closes before it's time to send a reply, just return // When the client closes before it's time to send a reply, just return
// now so the goroutine doesn't hang around. // now so the goroutine doesn't hang around.
case <-closeChan: case <-closeChan:
return nil, ErrClientQuit return nil, ErrClientQuit
// Wait until signal received to send the reply. // Wait until signal received to send the reply.
case <-longPollChan: case <-longPollChan:
// Fallthrough // Fallthrough
} }
@ -292,8 +249,7 @@ func handleGetBlockTemplateLongPoll(s *Server, longPollID string, useCoinbaseVal
// Include whether or not it is valid to submit work against the old // Include whether or not it is valid to submit work against the old
// block template depending on whether or not a solution has already // block template depending on whether or not a solution has already
// been found and added to the block DAG. // been found and added to the block DAG.
submitOld := areHashesEqual result, err = state.blockTemplateResult(s.cfg.DAG, useCoinbaseValue)
result, err := state.blockTemplateResult(s.cfg.DAG, useCoinbaseValue, &submitOld)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -301,6 +257,61 @@ func handleGetBlockTemplateLongPoll(s *Server, longPollID string, useCoinbaseVal
return result, nil return result, nil
} }
// blockTemplateOrLongPollChan returns a block template if the
// template identified by the provided long poll ID is stale or
// invalid. Otherwise, it returns a channel that will notify
// when there's a more current template.
func blockTemplateOrLongPollChan(s *Server, longPollID string, useCoinbaseValue bool) (*rpcmodel.GetBlockTemplateResult, chan struct{}, error) {
state := s.gbtWorkState
state.Lock()
defer state.Unlock()
// The state unlock is intentionally not deferred here since it needs to
// be manually unlocked before waiting for a notification about block
// template changes.
if err := state.updateBlockTemplate(s, useCoinbaseValue); err != nil {
return nil, nil, err
}
// Just return the current block template if the long poll ID provided by
// the caller is invalid.
parentHashes, lastGenerated, err := decodeLongPollID(longPollID)
if err != nil {
result, err := state.blockTemplateResult(s.cfg.DAG, useCoinbaseValue)
if err != nil {
return nil, nil, err
}
return result, nil, nil
}
// Return the block template now if the specific block template
// identified by the long poll ID no longer matches the current block
// template as this means the provided template is stale.
areHashesEqual := daghash.AreEqual(state.template.Block.Header.ParentHashes, parentHashes)
if !areHashesEqual ||
lastGenerated != state.lastGenerated.Unix() {
// Include whether or not it is valid to submit work against the
// old block template depending on whether or not a solution has
// already been found and added to the block DAG.
result, err := state.blockTemplateResult(s.cfg.DAG, useCoinbaseValue)
if err != nil {
return nil, nil, err
}
return result, nil, nil
}
// Register the parent hashes and last generated time for notifications
// Get a channel that will be notified when the template associated with
// the provided ID is stale and a new block template should be returned to
// the caller.
longPollChan := state.templateUpdateChan(parentHashes, lastGenerated)
return nil, longPollChan, nil
}
// handleGetBlockTemplateProposal is a helper for handleGetBlockTemplate which // handleGetBlockTemplateProposal is a helper for handleGetBlockTemplate which
// deals with block proposals. // deals with block proposals.
func handleGetBlockTemplateProposal(s *Server, request *rpcmodel.TemplateRequest) (interface{}, error) { func handleGetBlockTemplateProposal(s *Server, request *rpcmodel.TemplateRequest) (interface{}, error) {
@ -693,7 +704,7 @@ func (state *gbtWorkState) updateBlockTemplate(s *Server, useCoinbaseValue bool)
// and returned to the caller. // and returned to the caller.
// //
// This function MUST be called with the state locked. // This function MUST be called with the state locked.
func (state *gbtWorkState) blockTemplateResult(dag *blockdag.BlockDAG, useCoinbaseValue bool, submitOld *bool) (*rpcmodel.GetBlockTemplateResult, error) { func (state *gbtWorkState) blockTemplateResult(dag *blockdag.BlockDAG, useCoinbaseValue bool) (*rpcmodel.GetBlockTemplateResult, error) {
// Ensure the timestamps are still in valid range for the template. // Ensure the timestamps are still in valid range for the template.
// This should really only ever happen if the local clock is changed // This should really only ever happen if the local clock is changed
// after the template is generated, but it's important to avoid serving // after the template is generated, but it's important to avoid serving
@ -779,7 +790,6 @@ func (state *gbtWorkState) blockTemplateResult(dag *blockdag.BlockDAG, useCoinba
UTXOCommitment: header.UTXOCommitment.String(), UTXOCommitment: header.UTXOCommitment.String(),
Version: header.Version, Version: header.Version,
LongPollID: longPollID, LongPollID: longPollID,
SubmitOld: submitOld,
Target: targetDifficulty, Target: targetDifficulty,
MinTime: state.minTimestamp.Unix(), MinTime: state.minTimestamp.Unix(),
MaxTime: maxTime.Unix(), MaxTime: maxTime.Unix(),

View File

@ -30,22 +30,28 @@ func handleLoadTxFilter(wsc *wsClient, icmd interface{}) (interface{}, error) {
params := wsc.server.cfg.DAGParams params := wsc.server.cfg.DAGParams
wsc.Lock() reloadedFilterData := func() bool {
if cmd.Reload || wsc.filterData == nil { wsc.Lock()
wsc.filterData = newWSClientFilter(cmd.Addresses, outpoints, defer wsc.Unlock()
params) if cmd.Reload || wsc.filterData == nil {
wsc.Unlock() wsc.filterData = newWSClientFilter(cmd.Addresses, outpoints,
} else { params)
wsc.Unlock() return true
}
return false
}()
wsc.filterData.mu.Lock() if !reloadedFilterData {
for _, a := range cmd.Addresses { func() {
wsc.filterData.addAddressStr(a, params) wsc.filterData.mu.Lock()
} defer wsc.filterData.mu.Unlock()
for i := range outpoints { for _, a := range cmd.Addresses {
wsc.filterData.addUnspentOutpoint(&outpoints[i]) wsc.filterData.addAddressStr(a, params)
} }
wsc.filterData.mu.Unlock() for i := range outpoints {
wsc.filterData.addUnspentOutpoint(&outpoints[i])
}
}()
} }
return nil, nil return nil, nil

View File

@ -16,6 +16,7 @@ func rescanBlockFilter(filter *wsClientFilter, block *util.Block, params *dagcon
var transactions []string var transactions []string
filter.mu.Lock() filter.mu.Lock()
defer filter.mu.Unlock()
for _, tx := range block.Transactions() { for _, tx := range block.Transactions() {
msgTx := tx.MsgTx() msgTx := tx.MsgTx()
@ -26,7 +27,7 @@ func rescanBlockFilter(filter *wsClientFilter, block *util.Block, params *dagcon
// Scan inputs if not a coinbase transaction. // Scan inputs if not a coinbase transaction.
if !msgTx.IsCoinBase() { if !msgTx.IsCoinBase() {
for _, input := range msgTx.TxIn { for _, input := range msgTx.TxIn {
if !filter.existsUnspentOutpoint(&input.PreviousOutpoint) { if !filter.existsUnspentOutpointNoLock(&input.PreviousOutpoint) {
continue continue
} }
if !added { if !added {
@ -65,7 +66,6 @@ func rescanBlockFilter(filter *wsClientFilter, block *util.Block, params *dagcon
} }
} }
} }
filter.mu.Unlock()
return transactions return transactions
} }

View File

@ -17,9 +17,7 @@ func handleRescanBlocks(wsc *wsClient, icmd interface{}) (interface{}, error) {
} }
// Load client's transaction filter. Must exist in order to continue. // Load client's transaction filter. Must exist in order to continue.
wsc.Lock() filter := wsc.FilterData()
filter := wsc.filterData
wsc.Unlock()
if filter == nil { if filter == nil {
return nil, &rpcmodel.RPCError{ return nil, &rpcmodel.RPCError{
Code: rpcmodel.ErrRPCMisc, Code: rpcmodel.ErrRPCMisc,

View File

@ -185,9 +185,12 @@ func (s *Server) httpStatusLine(req *http.Request, code int) string {
if !proto11 { if !proto11 {
key = -key key = -key
} }
s.statusLock.RLock() line, ok := func() (string, bool) {
line, ok := s.statusLines[key] s.statusLock.RLock()
s.statusLock.RUnlock() defer s.statusLock.RUnlock()
line, ok := s.statusLines[key]
return line, ok
}()
if ok { if ok {
return line return line
} }
@ -202,8 +205,8 @@ func (s *Server) httpStatusLine(req *http.Request, code int) string {
if text != "" { if text != "" {
line = proto + " " + codeStr + " " + text + "\r\n" line = proto + " " + codeStr + " " + text + "\r\n"
s.statusLock.Lock() s.statusLock.Lock()
defer s.statusLock.Unlock()
s.statusLines[key] = line s.statusLines[key] = line
s.statusLock.Unlock()
} else { } else {
text = "status code " + codeStr text = "status code " + codeStr
line = proto + " " + codeStr + " " + text + "\r\n" line = proto + " " + codeStr + " " + text + "\r\n"

View File

@ -352,15 +352,21 @@ func (f *wsClientFilter) addUnspentOutpoint(op *wire.Outpoint) {
f.unspent[*op] = struct{}{} f.unspent[*op] = struct{}{}
} }
// existsUnspentOutpoint returns true if the passed outpoint has been added to // existsUnspentOutpointNoLock returns true if the passed outpoint has been added to
// the wsClientFilter. // the wsClientFilter.
// //
// NOTE: This extension was ported from github.com/decred/dcrd // NOTE: This extension was ported from github.com/decred/dcrd
func (f *wsClientFilter) existsUnspentOutpoint(op *wire.Outpoint) bool { func (f *wsClientFilter) existsUnspentOutpointNoLock(op *wire.Outpoint) bool {
_, ok := f.unspent[*op] _, ok := f.unspent[*op]
return ok return ok
} }
func (f *wsClientFilter) existsUnspentOutpoint(op *wire.Outpoint) bool {
f.mu.Lock()
defer f.mu.Unlock()
return f.existsUnspentOutpointNoLock(op)
}
// Notification types // Notification types
type notificationBlockAdded util.Block type notificationBlockAdded util.Block
type notificationChainChanged struct { type notificationChainChanged struct {
@ -568,17 +574,13 @@ func (m *wsNotificationManager) subscribedClients(tx *util.Tx,
msgTx := tx.MsgTx() msgTx := tx.MsgTx()
for _, input := range msgTx.TxIn { for _, input := range msgTx.TxIn {
for quitChan, wsc := range clients { for quitChan, wsc := range clients {
wsc.Lock() filter := wsc.FilterData()
filter := wsc.filterData
wsc.Unlock()
if filter == nil { if filter == nil {
continue continue
} }
filter.mu.Lock()
if filter.existsUnspentOutpoint(&input.PreviousOutpoint) { if filter.existsUnspentOutpoint(&input.PreviousOutpoint) {
subscribed[quitChan] = struct{}{} subscribed[quitChan] = struct{}{}
} }
filter.mu.Unlock()
} }
} }
@ -591,22 +593,22 @@ func (m *wsNotificationManager) subscribedClients(tx *util.Tx,
continue continue
} }
for quitChan, wsc := range clients { for quitChan, wsc := range clients {
wsc.Lock() filter := wsc.FilterData()
filter := wsc.filterData
wsc.Unlock()
if filter == nil { if filter == nil {
continue continue
} }
filter.mu.Lock() func() {
if filter.existsAddress(addr) { filter.mu.Lock()
subscribed[quitChan] = struct{}{} defer filter.mu.Unlock()
op := wire.Outpoint{ if filter.existsAddress(addr) {
TxID: *tx.ID(), subscribed[quitChan] = struct{}{}
Index: uint32(i), op := wire.Outpoint{
TxID: *tx.ID(),
Index: uint32(i),
}
filter.addUnspentOutpoint(&op)
} }
filter.addUnspentOutpoint(&op) }()
}
filter.mu.Unlock()
} }
} }
@ -1250,10 +1252,8 @@ func (c *wsClient) QueueNotification(marshalledJSON []byte) error {
// Disconnected returns whether or not the websocket client is disconnected. // Disconnected returns whether or not the websocket client is disconnected.
func (c *wsClient) Disconnected() bool { func (c *wsClient) Disconnected() bool {
c.Lock() c.Lock()
isDisconnected := c.disconnected defer c.Unlock()
c.Unlock() return c.disconnected
return isDisconnected
} }
// Disconnect disconnects the websocket client. // Disconnect disconnects the websocket client.
@ -1289,6 +1289,13 @@ func (c *wsClient) WaitForShutdown() {
c.wg.Wait() c.wg.Wait()
} }
// FilterData returns the websocket client filter data.
func (c *wsClient) FilterData() *wsClientFilter {
c.Lock()
defer c.Unlock()
return c.filterData
}
// newWebsocketClient returns a new websocket client given the notification // newWebsocketClient returns a new websocket client given the notification
// manager, websocket connection, remote address, and whether or not the client // manager, websocket connection, remote address, and whether or not the client
// has already been authenticated (via HTTP Basic access authentication). The // has already been authenticated (via HTTP Basic access authentication). The

View File

@ -57,8 +57,8 @@ func NewSigCache(maxEntries uint) *SigCache {
// unless there exists a writer, adding an entry to the SigCache. // unless there exists a writer, adding an entry to the SigCache.
func (s *SigCache) Exists(sigHash daghash.Hash, sig *ecc.Signature, pubKey *ecc.PublicKey) bool { func (s *SigCache) Exists(sigHash daghash.Hash, sig *ecc.Signature, pubKey *ecc.PublicKey) bool {
s.RLock() s.RLock()
defer s.RUnlock()
entry, ok := s.validSigs[sigHash] entry, ok := s.validSigs[sigHash]
s.RUnlock()
return ok && entry.pubKey.IsEqual(pubKey) && entry.sig.IsEqual(sig) return ok && entry.pubKey.IsEqual(pubKey) && entry.sig.IsEqual(sig)
} }

View File

@ -89,8 +89,8 @@ func LoadFilter(filter *wire.MsgFilterLoad) *Filter {
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (bf *Filter) IsLoaded() bool { func (bf *Filter) IsLoaded() bool {
bf.mtx.Lock() bf.mtx.Lock()
defer bf.mtx.Unlock()
loaded := bf.msgFilterLoad != nil loaded := bf.msgFilterLoad != nil
bf.mtx.Unlock()
return loaded return loaded
} }
@ -99,8 +99,8 @@ func (bf *Filter) IsLoaded() bool {
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (bf *Filter) Reload(filter *wire.MsgFilterLoad) { func (bf *Filter) Reload(filter *wire.MsgFilterLoad) {
bf.mtx.Lock() bf.mtx.Lock()
defer bf.mtx.Unlock()
bf.msgFilterLoad = filter bf.msgFilterLoad = filter
bf.mtx.Unlock()
} }
// Unload unloads the bloom filter. // Unload unloads the bloom filter.
@ -108,8 +108,8 @@ func (bf *Filter) Reload(filter *wire.MsgFilterLoad) {
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (bf *Filter) Unload() { func (bf *Filter) Unload() {
bf.mtx.Lock() bf.mtx.Lock()
defer bf.mtx.Unlock()
bf.msgFilterLoad = nil bf.msgFilterLoad = nil
bf.mtx.Unlock()
} }
// hash returns the bit offset in the bloom filter which corresponds to the // hash returns the bit offset in the bloom filter which corresponds to the
@ -156,9 +156,8 @@ func (bf *Filter) matches(data []byte) bool {
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (bf *Filter) Matches(data []byte) bool { func (bf *Filter) Matches(data []byte) bool {
bf.mtx.Lock() bf.mtx.Lock()
match := bf.matches(data) defer bf.mtx.Unlock()
bf.mtx.Unlock() return bf.matches(data)
return match
} }
// matchesOutpoint returns true if the bloom filter might contain the passed // matchesOutpoint returns true if the bloom filter might contain the passed
@ -180,9 +179,8 @@ func (bf *Filter) matchesOutpoint(outpoint *wire.Outpoint) bool {
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (bf *Filter) MatchesOutpoint(outpoint *wire.Outpoint) bool { func (bf *Filter) MatchesOutpoint(outpoint *wire.Outpoint) bool {
bf.mtx.Lock() bf.mtx.Lock()
match := bf.matchesOutpoint(outpoint) defer bf.mtx.Unlock()
bf.mtx.Unlock() return bf.matchesOutpoint(outpoint)
return match
} }
// add adds the passed byte slice to the bloom filter. // add adds the passed byte slice to the bloom filter.
@ -211,8 +209,8 @@ func (bf *Filter) add(data []byte) {
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (bf *Filter) Add(data []byte) { func (bf *Filter) Add(data []byte) {
bf.mtx.Lock() bf.mtx.Lock()
defer bf.mtx.Unlock()
bf.add(data) bf.add(data)
bf.mtx.Unlock()
} }
// AddHash adds the passed daghash.Hash to the Filter. // AddHash adds the passed daghash.Hash to the Filter.
@ -220,8 +218,8 @@ func (bf *Filter) Add(data []byte) {
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (bf *Filter) AddHash(hash *daghash.Hash) { func (bf *Filter) AddHash(hash *daghash.Hash) {
bf.mtx.Lock() bf.mtx.Lock()
defer bf.mtx.Unlock()
bf.add(hash[:]) bf.add(hash[:])
bf.mtx.Unlock()
} }
// addOutpoint adds the passed transaction outpoint to the bloom filter. // addOutpoint adds the passed transaction outpoint to the bloom filter.
@ -241,8 +239,8 @@ func (bf *Filter) addOutpoint(outpoint *wire.Outpoint) {
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (bf *Filter) AddOutpoint(outpoint *wire.Outpoint) { func (bf *Filter) AddOutpoint(outpoint *wire.Outpoint) {
bf.mtx.Lock() bf.mtx.Lock()
defer bf.mtx.Unlock()
bf.addOutpoint(outpoint) bf.addOutpoint(outpoint)
bf.mtx.Unlock()
} }
// maybeAddOutpoint potentially adds the passed outpoint to the bloom filter // maybeAddOutpoint potentially adds the passed outpoint to the bloom filter
@ -330,9 +328,8 @@ func (bf *Filter) matchTxAndUpdate(tx *util.Tx) bool {
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (bf *Filter) MatchTxAndUpdate(tx *util.Tx) bool { func (bf *Filter) MatchTxAndUpdate(tx *util.Tx) bool {
bf.mtx.Lock() bf.mtx.Lock()
match := bf.matchTxAndUpdate(tx) defer bf.mtx.Unlock()
bf.mtx.Unlock() return bf.matchTxAndUpdate(tx)
return match
} }
// MsgFilterLoad returns the underlying wire.MsgFilterLoad for the bloom // MsgFilterLoad returns the underlying wire.MsgFilterLoad for the bloom
@ -341,7 +338,6 @@ func (bf *Filter) MatchTxAndUpdate(tx *util.Tx) bool {
// This function is safe for concurrent access. // This function is safe for concurrent access.
func (bf *Filter) MsgFilterLoad() *wire.MsgFilterLoad { func (bf *Filter) MsgFilterLoad() *wire.MsgFilterLoad {
bf.mtx.Lock() bf.mtx.Lock()
msg := bf.msgFilterLoad defer bf.mtx.Unlock()
bf.mtx.Unlock() return bf.msgFilterLoad
return msg
} }

View File

@ -84,11 +84,11 @@ func (wg *waitGroup) done() {
// if there is a listener to wg.releaseWait. // if there is a listener to wg.releaseWait.
if atomic.LoadInt64(&wg.counter) == 0 { if atomic.LoadInt64(&wg.counter) == 0 {
wg.isReleaseWaitWaitingLock.Lock() wg.isReleaseWaitWaitingLock.Lock()
defer wg.isReleaseWaitWaitingLock.Unlock()
if atomic.LoadInt64(&wg.isReleaseWaitWaiting) == 1 { if atomic.LoadInt64(&wg.isReleaseWaitWaiting) == 1 {
wg.releaseWait <- struct{}{} wg.releaseWait <- struct{}{}
<-wg.releaseDone <-wg.releaseDone
} }
wg.isReleaseWaitWaitingLock.Unlock()
} }
} }
@ -96,6 +96,7 @@ func (wg *waitGroup) wait() {
wg.mainWaitLock.Lock() wg.mainWaitLock.Lock()
defer wg.mainWaitLock.Unlock() defer wg.mainWaitLock.Unlock()
wg.isReleaseWaitWaitingLock.Lock() wg.isReleaseWaitWaitingLock.Lock()
defer wg.isReleaseWaitWaitingLock.Unlock()
for atomic.LoadInt64(&wg.counter) != 0 { for atomic.LoadInt64(&wg.counter) != 0 {
atomic.StoreInt64(&wg.isReleaseWaitWaiting, 1) atomic.StoreInt64(&wg.isReleaseWaitWaiting, 1)
wg.isReleaseWaitWaitingLock.Unlock() wg.isReleaseWaitWaitingLock.Unlock()
@ -104,5 +105,4 @@ func (wg *waitGroup) wait() {
wg.releaseDone <- struct{}{} wg.releaseDone <- struct{}{}
wg.isReleaseWaitWaitingLock.Lock() wg.isReleaseWaitWaitingLock.Lock()
} }
wg.isReleaseWaitWaitingLock.Unlock()
} }