diff --git a/blockdag/blockindex.go b/blockdag/blockindex.go index ae237d634..83aa9951d 100644 --- a/blockdag/blockindex.go +++ b/blockdag/blockindex.go @@ -43,8 +43,8 @@ func newBlockIndex(db database.DB, dagParams *dagconfig.Params) *blockIndex { // This function is safe for concurrent access. func (bi *blockIndex) HaveBlock(hash *daghash.Hash) bool { bi.RLock() + defer bi.RUnlock() _, hasBlock := bi.index[*hash] - bi.RUnlock() return hasBlock } @@ -54,8 +54,8 @@ func (bi *blockIndex) HaveBlock(hash *daghash.Hash) bool { // This function is safe for concurrent access. func (bi *blockIndex) LookupNode(hash *daghash.Hash) *blockNode { bi.RLock() + defer bi.RUnlock() node := bi.index[*hash] - bi.RUnlock() return node } @@ -65,9 +65,9 @@ func (bi *blockIndex) LookupNode(hash *daghash.Hash) *blockNode { // This function is safe for concurrent access. func (bi *blockIndex) AddNode(node *blockNode) { bi.Lock() + defer bi.Unlock() bi.addNode(node) bi.dirty[node] = struct{}{} - bi.Unlock() } // addNode adds the provided node to the block index, but does not mark it as @@ -83,8 +83,8 @@ func (bi *blockIndex) addNode(node *blockNode) { // This function is safe for concurrent access. func (bi *blockIndex) NodeStatus(node *blockNode) blockStatus { bi.RLock() + defer bi.RUnlock() status := node.status - bi.RUnlock() return status } @@ -95,9 +95,9 @@ func (bi *blockIndex) NodeStatus(node *blockNode) blockStatus { // This function is safe for concurrent access. func (bi *blockIndex) SetStatusFlags(node *blockNode, flags blockStatus) { bi.Lock() + defer bi.Unlock() node.status |= flags bi.dirty[node] = struct{}{} - bi.Unlock() } // UnsetStatusFlags flips the provided status flags on the block node to off, @@ -106,9 +106,9 @@ func (bi *blockIndex) SetStatusFlags(node *blockNode, flags blockStatus) { // This function is safe for concurrent access. func (bi *blockIndex) UnsetStatusFlags(node *blockNode, flags blockStatus) { bi.Lock() + defer bi.Unlock() node.status &^= flags bi.dirty[node] = struct{}{} - bi.Unlock() } // flushToDB writes all dirty block nodes to the database. If all writes diff --git a/blockdag/dag.go b/blockdag/dag.go index a13c44352..5ef25a519 100644 --- a/blockdag/dag.go +++ b/blockdag/dag.go @@ -194,8 +194,8 @@ func (dag *BlockDAG) IsKnownOrphan(hash *daghash.Hash) bool { // Protect concurrent access. Using a read lock only so multiple // readers can query without blocking each other. dag.orphanLock.RLock() + defer dag.orphanLock.RUnlock() _, exists := dag.orphans[*hash] - dag.orphanLock.RUnlock() return exists } diff --git a/blockdag/indexers/addrindex.go b/blockdag/indexers/addrindex.go index 09e1533e1..279f016b3 100644 --- a/blockdag/indexers/addrindex.go +++ b/blockdag/indexers/addrindex.go @@ -771,6 +771,7 @@ func (idx *AddrIndex) indexUnconfirmedAddresses(scriptPubKey []byte, tx *util.Tx // Add a mapping from the address to the transaction. idx.unconfirmedLock.Lock() + defer idx.unconfirmedLock.Unlock() addrIndexEntry := idx.txnsByAddr[addrKey] if addrIndexEntry == nil { addrIndexEntry = make(map[daghash.TxID]*util.Tx) @@ -785,7 +786,6 @@ func (idx *AddrIndex) indexUnconfirmedAddresses(scriptPubKey []byte, tx *util.Tx idx.addrsByTx[*tx.ID()] = addrsByTxEntry } addrsByTxEntry[addrKey] = struct{}{} - idx.unconfirmedLock.Unlock() } // AddUnconfirmedTx adds all addresses related to the transaction to the diff --git a/blockdag/notifications.go b/blockdag/notifications.go index 16a9be48d..8699b589b 100644 --- a/blockdag/notifications.go +++ b/blockdag/notifications.go @@ -57,8 +57,8 @@ type Notification struct { // NotificationType for details on the types and contents of notifications. func (dag *BlockDAG) Subscribe(callback NotificationCallback) { dag.notificationsLock.Lock() + defer dag.notificationsLock.Unlock() dag.notifications = append(dag.notifications, callback) - dag.notificationsLock.Unlock() } // sendNotification sends a notification with the passed type and data if the @@ -68,10 +68,10 @@ func (dag *BlockDAG) sendNotification(typ NotificationType, data interface{}) { // Generate and send the notification. n := Notification{Type: typ, Data: data} dag.notificationsLock.RLock() + defer dag.notificationsLock.RUnlock() for _, callback := range dag.notifications { callback(&n) } - dag.notificationsLock.RUnlock() } // BlockAddedNotificationData defines data to be sent along with a BlockAdded diff --git a/blockdag/virtualblock.go b/blockdag/virtualblock.go index 102b8acca..17a98600c 100644 --- a/blockdag/virtualblock.go +++ b/blockdag/virtualblock.go @@ -113,8 +113,8 @@ func (v *virtualBlock) updateSelectedParentSet(oldSelectedParent *blockNode) *ch // This function is safe for concurrent access. func (v *virtualBlock) SetTips(tips blockSet) { v.mtx.Lock() + defer v.mtx.Unlock() v.setTips(tips) - v.mtx.Unlock() } // addTip adds the given tip to the set of tips in the virtual block. diff --git a/cmd/kaspactl/httpclient.go b/cmd/kaspactl/httpclient.go index eadd0f274..c6d133bb4 100644 --- a/cmd/kaspactl/httpclient.go +++ b/cmd/kaspactl/httpclient.go @@ -95,11 +95,12 @@ func sendPostRequest(marshalledJSON []byte, cfg *ConfigFlags) ([]byte, error) { } // Read the raw bytes and close the response. - respBytes, err := ioutil.ReadAll(httpResponse.Body) - httpResponse.Body.Close() + respBytes, err := func() ([]byte, error) { + defer httpResponse.Body.Close() + return ioutil.ReadAll(httpResponse.Body) + }() if err != nil { - err = errors.Errorf("error reading json reply: %s", err) - return nil, err + return nil, errors.Wrap(err, "error reading json reply") } // Handle unsuccessful HTTP responses diff --git a/connmgr/connmanager.go b/connmgr/connmanager.go index a4cbd4b32..316e3119c 100644 --- a/connmgr/connmanager.go +++ b/connmgr/connmanager.go @@ -89,8 +89,8 @@ type ConnReq struct { // updateState updates the state of the connection request. func (c *ConnReq) updateState(state ConnState) { c.stateMtx.Lock() + defer c.stateMtx.Unlock() c.state = state - c.stateMtx.Unlock() } // ID returns a unique identifier for the connection request. @@ -101,8 +101,8 @@ func (c *ConnReq) ID() uint64 { // State is the connection state of the requested connection. func (c *ConnReq) State() ConnState { c.stateMtx.RLock() + defer c.stateMtx.RUnlock() state := c.state - c.stateMtx.RUnlock() return state } diff --git a/connmgr/dynamicbanscore.go b/connmgr/dynamicbanscore.go index 0431d0a0f..23d7e2bff 100644 --- a/connmgr/dynamicbanscore.go +++ b/connmgr/dynamicbanscore.go @@ -68,9 +68,9 @@ type DynamicBanScore struct { // String returns the ban score as a human-readable string. func (s *DynamicBanScore) String() string { s.mtx.Lock() + defer s.mtx.Unlock() r := fmt.Sprintf("persistent %d + transient %f at %d = %d as of now", s.persistent, s.transient, s.lastUnix, s.Int()) - s.mtx.Unlock() return r } @@ -80,8 +80,8 @@ func (s *DynamicBanScore) String() string { // This function is safe for concurrent access. func (s *DynamicBanScore) Int() uint32 { s.mtx.Lock() + defer s.mtx.Unlock() r := s.int(time.Now()) - s.mtx.Unlock() return r } @@ -91,8 +91,8 @@ func (s *DynamicBanScore) Int() uint32 { // This function is safe for concurrent access. func (s *DynamicBanScore) Increase(persistent, transient uint32) uint32 { s.mtx.Lock() + defer s.mtx.Unlock() r := s.increase(persistent, transient, time.Now()) - s.mtx.Unlock() return r } @@ -101,10 +101,10 @@ func (s *DynamicBanScore) Increase(persistent, transient uint32) uint32 { // This function is safe for concurrent access. func (s *DynamicBanScore) Reset() { s.mtx.Lock() + defer s.mtx.Unlock() s.persistent = 0 s.transient = 0 s.lastUnix = 0 - s.mtx.Unlock() } // int returns the ban score, the sum of the persistent and decaying scores at a diff --git a/logs/logs.go b/logs/logs.go index 85a03f0bd..efe17220d 100644 --- a/logs/logs.go +++ b/logs/logs.go @@ -335,13 +335,13 @@ func (b *Backend) printf(lvl Level, tag string, format string, args ...interface func (b *Backend) write(lvl Level, bytesToWrite []byte) { b.mu.Lock() + defer b.mu.Unlock() os.Stdout.Write(bytesToWrite) for _, r := range b.rotators { if lvl >= r.logLevel { r.Write(bytesToWrite) } } - b.mu.Unlock() } // Close finalizes all log rotators for this backend diff --git a/mempool/mempool.go b/mempool/mempool.go index 4260edbb5..acbab0a9f 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -375,8 +375,8 @@ func (mp *TxPool) isTransactionInPool(hash *daghash.TxID) bool { func (mp *TxPool) IsTransactionInPool(hash *daghash.TxID) bool { // Protect concurrent access. mp.mtx.RLock() + defer mp.mtx.RUnlock() inPool := mp.isTransactionInPool(hash) - mp.mtx.RUnlock() return inPool } @@ -423,8 +423,8 @@ func (mp *TxPool) isOrphanInPool(hash *daghash.TxID) bool { func (mp *TxPool) IsOrphanInPool(hash *daghash.TxID) bool { // Protect concurrent access. mp.mtx.RLock() + defer mp.mtx.RUnlock() inPool := mp.isOrphanInPool(hash) - mp.mtx.RUnlock() return inPool } @@ -444,8 +444,8 @@ func (mp *TxPool) haveTransaction(hash *daghash.TxID) bool { func (mp *TxPool) HaveTransaction(hash *daghash.TxID) bool { // Protect concurrent access. mp.mtx.RLock() + defer mp.mtx.RUnlock() haveTx := mp.haveTransaction(hash) - mp.mtx.RUnlock() return haveTx } @@ -745,8 +745,8 @@ func (mp *TxPool) checkPoolDoubleSpend(tx *util.Tx) error { // be returned, if not nil will be returned. func (mp *TxPool) CheckSpend(op wire.Outpoint) *util.Tx { mp.mtx.RLock() + defer mp.mtx.RUnlock() txR := mp.outpoints[op] - mp.mtx.RUnlock() return txR } @@ -1207,8 +1207,8 @@ func (mp *TxPool) ProcessTransaction(tx *util.Tx, allowOrphan bool, tag Tag) ([] // This function is safe for concurrent access. func (mp *TxPool) Count() int { mp.mtx.RLock() + defer mp.mtx.RUnlock() count := len(mp.pool) - mp.mtx.RUnlock() return count } @@ -1229,6 +1229,7 @@ func (mp *TxPool) DepCount() int { // This function is safe for concurrent access. func (mp *TxPool) TxIDs() []*daghash.TxID { mp.mtx.RLock() + defer mp.mtx.RUnlock() ids := make([]*daghash.TxID, len(mp.pool)) i := 0 for txID := range mp.pool { @@ -1236,7 +1237,6 @@ func (mp *TxPool) TxIDs() []*daghash.TxID { ids[i] = &idCopy i++ } - mp.mtx.RUnlock() return ids } @@ -1247,13 +1247,13 @@ func (mp *TxPool) TxIDs() []*daghash.TxID { // This function is safe for concurrent access. func (mp *TxPool) TxDescs() []*TxDesc { mp.mtx.RLock() + defer mp.mtx.RUnlock() descs := make([]*TxDesc, len(mp.pool)) i := 0 for _, desc := range mp.pool { descs[i] = desc i++ } - mp.mtx.RUnlock() return descs } @@ -1265,13 +1265,13 @@ func (mp *TxPool) TxDescs() []*TxDesc { // concurrent access as required by the interface contract. func (mp *TxPool) MiningDescs() []*mining.TxDesc { mp.mtx.RLock() + defer mp.mtx.RUnlock() descs := make([]*mining.TxDesc, len(mp.pool)) i := 0 for _, desc := range mp.pool { descs[i] = &desc.TxDesc i++ } - mp.mtx.RUnlock() return descs } diff --git a/mempool/mempool_test.go b/mempool/mempool_test.go index 10ffa7697..fc326d929 100644 --- a/mempool/mempool_test.go +++ b/mempool/mempool_test.go @@ -41,24 +41,23 @@ type fakeDAG struct { // instance. func (s *fakeDAG) BlueScore() uint64 { s.RLock() - blueScore := s.currentBlueScore - s.RUnlock() - return blueScore + defer s.RUnlock() + return s.currentBlueScore } // SetBlueScore sets the current blueScore associated with the fake DAG instance. func (s *fakeDAG) SetBlueScore(blueScore uint64) { s.Lock() + defer s.Unlock() s.currentBlueScore = blueScore - s.Unlock() } // MedianTimePast returns the current median time past associated with the fake // DAG instance. func (s *fakeDAG) MedianTimePast() time.Time { s.RLock() + defer s.RUnlock() mtp := s.medianTimePast - s.RUnlock() return mtp } @@ -66,8 +65,8 @@ func (s *fakeDAG) MedianTimePast() time.Time { // DAG instance. func (s *fakeDAG) SetMedianTimePast(mtp time.Time) { s.Lock() + defer s.Unlock() s.medianTimePast = mtp - s.Unlock() } func calcSequenceLock(tx *util.Tx, @@ -1319,9 +1318,11 @@ func TestOrphanChainRemoval(t *testing.T) { // Remove the first orphan that starts the orphan chain without the // remove redeemer flag set and ensure that only the first orphan was // removed. - harness.txPool.mtx.Lock() - harness.txPool.removeOrphan(chainedTxns[1], false) - harness.txPool.mtx.Unlock() + func() { + harness.txPool.mtx.Lock() + defer harness.txPool.mtx.Unlock() + harness.txPool.removeOrphan(chainedTxns[1], false) + }() testPoolMembership(tc, chainedTxns[1], false, false, false) for _, tx := range chainedTxns[2 : maxOrphans+1] { testPoolMembership(tc, tx, true, false, false) @@ -1329,9 +1330,11 @@ func TestOrphanChainRemoval(t *testing.T) { // Remove the first remaining orphan that starts the orphan chain with // the remove redeemer flag set and ensure they are all removed. - harness.txPool.mtx.Lock() - harness.txPool.removeOrphan(chainedTxns[2], true) - harness.txPool.mtx.Unlock() + func() { + harness.txPool.mtx.Lock() + defer harness.txPool.mtx.Unlock() + harness.txPool.removeOrphan(chainedTxns[2], true) + }() for _, tx := range chainedTxns[2 : maxOrphans+1] { testPoolMembership(tc, tx, false, false, false) } diff --git a/peer/mruinvmap.go b/peer/mruinvmap.go index a13d454e6..5f0e9889d 100644 --- a/peer/mruinvmap.go +++ b/peer/mruinvmap.go @@ -50,8 +50,8 @@ func (m *mruInventoryMap) String() string { // This function is safe for concurrent access. func (m *mruInventoryMap) Exists(iv *wire.InvVect) bool { m.invMtx.Lock() + defer m.invMtx.Unlock() _, exists := m.invMap[*iv] - m.invMtx.Unlock() return exists } @@ -106,11 +106,11 @@ func (m *mruInventoryMap) Add(iv *wire.InvVect) { // This function is safe for concurrent access. func (m *mruInventoryMap) Delete(iv *wire.InvVect) { m.invMtx.Lock() + defer m.invMtx.Unlock() if node, exists := m.invMap[*iv]; exists { m.invList.Remove(node) delete(m.invMap, *iv) } - m.invMtx.Unlock() } // newMruInventoryMap returns a new inventory map that is limited to the number diff --git a/peer/mrunoncemap.go b/peer/mrunoncemap.go index d8fc792b7..2931fb3bb 100644 --- a/peer/mrunoncemap.go +++ b/peer/mrunoncemap.go @@ -48,8 +48,8 @@ func (m *mruNonceMap) String() string { // This function is safe for concurrent access. func (m *mruNonceMap) Exists(nonce uint64) bool { m.mtx.Lock() + defer m.mtx.Unlock() _, exists := m.nonceMap[nonce] - m.mtx.Unlock() return exists } @@ -104,11 +104,11 @@ func (m *mruNonceMap) Add(nonce uint64) { // This function is safe for concurrent access. func (m *mruNonceMap) Delete(nonce uint64) { m.mtx.Lock() + defer m.mtx.Unlock() if node, exists := m.nonceMap[nonce]; exists { m.nonceList.Remove(node) delete(m.nonceMap, nonce) } - m.mtx.Unlock() } // newMruNonceMap returns a new nonce map that is limited to the number of diff --git a/peer/peer.go b/peer/peer.go index 1837c3da0..6617ee599 100644 --- a/peer/peer.go +++ b/peer/peer.go @@ -456,14 +456,16 @@ func (p *Peer) AddKnownInventory(invVect *wire.InvVect) { // This function is safe for concurrent access. func (p *Peer) StatsSnapshot() *StatsSnap { p.statsMtx.RLock() + defer p.statsMtx.RUnlock() p.flagsMtx.Lock() + defer p.flagsMtx.Unlock() + id := p.id addr := p.addr userAgent := p.userAgent services := p.services protocolVersion := p.advertisedProtoVer - p.flagsMtx.Unlock() // Get a copy of all relevant flags and stats. statsSnap := &StatsSnap{ @@ -485,7 +487,6 @@ func (p *Peer) StatsSnapshot() *StatsSnap { LastPingTime: p.lastPingTime, } - p.statsMtx.RUnlock() return statsSnap } @@ -494,10 +495,8 @@ func (p *Peer) StatsSnapshot() *StatsSnap { // This function is safe for concurrent access. func (p *Peer) ID() int32 { p.flagsMtx.Lock() - id := p.id - p.flagsMtx.Unlock() - - return id + defer p.flagsMtx.Unlock() + return p.id } // NA returns the peer network address. @@ -505,10 +504,8 @@ func (p *Peer) ID() int32 { // This function is safe for concurrent access. func (p *Peer) NA() *wire.NetAddress { p.flagsMtx.Lock() - na := p.na - p.flagsMtx.Unlock() - - return na + defer p.flagsMtx.Unlock() + return p.na } // Addr returns the peer address. @@ -532,10 +529,8 @@ func (p *Peer) Inbound() bool { // This function is safe for concurrent access. func (p *Peer) Services() wire.ServiceFlag { p.flagsMtx.Lock() - services := p.services - p.flagsMtx.Unlock() - - return services + defer p.flagsMtx.Unlock() + return p.services } // UserAgent returns the user agent of the remote peer. @@ -543,19 +538,15 @@ func (p *Peer) Services() wire.ServiceFlag { // This function is safe for concurrent access. func (p *Peer) UserAgent() string { p.flagsMtx.Lock() - userAgent := p.userAgent - p.flagsMtx.Unlock() - - return userAgent + defer p.flagsMtx.Unlock() + return p.userAgent } // SubnetworkID returns peer subnetwork ID func (p *Peer) SubnetworkID() *subnetworkid.SubnetworkID { p.flagsMtx.Lock() - subnetworkID := p.cfg.SubnetworkID - p.flagsMtx.Unlock() - - return subnetworkID + defer p.flagsMtx.Unlock() + return p.cfg.SubnetworkID } // LastPingNonce returns the last ping nonce of the remote peer. @@ -563,10 +554,8 @@ func (p *Peer) SubnetworkID() *subnetworkid.SubnetworkID { // This function is safe for concurrent access. func (p *Peer) LastPingNonce() uint64 { p.statsMtx.RLock() - lastPingNonce := p.lastPingNonce - p.statsMtx.RUnlock() - - return lastPingNonce + defer p.statsMtx.RUnlock() + return p.lastPingNonce } // LastPingTime returns the last ping time of the remote peer. @@ -574,10 +563,8 @@ func (p *Peer) LastPingNonce() uint64 { // This function is safe for concurrent access. func (p *Peer) LastPingTime() time.Time { p.statsMtx.RLock() - lastPingTime := p.lastPingTime - p.statsMtx.RUnlock() - - return lastPingTime + defer p.statsMtx.RUnlock() + return p.lastPingTime } // LastPingMicros returns the last ping micros of the remote peer. @@ -585,10 +572,8 @@ func (p *Peer) LastPingTime() time.Time { // This function is safe for concurrent access. func (p *Peer) LastPingMicros() int64 { p.statsMtx.RLock() - lastPingMicros := p.lastPingMicros - p.statsMtx.RUnlock() - - return lastPingMicros + defer p.statsMtx.RUnlock() + return p.lastPingMicros } // VersionKnown returns the whether or not the version of a peer is known @@ -597,10 +582,8 @@ func (p *Peer) LastPingMicros() int64 { // This function is safe for concurrent access. func (p *Peer) VersionKnown() bool { p.flagsMtx.Lock() - versionKnown := p.versionKnown - p.flagsMtx.Unlock() - - return versionKnown + defer p.flagsMtx.Unlock() + return p.versionKnown } // VerAckReceived returns whether or not a verack message was received by the @@ -609,10 +592,8 @@ func (p *Peer) VersionKnown() bool { // This function is safe for concurrent access. func (p *Peer) VerAckReceived() bool { p.flagsMtx.Lock() - verAckReceived := p.verAckReceived - p.flagsMtx.Unlock() - - return verAckReceived + defer p.flagsMtx.Unlock() + return p.verAckReceived } // ProtocolVersion returns the negotiated peer protocol version. @@ -620,10 +601,8 @@ func (p *Peer) VerAckReceived() bool { // This function is safe for concurrent access. func (p *Peer) ProtocolVersion() uint32 { p.flagsMtx.Lock() - protocolVersion := p.protocolVersion - p.flagsMtx.Unlock() - - return protocolVersion + defer p.flagsMtx.Unlock() + return p.protocolVersion } // SelectedTipHash returns the selected tip of the peer. @@ -631,14 +610,14 @@ func (p *Peer) ProtocolVersion() uint32 { // This function is safe for concurrent access. func (p *Peer) SelectedTipHash() *daghash.Hash { p.statsMtx.RLock() - selectedTipHash := p.selectedTipHash - p.statsMtx.RUnlock() - - return selectedTipHash + defer p.statsMtx.RUnlock() + return p.selectedTipHash } // SetSelectedTipHash sets the selected tip of the peer. func (p *Peer) SetSelectedTipHash(selectedTipHash *daghash.Hash) { + p.statsMtx.Lock() + defer p.statsMtx.Unlock() p.selectedTipHash = selectedTipHash } @@ -683,10 +662,8 @@ func (p *Peer) BytesReceived() uint64 { // This function is safe for concurrent access. func (p *Peer) TimeConnected() time.Time { p.statsMtx.RLock() - timeConnected := p.timeConnected - p.statsMtx.RUnlock() - - return timeConnected + defer p.statsMtx.RUnlock() + return p.timeConnected } // TimeOffset returns the number of seconds the local time was offset from the @@ -696,10 +673,8 @@ func (p *Peer) TimeConnected() time.Time { // This function is safe for concurrent access. func (p *Peer) TimeOffset() int64 { p.statsMtx.RLock() - timeOffset := p.timeOffset - p.statsMtx.RUnlock() - - return timeOffset + defer p.statsMtx.RUnlock() + return p.timeOffset } // localVersionMsg creates a version message that can be used to send to the @@ -804,19 +779,21 @@ func (p *Peer) PushGetBlockLocatorMsg(highHash, lowHash *daghash.Hash) { p.QueueMessage(msg, nil) } +func (p *Peer) isDuplicateGetBlockInvsMsg(lowHash, highHash *daghash.Hash) bool { + p.prevGetBlockInvsMtx.Lock() + defer p.prevGetBlockInvsMtx.Unlock() + return p.prevGetBlockInvsHigh != nil && p.prevGetBlockInvsLow != nil && + lowHash != nil && highHash.IsEqual(p.prevGetBlockInvsHigh) && + lowHash.IsEqual(p.prevGetBlockInvsLow) +} + // PushGetBlockInvsMsg sends a getblockinvs message for the provided block locator // and high hash. It will ignore back-to-back duplicate requests. // // This function is safe for concurrent access. func (p *Peer) PushGetBlockInvsMsg(lowHash, highHash *daghash.Hash) error { // Filter duplicate getblockinvs requests. - p.prevGetBlockInvsMtx.Lock() - isDuplicate := p.prevGetBlockInvsHigh != nil && p.prevGetBlockInvsLow != nil && - lowHash != nil && highHash.IsEqual(p.prevGetBlockInvsHigh) && - lowHash.IsEqual(p.prevGetBlockInvsLow) - p.prevGetBlockInvsMtx.Unlock() - - if isDuplicate { + if p.isDuplicateGetBlockInvsMsg(lowHash, highHash) { log.Tracef("Filtering duplicate [getblockinvs] with low "+ "hash %s, high hash %s", lowHash, highHash) return nil @@ -829,9 +806,9 @@ func (p *Peer) PushGetBlockInvsMsg(lowHash, highHash *daghash.Hash) error { // Update the previous getblockinvs request information for filtering // duplicates. p.prevGetBlockInvsMtx.Lock() + defer p.prevGetBlockInvsMtx.Unlock() p.prevGetBlockInvsLow = lowHash p.prevGetBlockInvsHigh = highHash - p.prevGetBlockInvsMtx.Unlock() return nil } @@ -913,15 +890,26 @@ func (p *Peer) handleRemoteVersionMsg(msg *wire.MsgVersion) error { return errors.New("incompatible subnetworks") } - // Updating a bunch of stats including block based stats, and the - // peer's time offset. + p.updateStatsFromVersionMsg(msg) + p.updateFlagsFromVersionMsg(msg) + + return nil +} + +// updateStatsFromVersionMsg updates a bunch of stats including block based stats, and the +// peer's time offset. +func (p *Peer) updateStatsFromVersionMsg(msg *wire.MsgVersion) { p.statsMtx.Lock() + defer p.statsMtx.Unlock() p.selectedTipHash = msg.SelectedTipHash p.timeOffset = msg.Timestamp.Unix() - time.Now().Unix() - p.statsMtx.Unlock() +} +func (p *Peer) updateFlagsFromVersionMsg(msg *wire.MsgVersion) { // Negotiate the protocol version. p.flagsMtx.Lock() + defer p.flagsMtx.Unlock() + p.advertisedProtoVer = uint32(msg.ProtocolVersion) p.protocolVersion = minUint32(p.protocolVersion, p.advertisedProtoVer) p.versionKnown = true @@ -937,10 +925,6 @@ func (p *Peer) handleRemoteVersionMsg(msg *wire.MsgVersion) error { // Set the remote peer's user agent. p.userAgent = msg.UserAgent - - p.flagsMtx.Unlock() - - return nil } // handlePingMsg is invoked when a peer receives a ping kaspa message. For @@ -963,12 +947,12 @@ func (p *Peer) handlePongMsg(msg *wire.MsgPong) { // without large usage of the ping rpc call since we ping infrequently // enough that if they overlap we would have timed out the peer. p.statsMtx.Lock() + defer p.statsMtx.Unlock() if p.lastPingNonce != 0 && msg.Nonce == p.lastPingNonce { p.lastPingMicros = time.Since(p.lastPingTime).Nanoseconds() p.lastPingMicros /= 1000 // convert to usec. p.lastPingNonce = 0 } - p.statsMtx.Unlock() } // readMessage reads the next kaspa message from the peer with logging. @@ -1346,9 +1330,7 @@ out: "disconnecting", p) break out } - p.flagsMtx.Lock() - p.verAckReceived = true - p.flagsMtx.Unlock() + p.markVerAckReceived() if p.cfg.Listeners.OnVerAck != nil { p.cfg.Listeners.OnVerAck(p, msg) } @@ -1475,6 +1457,12 @@ out: log.Tracef("Peer input handler done for %s", p) } +func (p *Peer) markVerAckReceived() { + p.flagsMtx.Lock() + defer p.flagsMtx.Unlock() + p.verAckReceived = true +} + // queueHandler handles the queuing of outgoing data for the peer. This runs as // a muxer for various sources of input so we can ensure that server and peer // handlers will not block on us sending a message. That data is then passed on @@ -1622,10 +1610,12 @@ out: case msg := <-p.sendQueue: switch m := msg.msg.(type) { case *wire.MsgPing: - p.statsMtx.Lock() - p.lastPingNonce = m.Nonce - p.lastPingTime = time.Now() - p.statsMtx.Unlock() + func() { + p.statsMtx.Lock() + defer p.statsMtx.Unlock() + p.lastPingNonce = m.Nonce + p.lastPingTime = time.Now() + }() } p.stallControl <- stallControlMsg{sccSendMessage, msg.msg} diff --git a/rpcclient/infrastructure.go b/rpcclient/infrastructure.go index 7d1abffd6..938498bcc 100644 --- a/rpcclient/infrastructure.go +++ b/rpcclient/infrastructure.go @@ -442,9 +442,8 @@ out: // is being reassigned during a reconnect. func (c *Client) disconnectChan() <-chan struct{} { c.mtx.Lock() - ch := c.disconnect - c.mtx.Unlock() - return ch + defer c.mtx.Unlock() + return c.disconnect } // wsOutHandler handles all outgoing messages for the websocket connection. It @@ -511,9 +510,11 @@ func (c *Client) reregisterNtfns() error { // the notification state (while not under the lock of course) which // also register it with the remote RPC server, so this prevents double // registrations. - c.ntfnStateLock.Lock() - stateCopy := c.ntfnState.Copy() - c.ntfnStateLock.Unlock() + stateCopy := func() *notificationState { + c.ntfnStateLock.Lock() + defer c.ntfnStateLock.Unlock() + return c.ntfnState.Copy() + }() // Reregister notifyblocks if needed. if stateCopy.notifyBlocks { @@ -550,23 +551,9 @@ var ignoreResends = map[string]struct{}{ "rescan": {}, } -// resendRequests resends any requests that had not completed when the client -// disconnected. It is intended to be called once the client has reconnected as -// a separate goroutine. -func (c *Client) resendRequests() { - // Set the notification state back up. If anything goes wrong, - // disconnect the client. - if err := c.reregisterNtfns(); err != nil { - log.Warnf("Unable to re-establish notification state: %s", err) - c.Disconnect() - return - } - - // Since it's possible to block on send and more requests might be - // added by the caller while resending, make a copy of all of the - // requests that need to be resent now and work from the copy. This - // also allows the lock to be released quickly. +func (c *Client) collectResendRequests() []*jsonRequest { c.requestLock.Lock() + defer c.requestLock.Unlock() resendReqs := make([]*jsonRequest, 0, c.requestList.Len()) var nextElem *list.Element for e := c.requestList.Front(); e != nil; e = nextElem { @@ -583,7 +570,26 @@ func (c *Client) resendRequests() { resendReqs = append(resendReqs, jReq) } } - c.requestLock.Unlock() + return resendReqs +} + +// resendRequests resends any requests that had not completed when the client +// disconnected. It is intended to be called once the client has reconnected as +// a separate goroutine. +func (c *Client) resendRequests() { + // Set the notification state back up. If anything goes wrong, + // disconnect the client. + if err := c.reregisterNtfns(); err != nil { + log.Warnf("Unable to re-establish notification state: %s", err) + c.Disconnect() + return + } + + // Since it's possible to block on send and more requests might be + // added by the caller while resending, make a copy of all of the + // requests that need to be resent now and work from the copy. This + // also allows the lock to be released quickly. + resendReqs := c.collectResendRequests() for _, jReq := range resendReqs { // Stop resending commands if the client disconnected again @@ -654,10 +660,12 @@ out: c.wsConn = wsConn c.retryCount = 0 - c.mtx.Lock() - c.disconnect = make(chan struct{}) - c.disconnected = false - c.mtx.Unlock() + func() { + c.mtx.Lock() + defer c.mtx.Unlock() + c.disconnect = make(chan struct{}) + c.disconnected = false + }() // Start processing input and output for the // new connection. @@ -689,11 +697,14 @@ func (c *Client) handleSendPostMessage(details *sendPostDetails) { } // Read the raw bytes and close the response. - respBytes, err := ioutil.ReadAll(httpResponse.Body) - httpResponse.Body.Close() + respBytes, err := func() ([]byte, error) { + defer httpResponse.Body.Close() + return ioutil.ReadAll(httpResponse.Body) + }() if err != nil { - err = errors.Errorf("error reading json reply: %s", err) - jReq.responseChan <- &response{err: err} + jReq.responseChan <- &response{ + err: errors.Wrap(err, "error reading json reply"), + } return } diff --git a/rpcmodel/command_info.go b/rpcmodel/command_info.go index 59a4f1b04..07d1ed24a 100644 --- a/rpcmodel/command_info.go +++ b/rpcmodel/command_info.go @@ -10,15 +10,38 @@ import ( "strings" ) +func concreteTypeToMethodWithRLock(rt reflect.Type) (string, bool) { + registerLock.RLock() + defer registerLock.RUnlock() + method, ok := concreteTypeToMethod[rt] + return method, ok +} + +func methodToInfoWithRLock(method string) (methodInfo, bool) { + registerLock.RLock() + defer registerLock.RUnlock() + info, ok := methodToInfo[method] + return info, ok +} + +func methodConcreteTypeAndInfoWithRLock(method string) (reflect.Type, methodInfo, bool) { + registerLock.RLock() + defer registerLock.RUnlock() + rtp, ok := methodToConcreteType[method] + if !ok { + return nil, methodInfo{}, false + } + info := methodToInfo[method] + return rtp, info, ok +} + // CommandMethod returns the method for the passed command. The provided command // type must be a registered type. All commands provided by this package are // registered by default. func CommandMethod(cmd interface{}) (string, error) { // Look up the cmd type and error out if not registered. rt := reflect.TypeOf(cmd) - registerLock.RLock() - method, ok := concreteTypeToMethod[rt] - registerLock.RUnlock() + method, ok := concreteTypeToMethodWithRLock(rt) if !ok { str := fmt.Sprintf("%q is not registered", method) return "", makeError(ErrUnregisteredMethod, str) @@ -33,9 +56,7 @@ func CommandMethod(cmd interface{}) (string, error) { func MethodUsageFlags(method string) (UsageFlag, error) { // Look up details about the provided method and error out if not // registered. - registerLock.RLock() - info, ok := methodToInfo[method] - registerLock.RUnlock() + info, ok := methodToInfoWithRLock(method) if !ok { str := fmt.Sprintf("%q is not registered", method) return 0, makeError(ErrUnregisteredMethod, str) @@ -225,9 +246,9 @@ func MethodUsageText(method string) (string, error) { // Look up details about the provided method and error out if not // registered. registerLock.RLock() + defer registerLock.RUnlock() rtp, ok := methodToConcreteType[method] info := methodToInfo[method] - registerLock.RUnlock() if !ok { str := fmt.Sprintf("%q is not registered", method) return "", makeError(ErrUnregisteredMethod, str) @@ -241,9 +262,7 @@ func MethodUsageText(method string) (string, error) { // Generate and store the usage string for future calls and return it. usage := methodUsageText(rtp, info.defaults, method) - registerLock.Lock() info.usage = usage methodToInfo[method] = info - registerLock.Unlock() return usage, nil } diff --git a/rpcmodel/command_parse.go b/rpcmodel/command_parse.go index 9fa04d04e..3c143f7d2 100644 --- a/rpcmodel/command_parse.go +++ b/rpcmodel/command_parse.go @@ -39,9 +39,7 @@ func makeParams(rt reflect.Type, rv reflect.Value) []interface{} { func MarshalCommand(id interface{}, cmd interface{}) ([]byte, error) { // Look up the cmd type and error out if not registered. rt := reflect.TypeOf(cmd) - registerLock.RLock() - method, ok := concreteTypeToMethod[rt] - registerLock.RUnlock() + method, ok := concreteTypeToMethodWithRLock(rt) if !ok { str := fmt.Sprintf("%q is not registered", method) return nil, makeError(ErrUnregisteredMethod, str) @@ -109,10 +107,7 @@ func populateDefaults(numParams int, info *methodInfo, rv reflect.Value) { // so long as the method type contained within the marshalled request is // registered. func UnmarshalCommand(r *Request) (interface{}, error) { - registerLock.RLock() - rtp, ok := methodToConcreteType[r.Method] - info := methodToInfo[r.Method] - registerLock.RUnlock() + rtp, info, ok := methodConcreteTypeAndInfoWithRLock(r.Method) if !ok { str := fmt.Sprintf("%q is not registered", r.Method) return nil, makeError(ErrUnregisteredMethod, str) @@ -513,10 +508,7 @@ func assignField(paramNum int, fieldName string, dest reflect.Value, src reflect func NewCommand(method string, args ...interface{}) (interface{}, error) { // Look up details about the provided method. Any methods that aren't // registered are an error. - registerLock.RLock() - rtp, ok := methodToConcreteType[method] - info := methodToInfo[method] - registerLock.RUnlock() + rtp, info, ok := methodConcreteTypeAndInfoWithRLock(method) if !ok { str := fmt.Sprintf("%q is not registered", method) return nil, makeError(ErrUnregisteredMethod, str) diff --git a/rpcmodel/help.go b/rpcmodel/help.go index c54ad33f5..2914d9e4e 100644 --- a/rpcmodel/help.go +++ b/rpcmodel/help.go @@ -507,10 +507,7 @@ func isValidResultType(kind reflect.Kind) bool { func GenerateHelp(method string, descs map[string]string, resultTypes ...interface{}) (string, error) { // Look up details about the provided method and error out if not // registered. - registerLock.RLock() - rtp, ok := methodToConcreteType[method] - info := methodToInfo[method] - registerLock.RUnlock() + rtp, info, ok := methodConcreteTypeAndInfoWithRLock(method) if !ok { str := fmt.Sprintf("%q is not registered", method) return "", makeError(ErrUnregisteredMethod, str) diff --git a/rpcmodel/rpc_results.go b/rpcmodel/rpc_results.go index b645d31c6..82e05e4ab 100644 --- a/rpcmodel/rpc_results.go +++ b/rpcmodel/rpc_results.go @@ -155,7 +155,6 @@ type GetBlockTemplateResult struct { // Optional long polling from BIP 0022. LongPollID string `json:"longPollId,omitempty"` LongPollURI string `json:"longPollUri,omitempty"` - SubmitOld *bool `json:"submitOld,omitempty"` // Basic pool extension from BIP 0023. Target string `json:"target,omitempty"` diff --git a/server/p2p/p2p.go b/server/p2p/p2p.go index 5b54ab936..a14e5588d 100644 --- a/server/p2p/p2p.go +++ b/server/p2p/p2p.go @@ -303,8 +303,8 @@ func (sp *Peer) addressKnown(na *wire.NetAddress) bool { // It is safe for concurrent access. func (sp *Peer) setDisableRelayTx(disable bool) { sp.relayMtx.Lock() + defer sp.relayMtx.Unlock() sp.DisableRelayTx = disable - sp.relayMtx.Unlock() } // relayTxDisabled returns whether or not relaying of transactions for the given @@ -312,10 +312,8 @@ func (sp *Peer) setDisableRelayTx(disable bool) { // It is safe for concurrent access. func (sp *Peer) relayTxDisabled() bool { sp.relayMtx.Lock() - isDisabled := sp.DisableRelayTx - sp.relayMtx.Unlock() - - return isDisabled + defer sp.relayMtx.Unlock() + return sp.DisableRelayTx } // pushAddrMsg sends an addr message to the connected peer using the provided diff --git a/server/rpc/handle_get_block_template.go b/server/rpc/handle_get_block_template.go index c9e96a7fb..c1be0ec08 100644 --- a/server/rpc/handle_get_block_template.go +++ b/server/rpc/handle_get_block_template.go @@ -204,7 +204,7 @@ func handleGetBlockTemplateRequest(s *Server, request *rpcmodel.TemplateRequest, if err := state.updateBlockTemplate(s, useCoinbaseValue); err != nil { return nil, err } - return state.blockTemplateResult(s.cfg.DAG, useCoinbaseValue, nil) + return state.blockTemplateResult(s.cfg.DAG, useCoinbaseValue) } // handleGetBlockTemplateLongPoll is a helper for handleGetBlockTemplateRequest @@ -217,66 +217,23 @@ func handleGetBlockTemplateRequest(s *Server, request *rpcmodel.TemplateRequest, // has passed without finding a solution. func handleGetBlockTemplateLongPoll(s *Server, longPollID string, useCoinbaseValue bool, closeChan <-chan struct{}) (interface{}, error) { state := s.gbtWorkState - state.Lock() - // The state unlock is intentionally not deferred here since it needs to - // be manually unlocked before waiting for a notification about block - // template changes. - if err := state.updateBlockTemplate(s, useCoinbaseValue); err != nil { - state.Unlock() + result, longPollChan, err := blockTemplateOrLongPollChan(s, longPollID, useCoinbaseValue) + if err != nil { return nil, err } - // Just return the current block template if the long poll ID provided by - // the caller is invalid. - parentHashes, lastGenerated, err := decodeLongPollID(longPollID) - if err != nil { - result, err := state.blockTemplateResult(s.cfg.DAG, useCoinbaseValue, nil) - if err != nil { - state.Unlock() - return nil, err - } - - state.Unlock() + if result != nil { return result, nil } - // Return the block template now if the specific block template - // identified by the long poll ID no longer matches the current block - // template as this means the provided template is stale. - areHashesEqual := daghash.AreEqual(state.template.Block.Header.ParentHashes, parentHashes) - if !areHashesEqual || - lastGenerated != state.lastGenerated.Unix() { - - // Include whether or not it is valid to submit work against the - // old block template depending on whether or not a solution has - // already been found and added to the block DAG. - submitOld := areHashesEqual - result, err := state.blockTemplateResult(s.cfg.DAG, useCoinbaseValue, - &submitOld) - if err != nil { - state.Unlock() - return nil, err - } - - state.Unlock() - return result, nil - } - - // Register the parent hashes and last generated time for notifications - // Get a channel that will be notified when the template associated with - // the provided ID is stale and a new block template should be returned to - // the caller. - longPollChan := state.templateUpdateChan(parentHashes, lastGenerated) - state.Unlock() - select { // When the client closes before it's time to send a reply, just return // now so the goroutine doesn't hang around. case <-closeChan: return nil, ErrClientQuit - // Wait until signal received to send the reply. + // Wait until signal received to send the reply. case <-longPollChan: // Fallthrough } @@ -292,8 +249,7 @@ func handleGetBlockTemplateLongPoll(s *Server, longPollID string, useCoinbaseVal // Include whether or not it is valid to submit work against the old // block template depending on whether or not a solution has already // been found and added to the block DAG. - submitOld := areHashesEqual - result, err := state.blockTemplateResult(s.cfg.DAG, useCoinbaseValue, &submitOld) + result, err = state.blockTemplateResult(s.cfg.DAG, useCoinbaseValue) if err != nil { return nil, err } @@ -301,6 +257,61 @@ func handleGetBlockTemplateLongPoll(s *Server, longPollID string, useCoinbaseVal return result, nil } +// blockTemplateOrLongPollChan returns a block template if the +// template identified by the provided long poll ID is stale or +// invalid. Otherwise, it returns a channel that will notify +// when there's a more current template. +func blockTemplateOrLongPollChan(s *Server, longPollID string, useCoinbaseValue bool) (*rpcmodel.GetBlockTemplateResult, chan struct{}, error) { + state := s.gbtWorkState + + state.Lock() + defer state.Unlock() + // The state unlock is intentionally not deferred here since it needs to + // be manually unlocked before waiting for a notification about block + // template changes. + + if err := state.updateBlockTemplate(s, useCoinbaseValue); err != nil { + return nil, nil, err + } + + // Just return the current block template if the long poll ID provided by + // the caller is invalid. + parentHashes, lastGenerated, err := decodeLongPollID(longPollID) + if err != nil { + result, err := state.blockTemplateResult(s.cfg.DAG, useCoinbaseValue) + if err != nil { + return nil, nil, err + } + + return result, nil, nil + } + + // Return the block template now if the specific block template + // identified by the long poll ID no longer matches the current block + // template as this means the provided template is stale. + areHashesEqual := daghash.AreEqual(state.template.Block.Header.ParentHashes, parentHashes) + if !areHashesEqual || + lastGenerated != state.lastGenerated.Unix() { + + // Include whether or not it is valid to submit work against the + // old block template depending on whether or not a solution has + // already been found and added to the block DAG. + result, err := state.blockTemplateResult(s.cfg.DAG, useCoinbaseValue) + if err != nil { + return nil, nil, err + } + + return result, nil, nil + } + + // Register the parent hashes and last generated time for notifications + // Get a channel that will be notified when the template associated with + // the provided ID is stale and a new block template should be returned to + // the caller. + longPollChan := state.templateUpdateChan(parentHashes, lastGenerated) + return nil, longPollChan, nil +} + // handleGetBlockTemplateProposal is a helper for handleGetBlockTemplate which // deals with block proposals. func handleGetBlockTemplateProposal(s *Server, request *rpcmodel.TemplateRequest) (interface{}, error) { @@ -693,7 +704,7 @@ func (state *gbtWorkState) updateBlockTemplate(s *Server, useCoinbaseValue bool) // and returned to the caller. // // This function MUST be called with the state locked. -func (state *gbtWorkState) blockTemplateResult(dag *blockdag.BlockDAG, useCoinbaseValue bool, submitOld *bool) (*rpcmodel.GetBlockTemplateResult, error) { +func (state *gbtWorkState) blockTemplateResult(dag *blockdag.BlockDAG, useCoinbaseValue bool) (*rpcmodel.GetBlockTemplateResult, error) { // Ensure the timestamps are still in valid range for the template. // This should really only ever happen if the local clock is changed // after the template is generated, but it's important to avoid serving @@ -779,7 +790,6 @@ func (state *gbtWorkState) blockTemplateResult(dag *blockdag.BlockDAG, useCoinba UTXOCommitment: header.UTXOCommitment.String(), Version: header.Version, LongPollID: longPollID, - SubmitOld: submitOld, Target: targetDifficulty, MinTime: state.minTimestamp.Unix(), MaxTime: maxTime.Unix(), diff --git a/server/rpc/handle_load_tx_filter.go b/server/rpc/handle_load_tx_filter.go index 6a09acbd2..fabe9f926 100644 --- a/server/rpc/handle_load_tx_filter.go +++ b/server/rpc/handle_load_tx_filter.go @@ -30,22 +30,28 @@ func handleLoadTxFilter(wsc *wsClient, icmd interface{}) (interface{}, error) { params := wsc.server.cfg.DAGParams - wsc.Lock() - if cmd.Reload || wsc.filterData == nil { - wsc.filterData = newWSClientFilter(cmd.Addresses, outpoints, - params) - wsc.Unlock() - } else { - wsc.Unlock() + reloadedFilterData := func() bool { + wsc.Lock() + defer wsc.Unlock() + if cmd.Reload || wsc.filterData == nil { + wsc.filterData = newWSClientFilter(cmd.Addresses, outpoints, + params) + return true + } + return false + }() - wsc.filterData.mu.Lock() - for _, a := range cmd.Addresses { - wsc.filterData.addAddressStr(a, params) - } - for i := range outpoints { - wsc.filterData.addUnspentOutpoint(&outpoints[i]) - } - wsc.filterData.mu.Unlock() + if !reloadedFilterData { + func() { + wsc.filterData.mu.Lock() + defer wsc.filterData.mu.Unlock() + for _, a := range cmd.Addresses { + wsc.filterData.addAddressStr(a, params) + } + for i := range outpoints { + wsc.filterData.addUnspentOutpoint(&outpoints[i]) + } + }() } return nil, nil diff --git a/server/rpc/handle_rescan_block_filter.go b/server/rpc/handle_rescan_block_filter.go index ef4631f22..1a86b46bc 100644 --- a/server/rpc/handle_rescan_block_filter.go +++ b/server/rpc/handle_rescan_block_filter.go @@ -16,6 +16,7 @@ func rescanBlockFilter(filter *wsClientFilter, block *util.Block, params *dagcon var transactions []string filter.mu.Lock() + defer filter.mu.Unlock() for _, tx := range block.Transactions() { msgTx := tx.MsgTx() @@ -26,7 +27,7 @@ func rescanBlockFilter(filter *wsClientFilter, block *util.Block, params *dagcon // Scan inputs if not a coinbase transaction. if !msgTx.IsCoinBase() { for _, input := range msgTx.TxIn { - if !filter.existsUnspentOutpoint(&input.PreviousOutpoint) { + if !filter.existsUnspentOutpointNoLock(&input.PreviousOutpoint) { continue } if !added { @@ -65,7 +66,6 @@ func rescanBlockFilter(filter *wsClientFilter, block *util.Block, params *dagcon } } } - filter.mu.Unlock() return transactions } diff --git a/server/rpc/handle_rescan_blocks.go b/server/rpc/handle_rescan_blocks.go index a9da7b8d9..bb816d394 100644 --- a/server/rpc/handle_rescan_blocks.go +++ b/server/rpc/handle_rescan_blocks.go @@ -17,9 +17,7 @@ func handleRescanBlocks(wsc *wsClient, icmd interface{}) (interface{}, error) { } // Load client's transaction filter. Must exist in order to continue. - wsc.Lock() - filter := wsc.filterData - wsc.Unlock() + filter := wsc.FilterData() if filter == nil { return nil, &rpcmodel.RPCError{ Code: rpcmodel.ErrRPCMisc, diff --git a/server/rpc/rpcserver.go b/server/rpc/rpcserver.go index 4c30922ff..9a7f93c51 100644 --- a/server/rpc/rpcserver.go +++ b/server/rpc/rpcserver.go @@ -185,9 +185,12 @@ func (s *Server) httpStatusLine(req *http.Request, code int) string { if !proto11 { key = -key } - s.statusLock.RLock() - line, ok := s.statusLines[key] - s.statusLock.RUnlock() + line, ok := func() (string, bool) { + s.statusLock.RLock() + defer s.statusLock.RUnlock() + line, ok := s.statusLines[key] + return line, ok + }() if ok { return line } @@ -202,8 +205,8 @@ func (s *Server) httpStatusLine(req *http.Request, code int) string { if text != "" { line = proto + " " + codeStr + " " + text + "\r\n" s.statusLock.Lock() + defer s.statusLock.Unlock() s.statusLines[key] = line - s.statusLock.Unlock() } else { text = "status code " + codeStr line = proto + " " + codeStr + " " + text + "\r\n" diff --git a/server/rpc/rpcwebsocket.go b/server/rpc/rpcwebsocket.go index 4bac0a243..0ff003b38 100644 --- a/server/rpc/rpcwebsocket.go +++ b/server/rpc/rpcwebsocket.go @@ -352,15 +352,21 @@ func (f *wsClientFilter) addUnspentOutpoint(op *wire.Outpoint) { f.unspent[*op] = struct{}{} } -// existsUnspentOutpoint returns true if the passed outpoint has been added to +// existsUnspentOutpointNoLock returns true if the passed outpoint has been added to // the wsClientFilter. // // NOTE: This extension was ported from github.com/decred/dcrd -func (f *wsClientFilter) existsUnspentOutpoint(op *wire.Outpoint) bool { +func (f *wsClientFilter) existsUnspentOutpointNoLock(op *wire.Outpoint) bool { _, ok := f.unspent[*op] return ok } +func (f *wsClientFilter) existsUnspentOutpoint(op *wire.Outpoint) bool { + f.mu.Lock() + defer f.mu.Unlock() + return f.existsUnspentOutpointNoLock(op) +} + // Notification types type notificationBlockAdded util.Block type notificationChainChanged struct { @@ -568,17 +574,13 @@ func (m *wsNotificationManager) subscribedClients(tx *util.Tx, msgTx := tx.MsgTx() for _, input := range msgTx.TxIn { for quitChan, wsc := range clients { - wsc.Lock() - filter := wsc.filterData - wsc.Unlock() + filter := wsc.FilterData() if filter == nil { continue } - filter.mu.Lock() if filter.existsUnspentOutpoint(&input.PreviousOutpoint) { subscribed[quitChan] = struct{}{} } - filter.mu.Unlock() } } @@ -591,22 +593,22 @@ func (m *wsNotificationManager) subscribedClients(tx *util.Tx, continue } for quitChan, wsc := range clients { - wsc.Lock() - filter := wsc.filterData - wsc.Unlock() + filter := wsc.FilterData() if filter == nil { continue } - filter.mu.Lock() - if filter.existsAddress(addr) { - subscribed[quitChan] = struct{}{} - op := wire.Outpoint{ - TxID: *tx.ID(), - Index: uint32(i), + func() { + filter.mu.Lock() + defer filter.mu.Unlock() + if filter.existsAddress(addr) { + subscribed[quitChan] = struct{}{} + op := wire.Outpoint{ + TxID: *tx.ID(), + Index: uint32(i), + } + filter.addUnspentOutpoint(&op) } - filter.addUnspentOutpoint(&op) - } - filter.mu.Unlock() + }() } } @@ -1250,10 +1252,8 @@ func (c *wsClient) QueueNotification(marshalledJSON []byte) error { // Disconnected returns whether or not the websocket client is disconnected. func (c *wsClient) Disconnected() bool { c.Lock() - isDisconnected := c.disconnected - c.Unlock() - - return isDisconnected + defer c.Unlock() + return c.disconnected } // Disconnect disconnects the websocket client. @@ -1289,6 +1289,13 @@ func (c *wsClient) WaitForShutdown() { c.wg.Wait() } +// FilterData returns the websocket client filter data. +func (c *wsClient) FilterData() *wsClientFilter { + c.Lock() + defer c.Unlock() + return c.filterData +} + // newWebsocketClient returns a new websocket client given the notification // manager, websocket connection, remote address, and whether or not the client // has already been authenticated (via HTTP Basic access authentication). The diff --git a/txscript/sigcache.go b/txscript/sigcache.go index c3a7f0083..88e9dedbc 100644 --- a/txscript/sigcache.go +++ b/txscript/sigcache.go @@ -57,8 +57,8 @@ func NewSigCache(maxEntries uint) *SigCache { // unless there exists a writer, adding an entry to the SigCache. func (s *SigCache) Exists(sigHash daghash.Hash, sig *ecc.Signature, pubKey *ecc.PublicKey) bool { s.RLock() + defer s.RUnlock() entry, ok := s.validSigs[sigHash] - s.RUnlock() return ok && entry.pubKey.IsEqual(pubKey) && entry.sig.IsEqual(sig) } diff --git a/util/bloom/filter.go b/util/bloom/filter.go index 39c43c96d..ad2c5d157 100644 --- a/util/bloom/filter.go +++ b/util/bloom/filter.go @@ -89,8 +89,8 @@ func LoadFilter(filter *wire.MsgFilterLoad) *Filter { // This function is safe for concurrent access. func (bf *Filter) IsLoaded() bool { bf.mtx.Lock() + defer bf.mtx.Unlock() loaded := bf.msgFilterLoad != nil - bf.mtx.Unlock() return loaded } @@ -99,8 +99,8 @@ func (bf *Filter) IsLoaded() bool { // This function is safe for concurrent access. func (bf *Filter) Reload(filter *wire.MsgFilterLoad) { bf.mtx.Lock() + defer bf.mtx.Unlock() bf.msgFilterLoad = filter - bf.mtx.Unlock() } // Unload unloads the bloom filter. @@ -108,8 +108,8 @@ func (bf *Filter) Reload(filter *wire.MsgFilterLoad) { // This function is safe for concurrent access. func (bf *Filter) Unload() { bf.mtx.Lock() + defer bf.mtx.Unlock() bf.msgFilterLoad = nil - bf.mtx.Unlock() } // hash returns the bit offset in the bloom filter which corresponds to the @@ -156,9 +156,8 @@ func (bf *Filter) matches(data []byte) bool { // This function is safe for concurrent access. func (bf *Filter) Matches(data []byte) bool { bf.mtx.Lock() - match := bf.matches(data) - bf.mtx.Unlock() - return match + defer bf.mtx.Unlock() + return bf.matches(data) } // matchesOutpoint returns true if the bloom filter might contain the passed @@ -180,9 +179,8 @@ func (bf *Filter) matchesOutpoint(outpoint *wire.Outpoint) bool { // This function is safe for concurrent access. func (bf *Filter) MatchesOutpoint(outpoint *wire.Outpoint) bool { bf.mtx.Lock() - match := bf.matchesOutpoint(outpoint) - bf.mtx.Unlock() - return match + defer bf.mtx.Unlock() + return bf.matchesOutpoint(outpoint) } // add adds the passed byte slice to the bloom filter. @@ -211,8 +209,8 @@ func (bf *Filter) add(data []byte) { // This function is safe for concurrent access. func (bf *Filter) Add(data []byte) { bf.mtx.Lock() + defer bf.mtx.Unlock() bf.add(data) - bf.mtx.Unlock() } // AddHash adds the passed daghash.Hash to the Filter. @@ -220,8 +218,8 @@ func (bf *Filter) Add(data []byte) { // This function is safe for concurrent access. func (bf *Filter) AddHash(hash *daghash.Hash) { bf.mtx.Lock() + defer bf.mtx.Unlock() bf.add(hash[:]) - bf.mtx.Unlock() } // addOutpoint adds the passed transaction outpoint to the bloom filter. @@ -241,8 +239,8 @@ func (bf *Filter) addOutpoint(outpoint *wire.Outpoint) { // This function is safe for concurrent access. func (bf *Filter) AddOutpoint(outpoint *wire.Outpoint) { bf.mtx.Lock() + defer bf.mtx.Unlock() bf.addOutpoint(outpoint) - bf.mtx.Unlock() } // maybeAddOutpoint potentially adds the passed outpoint to the bloom filter @@ -330,9 +328,8 @@ func (bf *Filter) matchTxAndUpdate(tx *util.Tx) bool { // This function is safe for concurrent access. func (bf *Filter) MatchTxAndUpdate(tx *util.Tx) bool { bf.mtx.Lock() - match := bf.matchTxAndUpdate(tx) - bf.mtx.Unlock() - return match + defer bf.mtx.Unlock() + return bf.matchTxAndUpdate(tx) } // MsgFilterLoad returns the underlying wire.MsgFilterLoad for the bloom @@ -341,7 +338,6 @@ func (bf *Filter) MatchTxAndUpdate(tx *util.Tx) bool { // This function is safe for concurrent access. func (bf *Filter) MsgFilterLoad() *wire.MsgFilterLoad { bf.mtx.Lock() - msg := bf.msgFilterLoad - bf.mtx.Unlock() - return msg + defer bf.mtx.Unlock() + return bf.msgFilterLoad } diff --git a/util/locks/waitgroup.go b/util/locks/waitgroup.go index cad8a79d8..239d89eb5 100644 --- a/util/locks/waitgroup.go +++ b/util/locks/waitgroup.go @@ -84,11 +84,11 @@ func (wg *waitGroup) done() { // if there is a listener to wg.releaseWait. if atomic.LoadInt64(&wg.counter) == 0 { wg.isReleaseWaitWaitingLock.Lock() + defer wg.isReleaseWaitWaitingLock.Unlock() if atomic.LoadInt64(&wg.isReleaseWaitWaiting) == 1 { wg.releaseWait <- struct{}{} <-wg.releaseDone } - wg.isReleaseWaitWaitingLock.Unlock() } } @@ -96,6 +96,7 @@ func (wg *waitGroup) wait() { wg.mainWaitLock.Lock() defer wg.mainWaitLock.Unlock() wg.isReleaseWaitWaitingLock.Lock() + defer wg.isReleaseWaitWaitingLock.Unlock() for atomic.LoadInt64(&wg.counter) != 0 { atomic.StoreInt64(&wg.isReleaseWaitWaiting, 1) wg.isReleaseWaitWaitingLock.Unlock() @@ -104,5 +105,4 @@ func (wg *waitGroup) wait() { wg.releaseDone <- struct{}{} wg.isReleaseWaitWaitingLock.Lock() } - wg.isReleaseWaitWaitingLock.Unlock() }