diff --git a/kaspad.go b/kaspad.go index 7e75522f4..1fd5c8c96 100644 --- a/kaspad.go +++ b/kaspad.go @@ -32,7 +32,7 @@ type kaspad struct { cfg *config.Config rpcServer *rpc.Server addressManager *addrmgr.AddrManager - networkAdapter *netadapter.NetAdapter + protocolManager *protocol.Manager connectionManager *connmanager.ConnectionManager started, shutdown int32 @@ -47,7 +47,7 @@ func (k *kaspad) start() { log.Trace("Starting kaspad") - err := k.networkAdapter.Start() + err := k.protocolManager.Start() if err != nil { panics.Exit(log, fmt.Sprintf("Error starting the p2p protocol: %+v", err)) } @@ -85,7 +85,7 @@ func (k *kaspad) stop() error { k.connectionManager.Stop() - err := k.networkAdapter.Stop() + err := k.protocolManager.Stop() if err != nil { log.Errorf("Error stopping the p2p protocol: %+v", err) } @@ -123,7 +123,10 @@ func newKaspad(cfg *config.Config, interrupt <-chan struct{}) (*kaspad, error) { } addressManager := addrmgr.New(cfg) - protocol.Init(cfg, netAdapter, addressManager, dag) + protocolManager, err := protocol.NewManager(cfg, dag, addressManager, txMempool) + if err != nil { + return nil, err + } connectionManager, err := connmanager.New(cfg, netAdapter, addressManager) if err != nil { @@ -138,7 +141,7 @@ func newKaspad(cfg *config.Config, interrupt <-chan struct{}) (*kaspad, error) { return &kaspad{ cfg: cfg, rpcServer: rpcServer, - networkAdapter: netAdapter, + protocolManager: protocolManager, connectionManager: connectionManager, }, nil } diff --git a/mempool/mempool.go b/mempool/mempool.go index c742aa929..386471e18 100644 --- a/mempool/mempool.go +++ b/mempool/mempool.go @@ -347,11 +347,11 @@ func (mp *TxPool) removeOrphanDoubleSpends(tx *util.Tx) { // exists in the main pool. // // This function MUST be called with the mempool lock held (for reads). -func (mp *TxPool) isTransactionInPool(hash *daghash.TxID) bool { - if _, exists := mp.pool[*hash]; exists { +func (mp *TxPool) isTransactionInPool(txID *daghash.TxID) bool { + if _, exists := mp.pool[*txID]; exists { return true } - return mp.isInDependPool(hash) + return mp.isInDependPool(txID) } // IsTransactionInPool returns whether or not the passed transaction already @@ -394,8 +394,8 @@ func (mp *TxPool) IsInDependPool(hash *daghash.TxID) bool { // in the orphan pool. // // This function MUST be called with the mempool lock held (for reads). -func (mp *TxPool) isOrphanInPool(hash *daghash.TxID) bool { - if _, exists := mp.orphans[*hash]; exists { +func (mp *TxPool) isOrphanInPool(txID *daghash.TxID) bool { + if _, exists := mp.orphans[*txID]; exists { return true } @@ -419,19 +419,19 @@ func (mp *TxPool) IsOrphanInPool(hash *daghash.TxID) bool { // in the main pool or in the orphan pool. // // This function MUST be called with the mempool lock held (for reads). -func (mp *TxPool) haveTransaction(hash *daghash.TxID) bool { - return mp.isTransactionInPool(hash) || mp.isOrphanInPool(hash) +func (mp *TxPool) haveTransaction(txID *daghash.TxID) bool { + return mp.isTransactionInPool(txID) || mp.isOrphanInPool(txID) } // HaveTransaction returns whether or not the passed transaction already exists // in the main pool or in the orphan pool. // // This function is safe for concurrent access. -func (mp *TxPool) HaveTransaction(hash *daghash.TxID) bool { +func (mp *TxPool) HaveTransaction(txID *daghash.TxID) bool { // Protect concurrent access. mp.mtx.RLock() defer mp.mtx.RUnlock() - haveTx := mp.haveTransaction(hash) + haveTx := mp.haveTransaction(txID) return haveTx } @@ -1338,11 +1338,12 @@ func (mp *TxPool) LastUpdated() mstime.Time { return mstime.UnixMilliseconds(atomic.LoadInt64(&mp.lastUpdated)) } -// HandleNewBlock removes all the transactions in the new block +// HandleNewBlockOld removes all the transactions in the new block // from the mempool and the orphan pool, and it also removes // from the mempool transactions that double spend a // transaction that is already in the DAG -func (mp *TxPool) HandleNewBlock(block *util.Block, txChan chan NewBlockMsg) error { +func (mp *TxPool) HandleNewBlockOld(block *util.Block, txChan chan NewBlockMsg) error { + // TODO(libp2p) Remove this function oldUTXOSet := mp.mpUTXOSet // Remove all of the transactions (except the coinbase) in the @@ -1369,6 +1370,43 @@ func (mp *TxPool) HandleNewBlock(block *util.Block, txChan chan NewBlockMsg) err return nil } +// HandleNewBlock removes all the transactions in the new block +// from the mempool and the orphan pool, and it also removes +// from the mempool transactions that double spend a +// transaction that is already in the DAG +func (mp *TxPool) HandleNewBlock(block *util.Block) ([]*util.Tx, error) { + // Protect concurrent access. + mp.cfg.DAG.RLock() + defer mp.cfg.DAG.RUnlock() + mp.mtx.Lock() + defer mp.mtx.Unlock() + + oldUTXOSet := mp.mpUTXOSet + + // Remove all of the transactions (except the coinbase) in the + // connected block from the transaction pool. Secondly, remove any + // transactions which are now double spends as a result of these + // new transactions. Finally, remove any transaction that is + // no longer an orphan. Transactions which depend on a confirmed + // transaction are NOT removed recursively because they are still + // valid. + err := mp.RemoveTransactions(block.Transactions()[util.CoinbaseTransactionIndex+1:]) + if err != nil { + mp.mpUTXOSet = oldUTXOSet + return nil, err + } + acceptedTxs := make([]*util.Tx, 0) + for _, tx := range block.Transactions()[util.CoinbaseTransactionIndex+1:] { + mp.RemoveDoubleSpends(tx) + mp.RemoveOrphan(tx) + acceptedOrphans := mp.ProcessOrphans(tx) + for _, acceptedOrphan := range acceptedOrphans { + acceptedTxs = append(acceptedTxs, acceptedOrphan.Tx) + } + } + return acceptedTxs, nil +} + // New returns a new memory pool for validating and storing standalone // transactions until they are mined into a block. func New(cfg *Config) *TxPool { diff --git a/mempool/mempool_test.go b/mempool/mempool_test.go index 0f3c5ba91..bf84bb414 100644 --- a/mempool/mempool_test.go +++ b/mempool/mempool_test.go @@ -276,16 +276,16 @@ func (tc *testContext) mineTransactions(transactions []*util.Tx, numberOfBlocks // Handle new block by pool ch := make(chan NewBlockMsg) go func() { - err = tc.harness.txPool.HandleNewBlock(utilBlock, ch) + err = tc.harness.txPool.HandleNewBlockOld(utilBlock, ch) close(ch) }() - // process messages pushed by HandleNewBlock + // process messages pushed by HandleNewBlockOld for range ch { } - // ensure that HandleNewBlock has not failed + // ensure that HandleNewBlockOld has not failed if err != nil { - tc.t.Fatalf("HandleNewBlock failed to handle block %s", err) + tc.t.Fatalf("HandleNewBlockOld failed to handle block %s", err) } coinbaseTx := block.Transactions[util.CoinbaseTransactionIndex] @@ -1713,11 +1713,11 @@ func TestHandleNewBlock(t *testing.T) { // Handle new block by pool ch := make(chan NewBlockMsg) go func() { - err = harness.txPool.HandleNewBlock(block, ch) + err = harness.txPool.HandleNewBlockOld(block, ch) close(ch) }() - // process messages pushed by HandleNewBlock + // process messages pushed by HandleNewBlockOld blockTransactions := make(map[daghash.TxID]int) for msg := range ch { blockTransactions[*msg.Tx.ID()] = 1 @@ -1734,12 +1734,12 @@ func TestHandleNewBlock(t *testing.T) { } } } - // ensure that HandleNewBlock has not failed + // ensure that HandleNewBlockOld has not failed if err != nil { - t.Fatalf("HandleNewBlock failed to handle block %v", err) + t.Fatalf("HandleNewBlockOld failed to handle block %v", err) } - // Validate messages pushed by HandleNewBlock into the channel + // Validate messages pushed by HandleNewBlockOld into the channel if len(blockTransactions) != 2 { t.Fatalf("Wrong size of blockTransactions after new block handling") } diff --git a/netsync/manager.go b/netsync/manager.go index 8f3bdbf30..281faa4e2 100644 --- a/netsync/manager.go +++ b/netsync/manager.go @@ -1045,10 +1045,10 @@ func (sm *SyncManager) handleBlockDAGNotification(notification *blockdag.Notific // Update mempool ch := make(chan mempool.NewBlockMsg) spawn("SPAWN_PLACEHOLDER_NAME", func() { - err := sm.txMemPool.HandleNewBlock(block, ch) + err := sm.txMemPool.HandleNewBlockOld(block, ch) close(ch) if err != nil { - panic(fmt.Sprintf("HandleNewBlock failed to handle block %s", block.Hash())) + panic(fmt.Sprintf("HandleNewBlockOld failed to handle block %s", block.Hash())) } }) diff --git a/protocol/blocks.go b/protocol/blocks.go new file mode 100644 index 000000000..f2dfae521 --- /dev/null +++ b/protocol/blocks.go @@ -0,0 +1,43 @@ +package protocol + +import ( + peerpkg "github.com/kaspanet/kaspad/protocol/peer" + "github.com/kaspanet/kaspad/util" + "github.com/kaspanet/kaspad/util/daghash" + "github.com/kaspanet/kaspad/wire" + "sync/atomic" +) + +// OnNewBlock updates the mempool after a new block arrival, and +// relays newly unorphaned transactions and possibly rebroadcast +// manually added transactions when not in IBD. +// TODO(libp2p) Call this function from IBD as well. +func (m *Manager) OnNewBlock(block *util.Block) error { + transactionsAcceptedToMempool, err := m.txPool.HandleNewBlock(block) + if err != nil { + return err + } + // TODO(libp2p) Notify transactionsAcceptedToMempool to RPC + + m.updateTransactionsToRebroadcast(block) + + // Don't relay transactions when in IBD. + if atomic.LoadUint32(&m.isInIBD) != 0 { + return nil + } + + var txIDsToRebroadcast []*daghash.TxID + if m.shouldRebroadcastTransactions() { + txIDsToRebroadcast = m.txIDsToRebroadcast() + } + + txIDsToBroadcast := make([]*daghash.TxID, len(transactionsAcceptedToMempool)+len(txIDsToRebroadcast)) + for i, tx := range transactionsAcceptedToMempool { + txIDsToBroadcast[i] = tx.ID() + } + + copy(txIDsToBroadcast[len(transactionsAcceptedToMempool):], txIDsToBroadcast) + txIDsToBroadcast = txIDsToBroadcast[:wire.MaxInvPerTxInvMsg] + inv := wire.NewMsgTxInv(txIDsToBroadcast) + return m.netAdapter.Broadcast(peerpkg.ReadyPeerIDs(), inv) +} diff --git a/protocol/flows/addressexchange/receiveaddresses.go b/protocol/flows/addressexchange/receiveaddresses.go index 34bf6eade..55b61b39e 100644 --- a/protocol/flows/addressexchange/receiveaddresses.go +++ b/protocol/flows/addressexchange/receiveaddresses.go @@ -1,18 +1,15 @@ package addressexchange import ( - "time" - "github.com/kaspanet/kaspad/addrmgr" "github.com/kaspanet/kaspad/config" "github.com/kaspanet/kaspad/netadapter/router" + "github.com/kaspanet/kaspad/protocol/common" peerpkg "github.com/kaspanet/kaspad/protocol/peer" "github.com/kaspanet/kaspad/protocol/protocolerrors" "github.com/kaspanet/kaspad/wire" ) -const timeout = 30 * time.Second - // ReceiveAddresses asks a peer for more addresses if needed. func ReceiveAddresses(incomingRoute *router.Route, outgoingRoute *router.Route, cfg *config.Config, peer *peerpkg.Peer, addressManager *addrmgr.AddrManager) (routeClosed bool, err error) { @@ -28,7 +25,7 @@ func ReceiveAddresses(incomingRoute *router.Route, outgoingRoute *router.Route, return true, nil } - message, isOpen, err := incomingRoute.DequeueWithTimeout(timeout) + message, isOpen, err := incomingRoute.DequeueWithTimeout(common.DefaultTimeout) if err != nil { return false, err } diff --git a/protocol/flows/blockrelay/handle_relay_invs.go b/protocol/flows/blockrelay/handle_relay_invs.go index e9cc33063..dac32e09a 100644 --- a/protocol/flows/blockrelay/handle_relay_invs.go +++ b/protocol/flows/blockrelay/handle_relay_invs.go @@ -1,12 +1,11 @@ package blockrelay import ( - "time" - "github.com/kaspanet/kaspad/blockdag" "github.com/kaspanet/kaspad/netadapter" "github.com/kaspanet/kaspad/netadapter/router" "github.com/kaspanet/kaspad/protocol/blocklogger" + "github.com/kaspanet/kaspad/protocol/common" "github.com/kaspanet/kaspad/protocol/flows/ibd" peerpkg "github.com/kaspanet/kaspad/protocol/peer" "github.com/kaspanet/kaspad/protocol/protocolerrors" @@ -17,12 +16,14 @@ import ( "github.com/pkg/errors" ) -const timeout = 30 * time.Second +// NewBlockHandler is a function that is to be +// called when a new block is successfully processed. +type NewBlockHandler func(block *util.Block) error // HandleRelayInvs listens to wire.MsgInvRelayBlock messages, requests their corresponding blocks if they // are missing, adds them to the DAG and propagates them to the rest of the network. func HandleRelayInvs(incomingRoute *router.Route, outgoingRoute *router.Route, - peer *peerpkg.Peer, netAdapter *netadapter.NetAdapter, dag *blockdag.BlockDAG) error { + peer *peerpkg.Peer, netAdapter *netadapter.NetAdapter, dag *blockdag.BlockDAG, newBlockHandler NewBlockHandler) error { invsQueue := make([]*wire.MsgInvRelayBlock, 0) for { @@ -53,7 +54,7 @@ func HandleRelayInvs(incomingRoute *router.Route, outgoingRoute *router.Route, for requestQueue.len() > 0 { shouldStop, err := requestBlocks(netAdapter, outgoingRoute, peer, incomingRoute, dag, &invsQueue, - requestQueue) + requestQueue, newBlockHandler) if err != nil { return err } @@ -87,7 +88,8 @@ func readInv(incomingRoute *router.Route, invsQueue *[]*wire.MsgInvRelayBlock) ( func requestBlocks(netAdapater *netadapter.NetAdapter, outgoingRoute *router.Route, peer *peerpkg.Peer, incomingRoute *router.Route, dag *blockdag.BlockDAG, - invsQueue *[]*wire.MsgInvRelayBlock, requestQueue *hashesQueueSet) (shouldStop bool, err error) { + invsQueue *[]*wire.MsgInvRelayBlock, requestQueue *hashesQueueSet, + newBlockHandler NewBlockHandler) (shouldStop bool, err error) { numHashesToRequest := mathUtil.MinInt(wire.MsgGetRelayBlocksHashes, requestQueue.len()) hashesToRequest := requestQueue.dequeue(numHashesToRequest) @@ -131,7 +133,7 @@ func requestBlocks(netAdapater *netadapter.NetAdapter, outgoingRoute *router.Rou delete(pendingBlocks, *blockHash) requestedBlocks.remove(blockHash) - shouldStop, err = processAndRelayBlock(netAdapater, peer, dag, requestQueue, block) + shouldStop, err = processAndRelayBlock(netAdapater, peer, dag, requestQueue, block, newBlockHandler) if err != nil { return false, err } @@ -149,7 +151,7 @@ func readMsgBlock(incomingRoute *router.Route, invsQueue *[]*wire.MsgInvRelayBlo msgBlock *wire.MsgBlock, shouldStop bool, err error) { for { - message, isOpen, err := incomingRoute.DequeueWithTimeout(timeout) + message, isOpen, err := incomingRoute.DequeueWithTimeout(common.DefaultTimeout) if err != nil { return nil, false, err } @@ -169,7 +171,8 @@ func readMsgBlock(incomingRoute *router.Route, invsQueue *[]*wire.MsgInvRelayBlo } func processAndRelayBlock(netAdapter *netadapter.NetAdapter, peer *peerpkg.Peer, - dag *blockdag.BlockDAG, requestQueue *hashesQueueSet, block *util.Block) (shouldStop bool, err error) { + dag *blockdag.BlockDAG, requestQueue *hashesQueueSet, block *util.Block, + newBlockHandler NewBlockHandler) (shouldStop bool, err error) { blockHash := block.Hash() isOrphan, isDelayed, err := dag.ProcessBlock(block, blockdag.BFNone) @@ -224,12 +227,16 @@ func processAndRelayBlock(netAdapter *netadapter.NetAdapter, peer *peerpkg.Peer, // sm.restartSyncIfNeeded() //// Clear the rejected transactions. //sm.rejectedTxns = make(map[daghash.TxID]struct{}) - err = netAdapter.Broadcast(peerpkg.GetReadyPeerIDs(), block.MsgBlock()) + err = netAdapter.Broadcast(peerpkg.ReadyPeerIDs(), wire.NewMsgInvBlock(blockHash)) if err != nil { return false, err } ibd.StartIBDIfRequired(dag) + err = newBlockHandler(block) + if err != nil { + panic(err) + } return false, nil } diff --git a/protocol/flows/blockrelay/shared_requested_blocks.go b/protocol/flows/blockrelay/shared_requested_blocks.go index 18a8253a6..c32c0d953 100644 --- a/protocol/flows/blockrelay/shared_requested_blocks.go +++ b/protocol/flows/blockrelay/shared_requested_blocks.go @@ -18,6 +18,8 @@ func (s *sharedRequestedBlocks) remove(hash *daghash.Hash) { } func (s *sharedRequestedBlocks) removeSet(blockHashes map[daghash.Hash]struct{}) { + s.Lock() + defer s.Unlock() for hash := range blockHashes { delete(s.blocks, hash) } diff --git a/protocol/flows/handshake/handshake.go b/protocol/flows/handshake/handshake.go index e431a06df..0f99f8f28 100644 --- a/protocol/flows/handshake/handshake.go +++ b/protocol/flows/handshake/handshake.go @@ -1,13 +1,13 @@ package handshake import ( - "sync" - "sync/atomic" - "github.com/kaspanet/kaspad/addrmgr" "github.com/kaspanet/kaspad/blockdag" "github.com/kaspanet/kaspad/config" "github.com/kaspanet/kaspad/netadapter" + "sync" + "sync/atomic" + routerpkg "github.com/kaspanet/kaspad/netadapter/router" "github.com/kaspanet/kaspad/protocol/flows/ibd" peerpkg "github.com/kaspanet/kaspad/protocol/peer" diff --git a/protocol/flows/handshake/receiveversion.go b/protocol/flows/handshake/receiveversion.go index af1edf2f1..0994676d4 100644 --- a/protocol/flows/handshake/receiveversion.go +++ b/protocol/flows/handshake/receiveversion.go @@ -1,11 +1,10 @@ package handshake import ( - "time" - "github.com/kaspanet/kaspad/blockdag" "github.com/kaspanet/kaspad/netadapter" "github.com/kaspanet/kaspad/netadapter/router" + "github.com/kaspanet/kaspad/protocol/common" peerpkg "github.com/kaspanet/kaspad/protocol/peer" "github.com/kaspanet/kaspad/protocol/protocolerrors" "github.com/kaspanet/kaspad/wire" @@ -22,14 +21,12 @@ var ( minAcceptableProtocolVersion = wire.ProtocolVersion ) -const timeout = 30 * time.Second - // ReceiveVersion waits for the peer to send a version message, sends a // verack in response, and updates its info accordingly. func ReceiveVersion(incomingRoute *router.Route, outgoingRoute *router.Route, netAdapter *netadapter.NetAdapter, peer *peerpkg.Peer, dag *blockdag.BlockDAG) (address *wire.NetAddress, routeClosed bool, err error) { - message, isOpen, err := incomingRoute.DequeueWithTimeout(timeout) + message, isOpen, err := incomingRoute.DequeueWithTimeout(common.DefaultTimeout) if err != nil { return nil, false, err } diff --git a/protocol/flows/handshake/sendversion.go b/protocol/flows/handshake/sendversion.go index 240ab0b7f..1b4ba7eb3 100644 --- a/protocol/flows/handshake/sendversion.go +++ b/protocol/flows/handshake/sendversion.go @@ -5,6 +5,7 @@ import ( "github.com/kaspanet/kaspad/config" "github.com/kaspanet/kaspad/netadapter" "github.com/kaspanet/kaspad/netadapter/router" + "github.com/kaspanet/kaspad/protocol/common" "github.com/kaspanet/kaspad/version" "github.com/kaspanet/kaspad/wire" ) @@ -56,7 +57,7 @@ func SendVersion(cfg *config.Config, incomingRoute *router.Route, outgoingRoute return true, nil } - _, isOpen, err = incomingRoute.DequeueWithTimeout(timeout) + _, isOpen, err = incomingRoute.DequeueWithTimeout(common.DefaultTimeout) if err != nil { return false, err } diff --git a/protocol/flows/ibd/ibd.go b/protocol/flows/ibd/ibd.go index f859e2039..a912821a5 100644 --- a/protocol/flows/ibd/ibd.go +++ b/protocol/flows/ibd/ibd.go @@ -18,6 +18,10 @@ var ( startIBDMutex sync.Mutex ) +// NewBlockHandler is a function that is to be +// called when a new block is successfully processed. +type NewBlockHandler func(block *util.Block) error + // StartIBDIfRequired selects a peer and starts IBD against it // if required func StartIBDIfRequired(dag *blockdag.BlockDAG) { @@ -57,10 +61,10 @@ func selectPeerForIBD(dag *blockdag.BlockDAG) *peerpkg.Peer { // HandleIBD waits for IBD start and handles it when IBD is triggered for this peer func HandleIBD(incomingRoute *router.Route, outgoingRoute *router.Route, - peer *peerpkg.Peer, dag *blockdag.BlockDAG) error { + peer *peerpkg.Peer, dag *blockdag.BlockDAG, newBlockHandler NewBlockHandler) error { for { - shouldStop, err := runIBD(incomingRoute, outgoingRoute, peer, dag) + shouldStop, err := runIBD(incomingRoute, outgoingRoute, peer, dag, newBlockHandler) if err != nil { return err } @@ -71,7 +75,7 @@ func HandleIBD(incomingRoute *router.Route, outgoingRoute *router.Route, } func runIBD(incomingRoute *router.Route, outgoingRoute *router.Route, - peer *peerpkg.Peer, dag *blockdag.BlockDAG) (shouldStop bool, err error) { + peer *peerpkg.Peer, dag *blockdag.BlockDAG, newBlockHandler NewBlockHandler) (shouldStop bool, err error) { peer.WaitForIBDStart() defer finishIBD(dag) @@ -90,7 +94,8 @@ func runIBD(incomingRoute *router.Route, outgoingRoute *router.Route, "below the finality point", peer, highestSharedBlockHash) } - shouldStop, err = downloadBlocks(incomingRoute, outgoingRoute, dag, highestSharedBlockHash, peerSelectedTipHash) + shouldStop, err = downloadBlocks(incomingRoute, outgoingRoute, dag, highestSharedBlockHash, peerSelectedTipHash, + newBlockHandler) if err != nil { return false, err } @@ -157,7 +162,8 @@ func receiveBlockLocator(incomingRoute *router.Route) (blockLocatorHashes []*dag } func downloadBlocks(incomingRoute *router.Route, outgoingRoute *router.Route, - dag *blockdag.BlockDAG, highestSharedBlockHash *daghash.Hash, peerSelectedTipHash *daghash.Hash) (shouldStop bool, err error) { + dag *blockdag.BlockDAG, highestSharedBlockHash *daghash.Hash, + peerSelectedTipHash *daghash.Hash, newBlockHandler NewBlockHandler) (shouldStop bool, err error) { shouldStop = sendGetBlocks(outgoingRoute, highestSharedBlockHash, peerSelectedTipHash) if shouldStop { @@ -172,7 +178,7 @@ func downloadBlocks(incomingRoute *router.Route, outgoingRoute *router.Route, if shouldStop { return true, nil } - shouldStop, err = processIBDBlock(dag, msgIBDBlock) + shouldStop, err = processIBDBlock(dag, msgIBDBlock, newBlockHandler) if err != nil { return false, err } @@ -210,7 +216,9 @@ func receiveIBDBlock(incomingRoute *router.Route) (msgIBDBlock *wire.MsgIBDBlock return msgIBDBlock, false, nil } -func processIBDBlock(dag *blockdag.BlockDAG, msgIBDBlock *wire.MsgIBDBlock) (shouldStop bool, err error) { +func processIBDBlock(dag *blockdag.BlockDAG, msgIBDBlock *wire.MsgIBDBlock, + newBlockHandler NewBlockHandler) (shouldStop bool, err error) { + block := util.NewBlock(&msgIBDBlock.MsgBlock) if dag.IsInDAG(block.Hash()) { return false, nil @@ -227,6 +235,10 @@ func processIBDBlock(dag *blockdag.BlockDAG, msgIBDBlock *wire.MsgIBDBlock) (sho return false, protocolerrors.Errorf(false, "received delayed block %s "+ "during IBD", block.Hash()) } + err = newBlockHandler(block) + if err != nil { + panic(err) + } return false, nil } diff --git a/protocol/flows/ping/ping.go b/protocol/flows/ping/ping.go index 35824182c..82cd61cbe 100644 --- a/protocol/flows/ping/ping.go +++ b/protocol/flows/ping/ping.go @@ -1,6 +1,7 @@ package ping import ( + "github.com/kaspanet/kaspad/protocol/common" "time" "github.com/kaspanet/kaspad/netadapter/router" @@ -10,8 +11,6 @@ import ( "github.com/kaspanet/kaspad/wire" ) -const pingTimeout = 30 * time.Second - // ReceivePings handles all ping messages coming through incomingRoute. // This function assumes that incomingRoute will only return MsgPing. func ReceivePings(incomingRoute *router.Route, outgoingRoute *router.Route) error { @@ -51,7 +50,7 @@ func SendPings(incomingRoute *router.Route, outgoingRoute *router.Route, peer *p return nil } - message, isOpen, err := incomingRoute.DequeueWithTimeout(pingTimeout) + message, isOpen, err := incomingRoute.DequeueWithTimeout(common.DefaultTimeout) if err != nil { return err } diff --git a/protocol/flows/relaytransactions/relaytransactions.go b/protocol/flows/relaytransactions/relaytransactions.go new file mode 100644 index 000000000..20a391513 --- /dev/null +++ b/protocol/flows/relaytransactions/relaytransactions.go @@ -0,0 +1,222 @@ +package relaytransactions + +import ( + "github.com/kaspanet/kaspad/blockdag" + "github.com/kaspanet/kaspad/mempool" + "github.com/kaspanet/kaspad/netadapter" + "github.com/kaspanet/kaspad/netadapter/router" + "github.com/kaspanet/kaspad/protocol/common" + peerpkg "github.com/kaspanet/kaspad/protocol/peer" + "github.com/kaspanet/kaspad/protocol/protocolerrors" + "github.com/kaspanet/kaspad/util" + "github.com/kaspanet/kaspad/util/daghash" + "github.com/kaspanet/kaspad/wire" + "github.com/pkg/errors" +) + +// NewBlockHandler is a function that is to be +// called when a new block is successfully processed. +type NewBlockHandler func(block *util.Block) error + +// HandleRelayedTransactions listens to wire.MsgInvTransaction messages, requests their corresponding transactions if they +// are missing, adds them to the mempool and propagates them to the rest of the network. +func HandleRelayedTransactions(incomingRoute *router.Route, outgoingRoute *router.Route, + netAdapter *netadapter.NetAdapter, dag *blockdag.BlockDAG, txPool *mempool.TxPool, + sharedRequestedTransactions *SharedRequestedTransactions) error { + + invsQueue := make([]*wire.MsgInvTransaction, 0) + for { + inv, shouldStop, err := readInv(incomingRoute, &invsQueue) + if err != nil { + return err + } + if shouldStop { + return nil + } + + requestedIDs, shouldStop, err := requestInvTransactions(outgoingRoute, txPool, dag, sharedRequestedTransactions, inv) + if err != nil { + return err + } + if shouldStop { + return nil + } + + shouldStop, err = receiveTransactions(requestedIDs, incomingRoute, &invsQueue, txPool, netAdapter, + sharedRequestedTransactions) + if err != nil { + return err + } + if shouldStop { + return nil + } + } +} + +func requestInvTransactions(outgoingRoute *router.Route, txPool *mempool.TxPool, dag *blockdag.BlockDAG, + sharedRequestedTransactions *SharedRequestedTransactions, inv *wire.MsgInvTransaction) (requestedIDs []*daghash.TxID, + shouldStop bool, err error) { + + idsToRequest := make([]*daghash.TxID, 0, len(inv.TxIDS)) + for _, txID := range inv.TxIDS { + if isKnownTransaction(txPool, dag, txID) { + continue + } + exists := sharedRequestedTransactions.addIfNotExists(txID) + if exists { + continue + } + idsToRequest = append(idsToRequest, txID) + } + + if len(idsToRequest) == 0 { + return idsToRequest, false, nil + } + + msgGetTransactions := wire.NewMsgGetTransactions(idsToRequest) + isOpen := outgoingRoute.Enqueue(msgGetTransactions) + if !isOpen { + sharedRequestedTransactions.removeMany(idsToRequest) + return nil, true, nil + } + return idsToRequest, false, nil +} + +func isKnownTransaction(txPool *mempool.TxPool, dag *blockdag.BlockDAG, txID *daghash.TxID) bool { + // Ask the transaction memory pool if the transaction is known + // to it in any form (main pool or orphan). + if txPool.HaveTransaction(txID) { + return true + } + + // Check if the transaction exists from the point of view of the + // DAG's virtual block. Note that this is only a best effort + // since it is expensive to check existence of every output and + // the only purpose of this check is to avoid downloading + // already known transactions. Only the first two outputs are + // checked because the vast majority of transactions consist of + // two outputs where one is some form of "pay-to-somebody-else" + // and the other is a change output. + prevOut := wire.Outpoint{TxID: *txID} + for i := uint32(0); i < 2; i++ { + prevOut.Index = i + _, ok := dag.GetUTXOEntry(prevOut) + if ok { + return true + } + } + return false +} + +func readInv(incomingRoute *router.Route, invsQueue *[]*wire.MsgInvTransaction) ( + inv *wire.MsgInvTransaction, shouldStop bool, err error) { + + if len(*invsQueue) > 0 { + inv, *invsQueue = (*invsQueue)[0], (*invsQueue)[1:] + return inv, false, nil + } + + msg, isOpen := incomingRoute.Dequeue() + if !isOpen { + return nil, true, nil + } + + inv, ok := msg.(*wire.MsgInvTransaction) + if !ok { + return nil, false, protocolerrors.Errorf(true, "unexpected %s message in the block relay flow while "+ + "expecting an inv message", msg.Command()) + } + return inv, false, nil +} + +func broadcastAcceptedTransactions(netAdapter *netadapter.NetAdapter, acceptedTxs []*mempool.TxDesc) error { + // TODO(libp2p) Add mechanism to avoid sending to other peers invs that are known to them (e.g. mruinvmap) + // TODO(libp2p) Consider broadcasting in bulks + idsToBroadcast := make([]*daghash.TxID, len(acceptedTxs)) + for i, tx := range acceptedTxs { + idsToBroadcast[i] = tx.Tx.ID() + } + inv := wire.NewMsgTxInv(idsToBroadcast) + return netAdapter.Broadcast(peerpkg.ReadyPeerIDs(), inv) +} + +// readMsgTx returns the next msgTx in incomingRoute, and populates invsQueue with any inv messages that meanwhile arrive. +// +// Note: this function assumes msgChan can contain only wire.MsgInvTransaction and wire.MsgBlock messages. +func readMsgTx(incomingRoute *router.Route, invsQueue *[]*wire.MsgInvTransaction) ( + msgTx *wire.MsgTx, shouldStop bool, err error) { + + for { + message, isOpen, err := incomingRoute.DequeueWithTimeout(common.DefaultTimeout) + if err != nil { + return nil, false, err + } + if !isOpen { + return nil, true, nil + } + + switch message := message.(type) { + case *wire.MsgInvTransaction: + *invsQueue = append(*invsQueue, message) + case *wire.MsgTx: + return message, false, nil + default: + panic(errors.Errorf("unexpected message %s", message.Command())) + } + } +} + +func receiveTransactions(requestedTransactions []*daghash.TxID, incomingRoute *router.Route, + invsQueue *[]*wire.MsgInvTransaction, txPool *mempool.TxPool, netAdapter *netadapter.NetAdapter, + sharedRequestedTransactions *SharedRequestedTransactions) (shouldStop bool, err error) { + + // In case the function returns earlier than expected, we want to make sure sharedRequestedTransactions is + // clean from any pending transactions. + defer sharedRequestedTransactions.removeMany(requestedTransactions) + for _, expectedID := range requestedTransactions { + msgTx, shouldStop, err := readMsgTx(incomingRoute, invsQueue) + if err != nil { + return false, err + } + if shouldStop { + return true, nil + } + tx := util.NewTx(msgTx) + if !tx.ID().IsEqual(expectedID) { + return false, protocolerrors.Errorf(true, "expected transaction %s", expectedID) + } + + acceptedTxs, err := txPool.ProcessTransaction(tx, true, 0) // TODO(libp2p) Use the peer ID for the mempool tag + if err != nil { + // When the error is a rule error, it means the transaction was + // simply rejected as opposed to something actually going wrong, + // so log it as such. Otherwise, something really did go wrong, + // so panic. + ruleErr := &mempool.RuleError{} + if !errors.As(err, ruleErr) { + panic(errors.Wrapf(err, "failed to process transaction %s", tx.ID())) + } + + shouldBan := false + if txRuleErr := (&mempool.TxRuleError{}); errors.As(ruleErr.Err, txRuleErr) { + if txRuleErr.RejectCode == wire.RejectInvalid { + shouldBan = true + } + } else if dagRuleErr := (&blockdag.RuleError{}); errors.As(ruleErr.Err, dagRuleErr) { + shouldBan = true + } + + if !shouldBan { + continue + } + + return false, protocolerrors.Errorf(true, "rejected transaction %s", tx.ID()) + } + err = broadcastAcceptedTransactions(netAdapter, acceptedTxs) + if err != nil { + panic(err) + } + // TODO(libp2p) Notify transactionsAcceptedToMempool to RPC + } + return false, nil +} diff --git a/protocol/flows/relaytransactions/shared_requested_transactions.go b/protocol/flows/relaytransactions/shared_requested_transactions.go new file mode 100644 index 000000000..ea9531963 --- /dev/null +++ b/protocol/flows/relaytransactions/shared_requested_transactions.go @@ -0,0 +1,45 @@ +package relaytransactions + +import ( + "github.com/kaspanet/kaspad/util/daghash" + "sync" +) + +// SharedRequestedTransactions is a data structure that is shared between peers that +// holds the IDs of all the requested transactions to prevent redundant requests. +type SharedRequestedTransactions struct { + transactions map[daghash.TxID]struct{} + sync.Mutex +} + +func (s *SharedRequestedTransactions) remove(txID *daghash.TxID) { + s.Lock() + defer s.Unlock() + delete(s.transactions, *txID) +} + +func (s *SharedRequestedTransactions) removeMany(txIDs []*daghash.TxID) { + s.Lock() + defer s.Unlock() + for _, txID := range txIDs { + delete(s.transactions, *txID) + } +} + +func (s *SharedRequestedTransactions) addIfNotExists(txID *daghash.TxID) (exists bool) { + s.Lock() + defer s.Unlock() + _, ok := s.transactions[*txID] + if ok { + return true + } + s.transactions[*txID] = struct{}{} + return false +} + +// NewSharedRequestedTransactions returns a new instance of SharedRequestedTransactions. +func NewSharedRequestedTransactions() *SharedRequestedTransactions { + return &SharedRequestedTransactions{ + transactions: make(map[daghash.TxID]struct{}), + } +} diff --git a/protocol/manager.go b/protocol/manager.go new file mode 100644 index 000000000..aec18de74 --- /dev/null +++ b/protocol/manager.go @@ -0,0 +1,61 @@ +package protocol + +import ( + "github.com/kaspanet/kaspad/addrmgr" + "github.com/kaspanet/kaspad/blockdag" + "github.com/kaspanet/kaspad/config" + "github.com/kaspanet/kaspad/mempool" + "github.com/kaspanet/kaspad/netadapter" + "github.com/kaspanet/kaspad/protocol/flows/relaytransactions" + "github.com/kaspanet/kaspad/util" + "github.com/kaspanet/kaspad/util/daghash" + "sync" + "time" +) + +// Manager manages the p2p protocol +type Manager struct { + cfg *config.Config + netAdapter *netadapter.NetAdapter + txPool *mempool.TxPool + addedTransactions []*util.Tx + dag *blockdag.BlockDAG + addressManager *addrmgr.AddrManager + + transactionsToRebroadcastLock sync.Mutex + transactionsToRebroadcast map[daghash.TxID]*util.Tx + lastRebroadcastTime time.Time + sharedRequestedTransactions *relaytransactions.SharedRequestedTransactions + + isInIBD uint32 // TODO(libp2p) populate this var +} + +// NewManager creates a new instance of the p2p protocol manager +func NewManager(cfg *config.Config, dag *blockdag.BlockDAG, + addressManager *addrmgr.AddrManager, txPool *mempool.TxPool) (*Manager, error) { + + netAdapter, err := netadapter.NewNetAdapter(cfg) + if err != nil { + return nil, err + } + + manager := Manager{ + netAdapter: netAdapter, + dag: dag, + addressManager: addressManager, + txPool: txPool, + sharedRequestedTransactions: relaytransactions.NewSharedRequestedTransactions(), + } + netAdapter.SetRouterInitializer(manager.routerInitializer) + return &manager, nil +} + +// Start starts the p2p protocol +func (m *Manager) Start() error { + return m.netAdapter.Start() +} + +// Stop stops the p2p protocol +func (m *Manager) Stop() error { + return m.netAdapter.Stop() +} diff --git a/protocol/peer/peer.go b/protocol/peer/peer.go index 134cf7de3..ed85256a7 100644 --- a/protocol/peer/peer.go +++ b/protocol/peer/peer.go @@ -139,8 +139,8 @@ func AddToReadyPeers(peer *Peer) error { return nil } -// GetReadyPeerIDs returns the peer IDs of all the ready peers. -func GetReadyPeerIDs() []*id.ID { +// ReadyPeerIDs returns the peer IDs of all the ready peers. +func ReadyPeerIDs() []*id.ID { readyPeersMutex.RLock() defer readyPeersMutex.RUnlock() peerIDs := make([]*id.ID, len(readyPeers)) diff --git a/protocol/protocol.go b/protocol/protocol.go index 016f3ad4c..1527b8d3e 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -4,73 +4,57 @@ import ( "fmt" "sync/atomic" - "github.com/kaspanet/kaspad/config" - "github.com/kaspanet/kaspad/protocol/flows/handshake" "github.com/kaspanet/kaspad/protocol/flows/addressexchange" "github.com/kaspanet/kaspad/protocol/flows/blockrelay" - "github.com/kaspanet/kaspad/addrmgr" - "github.com/kaspanet/kaspad/blockdag" - "github.com/kaspanet/kaspad/netadapter" routerpkg "github.com/kaspanet/kaspad/netadapter/router" "github.com/kaspanet/kaspad/protocol/flows/ibd" "github.com/kaspanet/kaspad/protocol/flows/ping" + "github.com/kaspanet/kaspad/protocol/flows/relaytransactions" peerpkg "github.com/kaspanet/kaspad/protocol/peer" "github.com/kaspanet/kaspad/protocol/protocolerrors" "github.com/kaspanet/kaspad/wire" "github.com/pkg/errors" ) -// Init initializes the p2p protocol -func Init(cfg *config.Config, netAdapter *netadapter.NetAdapter, - addressManager *addrmgr.AddrManager, dag *blockdag.BlockDAG) { +func (m *Manager) routerInitializer() (*routerpkg.Router, error) { - routerInitializer := newRouterInitializer(cfg, netAdapter, addressManager, dag) - netAdapter.SetRouterInitializer(routerInitializer) -} - -func newRouterInitializer(cfg *config.Config, netAdapter *netadapter.NetAdapter, addressManager *addrmgr.AddrManager, - dag *blockdag.BlockDAG) netadapter.RouterInitializer { - - return func() (*routerpkg.Router, error) { - router := routerpkg.NewRouter() - spawn("newRouterInitializer-startFlows", func() { - err := startFlows(cfg, netAdapter, router, dag, addressManager) - if err != nil { - if protocolErr := &(protocolerrors.ProtocolError{}); errors.As(err, &protocolErr) { - if protocolErr.ShouldBan { - // TODO(libp2p) Ban peer - panic("unimplemented") - } - err = netAdapter.DisconnectAssociatedConnection(router) - if err != nil { - panic(err) - } - return + router := routerpkg.NewRouter() + spawn("newRouterInitializer-startFlows", func() { + err := m.startFlows(router) + if err != nil { + if protocolErr := &(protocolerrors.ProtocolError{}); errors.As(err, &protocolErr) { + if protocolErr.ShouldBan { + // TODO(libp2p) Ban peer + panic("unimplemented") } - if errors.Is(err, routerpkg.ErrTimeout) { - err = netAdapter.DisconnectAssociatedConnection(router) - if err != nil { - panic(err) - } - return + err = m.netAdapter.DisconnectAssociatedConnection(router) + if err != nil { + panic(err) } - panic(err) + return } - }) - return router, nil - } + if errors.Is(err, routerpkg.ErrTimeout) { + err = m.netAdapter.DisconnectAssociatedConnection(router) + if err != nil { + panic(err) + } + return + } + panic(err) + } + }) + return router, nil + } -func startFlows(cfg *config.Config, netAdapter *netadapter.NetAdapter, router *routerpkg.Router, dag *blockdag.BlockDAG, - addressManager *addrmgr.AddrManager) error { - +func (m *Manager) startFlows(router *routerpkg.Router) error { stop := make(chan error) stopped := uint32(0) - peer, closed, err := handshake.HandleHandshake(cfg, router, netAdapter, dag, addressManager) + peer, closed, err := handshake.HandleHandshake(m.cfg, router, m.netAdapter, m.dag, m.addressManager) if err != nil { return err } @@ -78,52 +62,53 @@ func startFlows(cfg *config.Config, netAdapter *netadapter.NetAdapter, router *r return nil } - addAddressFlows(cfg, router, &stopped, stop, peer, addressManager) - addBlockRelayFlows(netAdapter, router, &stopped, stop, peer, dag) - addPingFlows(router, &stopped, stop, peer) - addIBDFlows(router, &stopped, stop, peer, dag) + m.addAddressFlows(router, &stopped, stop, peer) + m.addBlockRelayFlows(router, &stopped, stop, peer) + m.addPingFlows(router, &stopped, stop, peer) + m.addIBDFlows(router, &stopped, stop, peer) + m.addTransactionRelayFlow(router, &stopped, stop) err = <-stop return err } -func addAddressFlows(cfg *config.Config, router *routerpkg.Router, stopped *uint32, stop chan error, - peer *peerpkg.Peer, addressManager *addrmgr.AddrManager) { +func (m *Manager) addAddressFlows(router *routerpkg.Router, stopped *uint32, stop chan error, + peer *peerpkg.Peer) { outgoingRoute := router.OutgoingRoute() addOneTimeFlow("SendAddresses", router, []wire.MessageCommand{wire.CmdGetAddresses}, stopped, stop, func(incomingRoute *routerpkg.Route) (routeClosed bool, err error) { - return addressexchange.SendAddresses(incomingRoute, outgoingRoute, addressManager) + return addressexchange.SendAddresses(incomingRoute, outgoingRoute, m.addressManager) }, ) addOneTimeFlow("ReceiveAddresses", router, []wire.MessageCommand{wire.CmdAddress}, stopped, stop, func(incomingRoute *routerpkg.Route) (routeClosed bool, err error) { - return addressexchange.ReceiveAddresses(incomingRoute, outgoingRoute, cfg, peer, addressManager) + return addressexchange.ReceiveAddresses(incomingRoute, outgoingRoute, m.cfg, peer, m.addressManager) }, ) } -func addBlockRelayFlows(netAdapter *netadapter.NetAdapter, router *routerpkg.Router, - stopped *uint32, stop chan error, peer *peerpkg.Peer, dag *blockdag.BlockDAG) { +func (m *Manager) addBlockRelayFlows(router *routerpkg.Router, stopped *uint32, stop chan error, peer *peerpkg.Peer) { outgoingRoute := router.OutgoingRoute() addFlow("HandleRelayInvs", router, []wire.MessageCommand{wire.CmdInvRelayBlock, wire.CmdBlock}, stopped, stop, func(incomingRoute *routerpkg.Route) error { - return blockrelay.HandleRelayInvs(incomingRoute, outgoingRoute, peer, netAdapter, dag) + return blockrelay.HandleRelayInvs(incomingRoute, + outgoingRoute, peer, m.netAdapter, m.dag, m.OnNewBlock) }, ) addFlow("HandleRelayBlockRequests", router, []wire.MessageCommand{wire.CmdGetRelayBlocks}, stopped, stop, func(incomingRoute *routerpkg.Route) error { - return blockrelay.HandleRelayBlockRequests(incomingRoute, outgoingRoute, peer, dag) + return blockrelay.HandleRelayBlockRequests(incomingRoute, outgoingRoute, peer, m.dag) }, ) } -func addPingFlows(router *routerpkg.Router, stopped *uint32, stop chan error, peer *peerpkg.Peer) { +func (m *Manager) addPingFlows(router *routerpkg.Router, stopped *uint32, stop chan error, peer *peerpkg.Peer) { outgoingRoute := router.OutgoingRoute() addFlow("ReceivePings", router, []wire.MessageCommand{wire.CmdPing}, stopped, stop, @@ -139,38 +124,50 @@ func addPingFlows(router *routerpkg.Router, stopped *uint32, stop chan error, pe ) } -func addIBDFlows(router *routerpkg.Router, stopped *uint32, stop chan error, - peer *peerpkg.Peer, dag *blockdag.BlockDAG) { +func (m *Manager) addIBDFlows(router *routerpkg.Router, stopped *uint32, stop chan error, + peer *peerpkg.Peer) { outgoingRoute := router.OutgoingRoute() addFlow("HandleIBD", router, []wire.MessageCommand{wire.CmdBlockLocator, wire.CmdIBDBlock}, stopped, stop, func(incomingRoute *routerpkg.Route) error { - return ibd.HandleIBD(incomingRoute, outgoingRoute, peer, dag) + return ibd.HandleIBD(incomingRoute, outgoingRoute, peer, m.dag, m.OnNewBlock) }, ) addFlow("RequestSelectedTip", router, []wire.MessageCommand{wire.CmdSelectedTip}, stopped, stop, func(incomingRoute *routerpkg.Route) error { - return ibd.RequestSelectedTip(incomingRoute, outgoingRoute, peer, dag) + return ibd.RequestSelectedTip(incomingRoute, outgoingRoute, peer, m.dag) }, ) addFlow("HandleGetSelectedTip", router, []wire.MessageCommand{wire.CmdGetSelectedTip}, stopped, stop, func(incomingRoute *routerpkg.Route) error { - return ibd.HandleGetSelectedTip(incomingRoute, outgoingRoute, dag) + return ibd.HandleGetSelectedTip(incomingRoute, outgoingRoute, m.dag) }, ) addFlow("HandleGetBlockLocator", router, []wire.MessageCommand{wire.CmdGetBlockLocator}, stopped, stop, func(incomingRoute *routerpkg.Route) error { - return ibd.HandleGetBlockLocator(incomingRoute, outgoingRoute, dag) + return ibd.HandleGetBlockLocator(incomingRoute, outgoingRoute, m.dag) }, ) addFlow("HandleGetBlocks", router, []wire.MessageCommand{wire.CmdGetBlocks}, stopped, stop, func(incomingRoute *routerpkg.Route) error { - return ibd.HandleGetBlocks(incomingRoute, outgoingRoute, dag) + return ibd.HandleGetBlocks(incomingRoute, outgoingRoute, m.dag) + }, + ) +} + +func (m *Manager) addTransactionRelayFlow(router *routerpkg.Router, stopped *uint32, stop chan error) { + + outgoingRoute := router.OutgoingRoute() + + addFlow("HandleRelayedTransactions", router, []wire.MessageCommand{wire.CmdInv, wire.CmdTx}, stopped, stop, + func(incomingRoute *routerpkg.Route) error { + return relaytransactions.HandleRelayedTransactions(incomingRoute, outgoingRoute, m.netAdapter, m.dag, + m.txPool, m.sharedRequestedTransactions) }, ) } diff --git a/protocol/transactions.go b/protocol/transactions.go new file mode 100644 index 000000000..99f69aaa2 --- /dev/null +++ b/protocol/transactions.go @@ -0,0 +1,58 @@ +package protocol + +import ( + "github.com/kaspanet/kaspad/protocol/peer" + "github.com/kaspanet/kaspad/util" + "github.com/kaspanet/kaspad/util/daghash" + "github.com/kaspanet/kaspad/wire" + "github.com/pkg/errors" + "time" +) + +// AddTransaction adds transaction to the mempool and propagates it. +func (m *Manager) AddTransaction(tx *util.Tx) error { + m.transactionsToRebroadcastLock.Lock() + defer m.transactionsToRebroadcastLock.Unlock() + + transactionsAcceptedToMempool, err := m.txPool.ProcessTransaction(tx, false, 0) + if err != nil { + return err + } + + if len(transactionsAcceptedToMempool) > 1 { + panic(errors.New("got more than one accepted transactions when no orphans were allowed")) + } + + m.transactionsToRebroadcast[*tx.ID()] = tx + inv := wire.NewMsgTxInv([]*daghash.TxID{tx.ID()}) + return m.netAdapter.Broadcast(peer.ReadyPeerIDs(), inv) +} + +func (m *Manager) updateTransactionsToRebroadcast(block *util.Block) { + m.transactionsToRebroadcastLock.Lock() + defer m.transactionsToRebroadcastLock.Unlock() + // Note: if the block is red, its transactions won't be rebroadcasted + // anymore, although they are not included in the UTXO set. + // This is probably ok, since red blocks are quite rare. + for _, tx := range block.Transactions() { + delete(m.transactionsToRebroadcast, *tx.ID()) + } +} + +func (m *Manager) shouldRebroadcastTransactions() bool { + const rebroadcastInterval = 30 * time.Second + return time.Since(m.lastRebroadcastTime) > rebroadcastInterval +} + +func (m *Manager) txIDsToRebroadcast() []*daghash.TxID { + m.transactionsToRebroadcastLock.Lock() + defer m.transactionsToRebroadcastLock.Unlock() + + txIDs := make([]*daghash.TxID, len(m.transactionsToRebroadcast)) + i := 0 + for _, tx := range m.transactionsToRebroadcast { + txIDs[i] = tx.ID() + i++ + } + return txIDs +} diff --git a/wire/message.go b/wire/message.go index 592eee18d..e5642448d 100644 --- a/wire/message.go +++ b/wire/message.go @@ -61,7 +61,9 @@ const ( CmdInvRelayBlock MessageCommand = 22 CmdGetRelayBlocks MessageCommand = 23 CmdRejectMalformed MessageCommand = 24 // Used only for reject message - CmdIBDBlock MessageCommand = 25 + CmdInvTransaction MessageCommand = 25 + CmdGetTransactions MessageCommand = 26 + CmdIBDBlock MessageCommand = 27 ) var messageCommandToString = map[MessageCommand]string{ @@ -90,6 +92,8 @@ var messageCommandToString = map[MessageCommand]string{ CmdInvRelayBlock: "InvRelayBlock", CmdGetRelayBlocks: "GetRelayBlocks", CmdRejectMalformed: "RejectMalformed", + CmdInvTransaction: "InvTransaction", + CmdGetTransactions: "GetTransactions", CmdIBDBlock: "IBDBlock", } diff --git a/wire/msggetrelayblocks.go b/wire/msggetrelayblocks.go index 4f07aab41..daa4c284a 100644 --- a/wire/msggetrelayblocks.go +++ b/wire/msggetrelayblocks.go @@ -37,7 +37,7 @@ func (msg *MsgGetRelayBlocks) Command() MessageCommand { // MaxPayloadLength returns the maximum length the payload can be for the // receiver. This is part of the Message interface implementation. func (msg *MsgGetRelayBlocks) MaxPayloadLength(pver uint32) uint32 { - return daghash.HashSize + return daghash.HashSize*MsgGetRelayBlocksHashes + uint32(VarIntSerializeSize(MsgGetRelayBlocksHashes)) } // NewMsgGetRelayBlocks returns a new kaspa getrelblks message that conforms to diff --git a/wire/msggettransactions.go b/wire/msggettransactions.go new file mode 100644 index 000000000..03e77d20a --- /dev/null +++ b/wire/msggettransactions.go @@ -0,0 +1,49 @@ +package wire + +import ( + "github.com/kaspanet/kaspad/util/daghash" + "io" +) + +// MaxInvPerGetTransactionsMsg is the maximum number of hashes that can +// be in a single CmdInvTransaction message. +const MaxInvPerGetTransactionsMsg = MaxInvPerMsg + +// MsgGetTransactions implements the Message interface and represents a kaspa +// GetTransactions message. It is used to request transactions as part of the +// transactions relay protocol. +type MsgGetTransactions struct { + IDs []*daghash.TxID +} + +// KaspaDecode decodes r using the kaspa protocol encoding into the receiver. +// This is part of the Message interface implementation. +func (msg *MsgGetTransactions) KaspaDecode(r io.Reader, pver uint32) error { + return ReadElement(r, &msg.IDs) +} + +// KaspaEncode encodes the receiver to w using the kaspa protocol encoding. +// This is part of the Message interface implementation. +func (msg *MsgGetTransactions) KaspaEncode(w io.Writer, pver uint32) error { + return WriteElement(w, msg.IDs) +} + +// Command returns the protocol command string for the message. This is part +// of the Message interface implementation. +func (msg *MsgGetTransactions) Command() MessageCommand { + return CmdGetTransactions +} + +// MaxPayloadLength returns the maximum length the payload can be for the +// receiver. This is part of the Message interface implementation. +func (msg *MsgGetTransactions) MaxPayloadLength(pver uint32) uint32 { + return daghash.TxIDSize*MaxInvPerGetTransactionsMsg + uint32(VarIntSerializeSize(MaxInvPerGetTransactionsMsg)) +} + +// NewMsgGetTransactions returns a new kaspa GetTransactions message that conforms to +// the Message interface. See MsgGetTransactions for details. +func NewMsgGetTransactions(ids []*daghash.TxID) *MsgGetTransactions { + return &MsgGetTransactions{ + IDs: ids, + } +} diff --git a/wire/msgibdblock_test.go b/wire/msgibdblock_test.go index b09153706..18bb89f0a 100644 --- a/wire/msgibdblock_test.go +++ b/wire/msgibdblock_test.go @@ -26,7 +26,7 @@ func TestIBDBlock(t *testing.T) { bh := NewBlockHeader(1, parentHashes, hashMerkleRoot, acceptedIDMerkleRoot, utxoCommitment, bits, nonce) // Ensure the command is expected value. - wantCmd := MessageCommand(25) + wantCmd := MessageCommand(27) msg := NewMsgIBDBlock(NewMsgBlock(bh)) if cmd := msg.Command(); cmd != wantCmd { t.Errorf("NewMsgIBDBlock: wrong command - got %v want %v", diff --git a/wire/msginvtransaction.go b/wire/msginvtransaction.go new file mode 100644 index 000000000..f23d1df02 --- /dev/null +++ b/wire/msginvtransaction.go @@ -0,0 +1,49 @@ +package wire + +import ( + "github.com/kaspanet/kaspad/util/daghash" + "io" +) + +// MaxInvPerTxInvMsg is the maximum number of hashes that can +// be in a single CmdInvTransaction message. +const MaxInvPerTxInvMsg = MaxInvPerMsg + +// MsgInvTransaction implements the Message interface and represents a kaspa +// TxInv message. It is used to notify the network about new transactions +// by sending their ID, and let the receiving node decide if it needs it. +type MsgInvTransaction struct { + TxIDS []*daghash.TxID +} + +// KaspaDecode decodes r using the kaspa protocol encoding into the receiver. +// This is part of the Message interface implementation. +func (msg *MsgInvTransaction) KaspaDecode(r io.Reader, pver uint32) error { + return ReadElement(r, &msg.TxIDS) +} + +// KaspaEncode encodes the receiver to w using the kaspa protocol encoding. +// This is part of the Message interface implementation. +func (msg *MsgInvTransaction) KaspaEncode(w io.Writer, pver uint32) error { + return WriteElement(w, msg.TxIDS) +} + +// Command returns the protocol command string for the message. This is part +// of the Message interface implementation. +func (msg *MsgInvTransaction) Command() MessageCommand { + return CmdInvTransaction +} + +// MaxPayloadLength returns the maximum length the payload can be for the +// receiver. This is part of the Message interface implementation. +func (msg *MsgInvTransaction) MaxPayloadLength(pver uint32) uint32 { + return daghash.TxIDSize*MaxInvPerTxInvMsg + uint32(VarIntSerializeSize(MaxInvPerTxInvMsg)) +} + +// NewMsgTxInv returns a new kaspa TxInv message that conforms to +// the Message interface. See MsgInvTransaction for details. +func NewMsgTxInv(ids []*daghash.TxID) *MsgInvTransaction { + return &MsgInvTransaction{ + TxIDS: ids, + } +}