From 3e6c1792ef9a581cd82338f227b6d848707cedc9 Mon Sep 17 00:00:00 2001 From: Ori Newman Date: Tue, 21 Jul 2020 12:06:11 +0300 Subject: [PATCH] [NOD-1170] Return a custom error when a route is closed (#805) * [NOD-1170] Return a custom error when a route is closed * [NOD-1170] Return ErrRouteClosed directly from route methods * [NOD-1170] Fix comment location --- netadapter/netadapter.go | 7 +- netadapter/router/route.go | 35 ++++-- netadapter/router/router.go | 6 +- .../server/grpcserver/connection_loops.go | 17 +-- protocol/common/common.go | 4 +- .../flows/addressexchange/receiveaddresses.go | 25 ++-- .../flows/addressexchange/sendaddresses.go | 14 +-- .../blockrelay/handle_relay_block_requests.go | 12 +- .../flows/blockrelay/handle_relay_invs.go | 81 +++++-------- protocol/flows/handshake/handshake.go | 8 +- protocol/flows/handshake/receiveversion.go | 25 ++-- protocol/flows/handshake/sendversion.go | 18 ++- .../flows/ibd/handle_get_block_locator.go | 30 ++--- protocol/flows/ibd/handle_get_blocks.go | 29 ++--- protocol/flows/ibd/ibd.go | 113 +++++++----------- protocol/flows/ibd/selected_tip.go | 64 ++++------ protocol/flows/ping/ping.go | 25 ++-- .../relaytransactions/relaytransactions.go | 69 +++++------ protocol/protocol.go | 13 +- 19 files changed, 264 insertions(+), 331 deletions(-) diff --git a/netadapter/netadapter.go b/netadapter/netadapter.go index b4fe25163..d56071b73 100644 --- a/netadapter/netadapter.go +++ b/netadapter/netadapter.go @@ -176,8 +176,13 @@ func (na *NetAdapter) Broadcast(connectionIDs []*id.ID, message wire.Message) er log.Warnf("connectionID %s is not registered", connectionID) continue } - _, err := router.EnqueueIncomingMessage(message) + err := router.EnqueueIncomingMessage(message) if err != nil { + if errors.Is(err, routerpkg.ErrRouteClosed) { + connection := na.routersToConnections[router] + log.Debugf("Cannot enqueue message to %s: router is closed", connection) + continue + } return err } } diff --git a/netadapter/router/route.go b/netadapter/router/route.go index 5692ca3b2..af2788e85 100644 --- a/netadapter/router/route.go +++ b/netadapter/router/route.go @@ -12,8 +12,13 @@ const ( defaultMaxMessages = 100 ) -// ErrTimeout signifies that one of the router functions had a timeout. -var ErrTimeout = errors.New("timeout expired") +var ( + // ErrTimeout signifies that one of the router functions had a timeout. + ErrTimeout = errors.New("timeout expired") + + // ErrRouteClosed indicates that a route was closed while reading/writing. + ErrRouteClosed = errors.New("route is closed") +) // onCapacityReachedHandler is a function that is to be // called when a route reaches capacity. @@ -43,34 +48,40 @@ func newRouteWithCapacity(capacity int) *Route { } // Enqueue enqueues a message to the Route -func (r *Route) Enqueue(message wire.Message) (isOpen bool) { +func (r *Route) Enqueue(message wire.Message) error { r.closeLock.Lock() defer r.closeLock.Unlock() if r.closed { - return false + return errors.WithStack(ErrRouteClosed) } if len(r.channel) == defaultMaxMessages { r.onCapacityReachedHandler() } r.channel <- message - return true + return nil } // Dequeue dequeues a message from the Route -func (r *Route) Dequeue() (message wire.Message, isOpen bool) { - message, isOpen = <-r.channel - return message, isOpen +func (r *Route) Dequeue() (wire.Message, error) { + message, isOpen := <-r.channel + if !isOpen { + return nil, errors.WithStack(ErrRouteClosed) + } + return message, nil } // DequeueWithTimeout attempts to dequeue a message from the Route // and returns an error if the given timeout expires first. -func (r *Route) DequeueWithTimeout(timeout time.Duration) (message wire.Message, isOpen bool, err error) { +func (r *Route) DequeueWithTimeout(timeout time.Duration) (wire.Message, error) { select { case <-time.After(timeout): - return nil, false, errors.Wrapf(ErrTimeout, "got timeout after %s", timeout) - case message, isOpen = <-r.channel: - return message, isOpen, nil + return nil, errors.Wrapf(ErrTimeout, "got timeout after %s", timeout) + case message, isOpen := <-r.channel: + if !isOpen { + return nil, errors.WithStack(ErrRouteClosed) + } + return message, nil } } diff --git a/netadapter/router/router.go b/netadapter/router/router.go index 9a6a6063d..044796a8b 100644 --- a/netadapter/router/router.go +++ b/netadapter/router/router.go @@ -68,12 +68,12 @@ func (r *Router) RemoveRoute(messageTypes []wire.MessageCommand) error { // EnqueueIncomingMessage enqueues the given message to the // appropriate route -func (r *Router) EnqueueIncomingMessage(message wire.Message) (isOpen bool, err error) { +func (r *Router) EnqueueIncomingMessage(message wire.Message) error { route, ok := r.incomingRoutes[message.Command()] if !ok { - return false, errors.Errorf("a route for '%s' does not exist", message.Command()) + return errors.Errorf("a route for '%s' does not exist", message.Command()) } - return route.Enqueue(message), nil + return route.Enqueue(message) } // OutgoingRoute returns the outgoing route diff --git a/netadapter/server/grpcserver/connection_loops.go b/netadapter/server/grpcserver/connection_loops.go index 84353d421..ee8ddceef 100644 --- a/netadapter/server/grpcserver/connection_loops.go +++ b/netadapter/server/grpcserver/connection_loops.go @@ -1,6 +1,8 @@ package grpcserver import ( + routerpkg "github.com/kaspanet/kaspad/netadapter/router" + "github.com/pkg/errors" "io" "github.com/davecgh/go-spew/spew" @@ -32,9 +34,9 @@ func (c *gRPCConnection) connectionLoops() error { func (c *gRPCConnection) sendLoop() error { outgoingRoute := c.router.OutgoingRoute() for c.IsConnected() { - message, isOpen := outgoingRoute.Dequeue() - if !isOpen { - return nil + message, err := outgoingRoute.Dequeue() + if err != nil { + return err } log.Tracef("outgoing '%s' message to %s: %s", message.Command(), c, logger.NewLogClosure(func() string { @@ -73,13 +75,14 @@ func (c *gRPCConnection) receiveLoop() error { return spew.Sdump(message) })) - isOpen, err := c.router.EnqueueIncomingMessage(message) + err = c.router.EnqueueIncomingMessage(message) if err != nil { + if errors.Is(err, routerpkg.ErrRouteClosed) { + log.Debugf("Router for %s is closed. Exiting the receive loop", c) + return nil + } return err } - if !isOpen { - return nil - } } return nil } diff --git a/protocol/common/common.go b/protocol/common/common.go index 8a766b6b2..e5ed3812a 100644 --- a/protocol/common/common.go +++ b/protocol/common/common.go @@ -1,6 +1,8 @@ package common -import "time" +import ( + "time" +) // DefaultTimeout is the default duration to wait for enqueuing/dequeuing // to/from routes. diff --git a/protocol/flows/addressexchange/receiveaddresses.go b/protocol/flows/addressexchange/receiveaddresses.go index 55b61b39e..7e4a42d99 100644 --- a/protocol/flows/addressexchange/receiveaddresses.go +++ b/protocol/flows/addressexchange/receiveaddresses.go @@ -12,38 +12,35 @@ import ( // ReceiveAddresses asks a peer for more addresses if needed. func ReceiveAddresses(incomingRoute *router.Route, outgoingRoute *router.Route, cfg *config.Config, peer *peerpkg.Peer, - addressManager *addrmgr.AddrManager) (routeClosed bool, err error) { + addressManager *addrmgr.AddrManager) error { if !addressManager.NeedMoreAddresses() { - return false, nil + return nil } subnetworkID := peer.SubnetworkID() msgGetAddresses := wire.NewMsgGetAddresses(false, subnetworkID) - isOpen := outgoingRoute.Enqueue(msgGetAddresses) - if !isOpen { - return true, nil + err := outgoingRoute.Enqueue(msgGetAddresses) + if err != nil { + return err } - message, isOpen, err := incomingRoute.DequeueWithTimeout(common.DefaultTimeout) + message, err := incomingRoute.DequeueWithTimeout(common.DefaultTimeout) if err != nil { - return false, err - } - if !isOpen { - return true, nil + return err } msgAddresses := message.(*wire.MsgAddresses) if len(msgAddresses.AddrList) > addrmgr.GetAddressesMax { - return false, protocolerrors.Errorf(true, "address count excceeded %d", addrmgr.GetAddressesMax) + return protocolerrors.Errorf(true, "address count excceeded %d", addrmgr.GetAddressesMax) } if msgAddresses.IncludeAllSubnetworks { - return false, protocolerrors.Errorf(true, "got unexpected "+ + return protocolerrors.Errorf(true, "got unexpected "+ "IncludeAllSubnetworks=true in [%s] command", msgAddresses.Command()) } if !msgAddresses.SubnetworkID.IsEqual(cfg.SubnetworkID) && msgAddresses.SubnetworkID != nil { - return false, protocolerrors.Errorf(false, "only full nodes and %s subnetwork IDs "+ + return protocolerrors.Errorf(false, "only full nodes and %s subnetwork IDs "+ "are allowed in [%s] command, but got subnetwork ID %s", cfg.SubnetworkID, msgAddresses.Command(), msgAddresses.SubnetworkID) } @@ -52,5 +49,5 @@ func ReceiveAddresses(incomingRoute *router.Route, outgoingRoute *router.Route, // TODO(libp2p) Replace with real peer IP fakeSourceAddress := new(wire.NetAddress) addressManager.AddAddresses(msgAddresses.AddrList, fakeSourceAddress, msgAddresses.SubnetworkID) - return false, nil + return nil } diff --git a/protocol/flows/addressexchange/sendaddresses.go b/protocol/flows/addressexchange/sendaddresses.go index fd144b0e0..b2575b0d6 100644 --- a/protocol/flows/addressexchange/sendaddresses.go +++ b/protocol/flows/addressexchange/sendaddresses.go @@ -9,11 +9,11 @@ import ( // SendAddresses sends addresses to a peer that requests it. func SendAddresses(incomingRoute *router.Route, outgoingRoute *router.Route, - addressManager *addrmgr.AddrManager) (routeClosed bool, err error) { + addressManager *addrmgr.AddrManager) error { - message, isOpen := incomingRoute.Dequeue() - if !isOpen { - return true, nil + message, err := incomingRoute.Dequeue() + if err != nil { + return err } msgGetAddresses := message.(*wire.MsgGetAddresses) @@ -24,11 +24,7 @@ func SendAddresses(incomingRoute *router.Route, outgoingRoute *router.Route, panic(err) } - isOpen = outgoingRoute.Enqueue(msgAddresses) - if !isOpen { - return true, nil - } - return false, nil + return outgoingRoute.Enqueue(msgAddresses) } // shuffleAddresses randomizes the given addresses sent if there are more than the maximum allowed in one message. diff --git a/protocol/flows/blockrelay/handle_relay_block_requests.go b/protocol/flows/blockrelay/handle_relay_block_requests.go index 0cb254e0d..a7b160641 100644 --- a/protocol/flows/blockrelay/handle_relay_block_requests.go +++ b/protocol/flows/blockrelay/handle_relay_block_requests.go @@ -15,9 +15,9 @@ func HandleRelayBlockRequests(incomingRoute *router.Route, outgoingRoute *router peer *peerpkg.Peer, dag *blockdag.BlockDAG) error { for { - message, isOpen := incomingRoute.Dequeue() - if !isOpen { - return nil + message, err := incomingRoute.Dequeue() + if err != nil { + return err } getRelayBlocksMessage := message.(*wire.MsgGetRelayBlocks) for _, hash := range getRelayBlocksMessage.Hashes { @@ -41,9 +41,9 @@ func HandleRelayBlockRequests(incomingRoute *router.Route, outgoingRoute *router msgBlock.ConvertToPartial(peerSubnetworkID) } - isOpen = outgoingRoute.Enqueue(msgBlock) - if !isOpen { - return nil + err = outgoingRoute.Enqueue(msgBlock) + if err != nil { + return err } } } diff --git a/protocol/flows/blockrelay/handle_relay_invs.go b/protocol/flows/blockrelay/handle_relay_invs.go index dac32e09a..7435ace42 100644 --- a/protocol/flows/blockrelay/handle_relay_invs.go +++ b/protocol/flows/blockrelay/handle_relay_invs.go @@ -27,13 +27,10 @@ func HandleRelayInvs(incomingRoute *router.Route, outgoingRoute *router.Route, invsQueue := make([]*wire.MsgInvRelayBlock, 0) for { - inv, shouldStop, err := readInv(incomingRoute, &invsQueue) + inv, err := readInv(incomingRoute, &invsQueue) if err != nil { return err } - if shouldStop { - return nil - } if dag.IsKnownBlock(inv.Hash) { if dag.IsKnownInvalid(inv.Hash) { @@ -53,43 +50,40 @@ func HandleRelayInvs(incomingRoute *router.Route, outgoingRoute *router.Route, requestQueue.enqueueIfNotExists(inv.Hash) for requestQueue.len() > 0 { - shouldStop, err := requestBlocks(netAdapter, outgoingRoute, peer, incomingRoute, dag, &invsQueue, + err := requestBlocks(netAdapter, outgoingRoute, peer, incomingRoute, dag, &invsQueue, requestQueue, newBlockHandler) if err != nil { return err } - if shouldStop { - return nil - } } } } -func readInv(incomingRoute *router.Route, invsQueue *[]*wire.MsgInvRelayBlock) ( - inv *wire.MsgInvRelayBlock, shouldStop bool, err error) { +func readInv(incomingRoute *router.Route, invsQueue *[]*wire.MsgInvRelayBlock) (*wire.MsgInvRelayBlock, error) { if len(*invsQueue) > 0 { + var inv *wire.MsgInvRelayBlock inv, *invsQueue = (*invsQueue)[0], (*invsQueue)[1:] - return inv, false, nil + return inv, nil } - msg, isOpen := incomingRoute.Dequeue() - if !isOpen { - return nil, true, nil + msg, err := incomingRoute.Dequeue() + if err != nil { + return nil, err } inv, ok := msg.(*wire.MsgInvRelayBlock) if !ok { - return nil, false, protocolerrors.Errorf(true, "unexpected %s message in the block relay flow while "+ + return nil, protocolerrors.Errorf(true, "unexpected %s message in the block relay flow while "+ "expecting an inv message", msg.Command()) } - return inv, false, nil + return inv, nil } func requestBlocks(netAdapater *netadapter.NetAdapter, outgoingRoute *router.Route, peer *peerpkg.Peer, incomingRoute *router.Route, dag *blockdag.BlockDAG, invsQueue *[]*wire.MsgInvRelayBlock, requestQueue *hashesQueueSet, - newBlockHandler NewBlockHandler) (shouldStop bool, err error) { + newBlockHandler NewBlockHandler) error { numHashesToRequest := mathUtil.MinInt(wire.MsgGetRelayBlocksHashes, requestQueue.len()) hashesToRequest := requestQueue.dequeue(numHashesToRequest) @@ -111,59 +105,50 @@ func requestBlocks(netAdapater *netadapter.NetAdapter, outgoingRoute *router.Rou defer requestedBlocks.removeSet(pendingBlocks) getRelayBlocksMsg := wire.NewMsgGetRelayBlocks(filteredHashesToRequest) - isOpen := outgoingRoute.Enqueue(getRelayBlocksMsg) - if !isOpen { - return true, nil + err := outgoingRoute.Enqueue(getRelayBlocksMsg) + if err != nil { + return err } for len(pendingBlocks) > 0 { - msgBlock, shouldStop, err := readMsgBlock(incomingRoute, invsQueue) + msgBlock, err := readMsgBlock(incomingRoute, invsQueue) if err != nil { - return false, err - } - if shouldStop { - return true, nil + return err } block := util.NewBlock(msgBlock) blockHash := block.Hash() if _, ok := pendingBlocks[*blockHash]; !ok { - return false, protocolerrors.Errorf(true, "got unrequested block %s", block.Hash()) + return protocolerrors.Errorf(true, "got unrequested block %s", block.Hash()) } delete(pendingBlocks, *blockHash) requestedBlocks.remove(blockHash) - shouldStop, err = processAndRelayBlock(netAdapater, peer, dag, requestQueue, block, newBlockHandler) + err = processAndRelayBlock(netAdapater, peer, dag, requestQueue, block, newBlockHandler) if err != nil { - return false, err - } - if shouldStop { - return true, nil + return err } } - return false, nil + return nil } // readMsgBlock returns the next msgBlock in msgChan, and populates invsQueue with any inv messages that meanwhile arrive. // // Note: this function assumes msgChan can contain only wire.MsgInvRelayBlock and wire.MsgBlock messages. func readMsgBlock(incomingRoute *router.Route, invsQueue *[]*wire.MsgInvRelayBlock) ( - msgBlock *wire.MsgBlock, shouldStop bool, err error) { + msgBlock *wire.MsgBlock, err error) { for { - message, isOpen, err := incomingRoute.DequeueWithTimeout(common.DefaultTimeout) + message, err := incomingRoute.DequeueWithTimeout(common.DefaultTimeout) if err != nil { - return nil, false, err - } - if !isOpen { - return nil, true, nil + return nil, err } switch message := message.(type) { case *wire.MsgInvRelayBlock: *invsQueue = append(*invsQueue, message) case *wire.MsgBlock: - return message, false, nil + return message, nil default: panic(errors.Errorf("unexpected message %s", message.Command())) } @@ -172,7 +157,7 @@ func readMsgBlock(incomingRoute *router.Route, invsQueue *[]*wire.MsgInvRelayBlo func processAndRelayBlock(netAdapter *netadapter.NetAdapter, peer *peerpkg.Peer, dag *blockdag.BlockDAG, requestQueue *hashesQueueSet, block *util.Block, - newBlockHandler NewBlockHandler) (shouldStop bool, err error) { + newBlockHandler NewBlockHandler) error { blockHash := block.Hash() isOrphan, isDelayed, err := dag.ProcessBlock(block, blockdag.BFNone) @@ -187,17 +172,17 @@ func processAndRelayBlock(netAdapter *netadapter.NetAdapter, peer *peerpkg.Peer, log.Infof("Rejected block %s from %s: %s", blockHash, peer, err) - return false, protocolerrors.Wrap(true, err, "got invalid block") + return protocolerrors.Wrap(true, err, "got invalid block") } if isDelayed { - return false, nil + return nil } if isOrphan { blueScore, err := block.BlueScore() if err != nil { - return false, protocolerrors.Errorf(true, "received an orphan "+ + return protocolerrors.Errorf(true, "received an orphan "+ "block %s with malformed blue score", blockHash) } @@ -207,7 +192,7 @@ func processAndRelayBlock(netAdapter *netadapter.NetAdapter, peer *peerpkg.Peer, log.Infof("Orphan block %s has blue score %d and the selected tip blue score is "+ "%d. Ignoring orphans with a blue score difference from the selected tip greater than %d", blockHash, blueScore, selectedTipBlueScore, maxOrphanBlueScoreDiff) - return false, nil + return nil } // Request the parents for the orphan block from the peer that sent it. @@ -215,11 +200,11 @@ func processAndRelayBlock(netAdapter *netadapter.NetAdapter, peer *peerpkg.Peer, for _, missingAncestor := range missingAncestors { requestQueue.enqueueIfNotExists(missingAncestor) } - return false, nil + return nil } err = blocklogger.LogBlock(block) if err != nil { - return false, err + return err } //TODO(libp2p) //// When the block is not an orphan, log information about it and @@ -229,7 +214,7 @@ func processAndRelayBlock(netAdapter *netadapter.NetAdapter, peer *peerpkg.Peer, //sm.rejectedTxns = make(map[daghash.TxID]struct{}) err = netAdapter.Broadcast(peerpkg.ReadyPeerIDs(), wire.NewMsgInvBlock(blockHash)) if err != nil { - return false, err + return err } ibd.StartIBDIfRequired(dag) @@ -238,5 +223,5 @@ func processAndRelayBlock(netAdapter *netadapter.NetAdapter, peer *peerpkg.Peer, panic(err) } - return false, nil + return nil } diff --git a/protocol/flows/handshake/handshake.go b/protocol/flows/handshake/handshake.go index 0f99f8f28..814bc7522 100644 --- a/protocol/flows/handshake/handshake.go +++ b/protocol/flows/handshake/handshake.go @@ -45,11 +45,11 @@ func HandleHandshake(cfg *config.Config, router *routerpkg.Router, netAdapter *n var peerAddress *wire.NetAddress spawn("HandleHandshake-ReceiveVersion", func() { defer wg.Done() - address, closed, err := ReceiveVersion(receiveVersionRoute, router.OutgoingRoute(), netAdapter, peer, dag) + address, err := ReceiveVersion(receiveVersionRoute, router.OutgoingRoute(), netAdapter, peer, dag) if err != nil { log.Errorf("error from ReceiveVersion: %s", err) } - if err != nil || closed { + if err != nil { if atomic.AddUint32(&errChanUsed, 1) != 1 { errChan <- err } @@ -60,11 +60,11 @@ func HandleHandshake(cfg *config.Config, router *routerpkg.Router, netAdapter *n spawn("HandleHandshake-SendVersion", func() { defer wg.Done() - closed, err := SendVersion(cfg, sendVersionRoute, router.OutgoingRoute(), netAdapter, dag) + err := SendVersion(cfg, sendVersionRoute, router.OutgoingRoute(), netAdapter, dag) if err != nil { log.Errorf("error from SendVersion: %s", err) } - if err != nil || closed { + if err != nil { if atomic.AddUint32(&errChanUsed, 1) != 1 { errChan <- err } diff --git a/protocol/flows/handshake/receiveversion.go b/protocol/flows/handshake/receiveversion.go index 0994676d4..cc94783b0 100644 --- a/protocol/flows/handshake/receiveversion.go +++ b/protocol/flows/handshake/receiveversion.go @@ -24,23 +24,20 @@ var ( // ReceiveVersion waits for the peer to send a version message, sends a // verack in response, and updates its info accordingly. func ReceiveVersion(incomingRoute *router.Route, outgoingRoute *router.Route, netAdapter *netadapter.NetAdapter, - peer *peerpkg.Peer, dag *blockdag.BlockDAG) (address *wire.NetAddress, routeClosed bool, err error) { + peer *peerpkg.Peer, dag *blockdag.BlockDAG) (*wire.NetAddress, error) { - message, isOpen, err := incomingRoute.DequeueWithTimeout(common.DefaultTimeout) + message, err := incomingRoute.DequeueWithTimeout(common.DefaultTimeout) if err != nil { - return nil, false, err - } - if !isOpen { - return nil, true, nil + return nil, err } msgVersion, ok := message.(*wire.MsgVersion) if !ok { - return nil, false, protocolerrors.New(true, "a version message must precede all others") + return nil, protocolerrors.New(true, "a version message must precede all others") } if !allowSelfConnections && netAdapter.ID().IsEqual(msgVersion.ID) { - return nil, false, protocolerrors.New(true, "connected to self") + return nil, protocolerrors.New(true, "connected to self") } // Notify and disconnect clients that have a protocol version that is @@ -51,13 +48,13 @@ func ReceiveVersion(incomingRoute *router.Route, outgoingRoute *router.Route, ne // disconnecting. if msgVersion.ProtocolVersion < minAcceptableProtocolVersion { //TODO(libp2p) create error type for disconnect but don't ban - return nil, false, protocolerrors.Errorf(false, "protocol version must be %d or greater", + return nil, protocolerrors.Errorf(false, "protocol version must be %d or greater", minAcceptableProtocolVersion) } // Disconnect from partial nodes in networks that don't allow them if !dag.Params.EnableNonNativeSubnetworks && msgVersion.SubnetworkID != nil { - return nil, false, protocolerrors.New(true, "partial nodes are not allowed") + return nil, protocolerrors.New(true, "partial nodes are not allowed") } // TODO(libp2p) @@ -74,10 +71,10 @@ func ReceiveVersion(incomingRoute *router.Route, outgoingRoute *router.Route, ne //} peer.UpdateFieldsFromMsgVersion(msgVersion) - isOpen = outgoingRoute.Enqueue(wire.NewMsgVerAck()) - if !isOpen { - return nil, true, nil + err = outgoingRoute.Enqueue(wire.NewMsgVerAck()) + if err != nil { + return nil, err } // TODO(libp2p) Register peer ID - return msgVersion.Address, false, nil + return msgVersion.Address, nil } diff --git a/protocol/flows/handshake/sendversion.go b/protocol/flows/handshake/sendversion.go index 1b4ba7eb3..b35c4ef22 100644 --- a/protocol/flows/handshake/sendversion.go +++ b/protocol/flows/handshake/sendversion.go @@ -30,7 +30,7 @@ var ( // SendVersion sends a version to a peer and waits for verack. func SendVersion(cfg *config.Config, incomingRoute *router.Route, outgoingRoute *router.Route, - netAdapter *netadapter.NetAdapter, dag *blockdag.BlockDAG) (routeClosed bool, err error) { + netAdapter *netadapter.NetAdapter, dag *blockdag.BlockDAG) error { selectedTipHash := dag.SelectedTipHash() subnetworkID := cfg.SubnetworkID @@ -52,17 +52,15 @@ func SendVersion(cfg *config.Config, incomingRoute *router.Route, outgoingRoute // Advertise if inv messages for transactions are desired. msg.DisableRelayTx = cfg.BlocksOnly - isOpen := outgoingRoute.Enqueue(msg) - if !isOpen { - return true, nil + err = outgoingRoute.Enqueue(msg) + if err != nil { + return err } - _, isOpen, err = incomingRoute.DequeueWithTimeout(common.DefaultTimeout) + // Wait for verack + _, err = incomingRoute.DequeueWithTimeout(common.DefaultTimeout) if err != nil { - return false, err + return err } - if !isOpen { - return true, nil - } - return false, nil + return nil } diff --git a/protocol/flows/ibd/handle_get_block_locator.go b/protocol/flows/ibd/handle_get_block_locator.go index 6156b6bea..d1dcfd0a6 100644 --- a/protocol/flows/ibd/handle_get_block_locator.go +++ b/protocol/flows/ibd/handle_get_block_locator.go @@ -11,13 +11,10 @@ import ( // HandleGetBlockLocator handles getBlockLocator messages func HandleGetBlockLocator(incomingRoute *router.Route, outgoingRoute *router.Route, dag *blockdag.BlockDAG) error { for { - lowHash, highHash, shouldStop, err := receiveGetBlockLocator(incomingRoute) + lowHash, highHash, err := receiveGetBlockLocator(incomingRoute) if err != nil { return err } - if shouldStop { - return nil - } locator, err := dag.BlockLocatorFromHashes(highHash, lowHash) if err != nil || len(locator) == 0 { @@ -25,27 +22,30 @@ func HandleGetBlockLocator(incomingRoute *router.Route, outgoingRoute *router.Ro "locator between blocks %s and %s", lowHash, highHash) } - shouldStop = sendBlockLocator(outgoingRoute, locator) - if shouldStop { - return nil + err = sendBlockLocator(outgoingRoute, locator) + if err != nil { + return err } } } func receiveGetBlockLocator(incomingRoute *router.Route) (lowHash *daghash.Hash, - highHash *daghash.Hash, shouldStop bool, err error) { + highHash *daghash.Hash, err error) { - message, isOpen := incomingRoute.Dequeue() - if !isOpen { - return nil, nil, true, nil + message, err := incomingRoute.Dequeue() + if err != nil { + return nil, nil, err } msgGetBlockLocator := message.(*wire.MsgGetBlockLocator) - return msgGetBlockLocator.LowHash, msgGetBlockLocator.HighHash, false, nil + return msgGetBlockLocator.LowHash, msgGetBlockLocator.HighHash, nil } -func sendBlockLocator(outgoingRoute *router.Route, locator blockdag.BlockLocator) (shouldStop bool) { +func sendBlockLocator(outgoingRoute *router.Route, locator blockdag.BlockLocator) error { msgBlockLocator := wire.NewMsgBlockLocator(locator) - isOpen := outgoingRoute.Enqueue(msgBlockLocator) - return !isOpen + err := outgoingRoute.Enqueue(msgBlockLocator) + if err != nil { + return err + } + return nil } diff --git a/protocol/flows/ibd/handle_get_blocks.go b/protocol/flows/ibd/handle_get_blocks.go index 95da52cc7..144a1b187 100644 --- a/protocol/flows/ibd/handle_get_blocks.go +++ b/protocol/flows/ibd/handle_get_blocks.go @@ -10,36 +10,33 @@ import ( // HandleGetBlocks handles getBlocks messages func HandleGetBlocks(incomingRoute *router.Route, outgoingRoute *router.Route, dag *blockdag.BlockDAG) error { for { - lowHash, highHash, shouldStop, err := receiveGetBlocks(incomingRoute) + lowHash, highHash, err := receiveGetBlocks(incomingRoute) if err != nil { return err } - if shouldStop { - return nil - } msgIBDBlocks, err := buildMsgIBDBlocks(lowHash, highHash, dag) if err != nil { return err } - shouldStop = sendMsgIBDBlocks(outgoingRoute, msgIBDBlocks) - if shouldStop { + err = sendMsgIBDBlocks(outgoingRoute, msgIBDBlocks) + if err != nil { return nil } } } func receiveGetBlocks(incomingRoute *router.Route) (lowHash *daghash.Hash, - highHash *daghash.Hash, shouldStop bool, err error) { + highHash *daghash.Hash, err error) { - message, isOpen := incomingRoute.Dequeue() - if !isOpen { - return nil, nil, true, nil + message, err := incomingRoute.Dequeue() + if err != nil { + return nil, nil, err } msgGetBlocks := message.(*wire.MsgGetBlocks) - return msgGetBlocks.LowHash, msgGetBlocks.HighHash, false, nil + return msgGetBlocks.LowHash, msgGetBlocks.HighHash, nil } func buildMsgIBDBlocks(lowHash *daghash.Hash, highHash *daghash.Hash, @@ -63,12 +60,12 @@ func buildMsgIBDBlocks(lowHash *daghash.Hash, highHash *daghash.Hash, return msgIBDBlocks, nil } -func sendMsgIBDBlocks(outgoingRoute *router.Route, msgIBDBlocks []*wire.MsgIBDBlock) (shouldStop bool) { +func sendMsgIBDBlocks(outgoingRoute *router.Route, msgIBDBlocks []*wire.MsgIBDBlock) error { for _, msgIBDBlock := range msgIBDBlocks { - isOpen := outgoingRoute.Enqueue(msgIBDBlock) - if !isOpen { - return true + err := outgoingRoute.Enqueue(msgIBDBlock) + if err != nil { + return err } } - return false + return nil } diff --git a/protocol/flows/ibd/ibd.go b/protocol/flows/ibd/ibd.go index a912821a5..76de12d6b 100644 --- a/protocol/flows/ibd/ibd.go +++ b/protocol/flows/ibd/ibd.go @@ -64,62 +64,49 @@ func HandleIBD(incomingRoute *router.Route, outgoingRoute *router.Route, peer *peerpkg.Peer, dag *blockdag.BlockDAG, newBlockHandler NewBlockHandler) error { for { - shouldStop, err := runIBD(incomingRoute, outgoingRoute, peer, dag, newBlockHandler) + err := runIBD(incomingRoute, outgoingRoute, peer, dag, newBlockHandler) if err != nil { return err } - if shouldStop { - return nil - } } } func runIBD(incomingRoute *router.Route, outgoingRoute *router.Route, - peer *peerpkg.Peer, dag *blockdag.BlockDAG, newBlockHandler NewBlockHandler) (shouldStop bool, err error) { + peer *peerpkg.Peer, dag *blockdag.BlockDAG, newBlockHandler NewBlockHandler) error { peer.WaitForIBDStart() defer finishIBD(dag) peerSelectedTipHash := peer.SelectedTipHash() - highestSharedBlockHash, shouldStop, err := findHighestSharedBlockHash(incomingRoute, outgoingRoute, dag, peerSelectedTipHash) + highestSharedBlockHash, err := findHighestSharedBlockHash(incomingRoute, outgoingRoute, dag, peerSelectedTipHash) if err != nil { - return false, err - } - if shouldStop { - return true, nil + return err } if dag.IsKnownFinalizedBlock(highestSharedBlockHash) { - return false, protocolerrors.Errorf(false, "cannot initiate "+ + return protocolerrors.Errorf(false, "cannot initiate "+ "IBD with peer %s because the highest shared chain block (%s) is "+ "below the finality point", peer, highestSharedBlockHash) } - shouldStop, err = downloadBlocks(incomingRoute, outgoingRoute, dag, highestSharedBlockHash, peerSelectedTipHash, + return downloadBlocks(incomingRoute, outgoingRoute, dag, highestSharedBlockHash, peerSelectedTipHash, newBlockHandler) - if err != nil { - return false, err - } - return shouldStop, nil } func findHighestSharedBlockHash(incomingRoute *router.Route, outgoingRoute *router.Route, dag *blockdag.BlockDAG, - peerSelectedTipHash *daghash.Hash) (lowHash *daghash.Hash, shouldStop bool, err error) { + peerSelectedTipHash *daghash.Hash) (lowHash *daghash.Hash, err error) { lowHash = dag.Params.GenesisHash highHash := peerSelectedTipHash for { - shouldStop = sendGetBlockLocator(outgoingRoute, lowHash, highHash) - if shouldStop { - return nil, true, nil + err := sendGetBlockLocator(outgoingRoute, lowHash, highHash) + if err != nil { + return nil, err } - blockLocatorHashes, shouldStop, err := receiveBlockLocator(incomingRoute) + blockLocatorHashes, err := receiveBlockLocator(incomingRoute) if err != nil { - return nil, false, err - } - if shouldStop { - return nil, true, nil + return nil, err } // We check whether the locator's highest hash is in the local DAG. @@ -127,7 +114,7 @@ func findHighestSharedBlockHash(incomingRoute *router.Route, outgoingRoute *rout // getBlockLocator request and try again. locatorHighHash := blockLocatorHashes[0] if dag.IsInDAG(locatorHighHash) { - return locatorHighHash, false, nil + return locatorHighHash, nil } highHash, lowHash = dag.FindNextLocatorBoundaries(blockLocatorHashes) @@ -135,111 +122,95 @@ func findHighestSharedBlockHash(incomingRoute *router.Route, outgoingRoute *rout } func sendGetBlockLocator(outgoingRoute *router.Route, lowHash *daghash.Hash, - highHash *daghash.Hash) (shouldStop bool) { + highHash *daghash.Hash) error { msgGetBlockLocator := wire.NewMsgGetBlockLocator(highHash, lowHash) - isOpen := outgoingRoute.Enqueue(msgGetBlockLocator) - return !isOpen + return outgoingRoute.Enqueue(msgGetBlockLocator) } -func receiveBlockLocator(incomingRoute *router.Route) (blockLocatorHashes []*daghash.Hash, - shouldStop bool, err error) { - - message, isOpen, err := incomingRoute.DequeueWithTimeout(common.DefaultTimeout) +func receiveBlockLocator(incomingRoute *router.Route) (blockLocatorHashes []*daghash.Hash, err error) { + message, err := incomingRoute.DequeueWithTimeout(common.DefaultTimeout) if err != nil { - return nil, false, err - } - if !isOpen { - return nil, true, nil + return nil, err } msgBlockLocator, ok := message.(*wire.MsgBlockLocator) if !ok { - return nil, false, + return nil, protocolerrors.Errorf(true, "received unexpected message type. "+ "expected: %s, got: %s", wire.CmdBlockLocator, message.Command()) } - return msgBlockLocator.BlockLocatorHashes, false, nil + return msgBlockLocator.BlockLocatorHashes, nil } func downloadBlocks(incomingRoute *router.Route, outgoingRoute *router.Route, dag *blockdag.BlockDAG, highestSharedBlockHash *daghash.Hash, - peerSelectedTipHash *daghash.Hash, newBlockHandler NewBlockHandler) (shouldStop bool, err error) { + peerSelectedTipHash *daghash.Hash, newBlockHandler NewBlockHandler) error { - shouldStop = sendGetBlocks(outgoingRoute, highestSharedBlockHash, peerSelectedTipHash) - if shouldStop { - return true, nil + err := sendGetBlocks(outgoingRoute, highestSharedBlockHash, peerSelectedTipHash) + if err != nil { + return err } for { - msgIBDBlock, shouldStop, err := receiveIBDBlock(incomingRoute) + msgIBDBlock, err := receiveIBDBlock(incomingRoute) if err != nil { - return false, err + return err } - if shouldStop { - return true, nil - } - shouldStop, err = processIBDBlock(dag, msgIBDBlock, newBlockHandler) + err = processIBDBlock(dag, msgIBDBlock, newBlockHandler) if err != nil { - return false, err - } - if shouldStop { - return true, nil + return err } if msgIBDBlock.BlockHash().IsEqual(peerSelectedTipHash) { - return true, nil + return nil } } } func sendGetBlocks(outgoingRoute *router.Route, highestSharedBlockHash *daghash.Hash, - peerSelectedTipHash *daghash.Hash) (shouldStop bool) { + peerSelectedTipHash *daghash.Hash) error { msgGetBlockInvs := wire.NewMsgGetBlocks(highestSharedBlockHash, peerSelectedTipHash) - isOpen := outgoingRoute.Enqueue(msgGetBlockInvs) - return !isOpen + return outgoingRoute.Enqueue(msgGetBlockInvs) } -func receiveIBDBlock(incomingRoute *router.Route) (msgIBDBlock *wire.MsgIBDBlock, shouldStop bool, err error) { - message, isOpen, err := incomingRoute.DequeueWithTimeout(common.DefaultTimeout) +func receiveIBDBlock(incomingRoute *router.Route) (msgIBDBlock *wire.MsgIBDBlock, err error) { + message, err := incomingRoute.DequeueWithTimeout(common.DefaultTimeout) if err != nil { - return nil, false, err - } - if !isOpen { - return nil, true, nil + return nil, err } msgIBDBlock, ok := message.(*wire.MsgIBDBlock) if !ok { - return nil, false, + return nil, protocolerrors.Errorf(true, "received unexpected message type. "+ "expected: %s, got: %s", wire.CmdIBDBlock, message.Command()) } - return msgIBDBlock, false, nil + return msgIBDBlock, nil } func processIBDBlock(dag *blockdag.BlockDAG, msgIBDBlock *wire.MsgIBDBlock, - newBlockHandler NewBlockHandler) (shouldStop bool, err error) { + newBlockHandler NewBlockHandler) error { block := util.NewBlock(&msgIBDBlock.MsgBlock) if dag.IsInDAG(block.Hash()) { - return false, nil + return nil } isOrphan, isDelayed, err := dag.ProcessBlock(block, blockdag.BFNone) if err != nil { - return false, err + return err } if isOrphan { - return false, protocolerrors.Errorf(true, "received orphan block %s "+ + return protocolerrors.Errorf(true, "received orphan block %s "+ "during IBD", block.Hash()) } if isDelayed { - return false, protocolerrors.Errorf(false, "received delayed block %s "+ + return protocolerrors.Errorf(false, "received delayed block %s "+ "during IBD", block.Hash()) } err = newBlockHandler(block) if err != nil { panic(err) } - return false, nil + return nil } func finishIBD(dag *blockdag.BlockDAG) { diff --git a/protocol/flows/ibd/selected_tip.go b/protocol/flows/ibd/selected_tip.go index 363f4f9dc..3f2614074 100644 --- a/protocol/flows/ibd/selected_tip.go +++ b/protocol/flows/ibd/selected_tip.go @@ -34,82 +34,69 @@ func requestSelectedTips() { func RequestSelectedTip(incomingRoute *router.Route, outgoingRoute *router.Route, peer *peerpkg.Peer, dag *blockdag.BlockDAG) error { for { - shouldStop, err := runSelectedTipRequest(incomingRoute, outgoingRoute, peer, dag) + err := runSelectedTipRequest(incomingRoute, outgoingRoute, peer, dag) if err != nil { return err } - if shouldStop { - return nil - } } } func runSelectedTipRequest(incomingRoute *router.Route, outgoingRoute *router.Route, - peer *peerpkg.Peer, dag *blockdag.BlockDAG) (shouldStop bool, err error) { + peer *peerpkg.Peer, dag *blockdag.BlockDAG) error { peer.WaitForSelectedTipRequests() defer peer.FinishRequestingSelectedTip() - shouldStop = requestSelectedTip(outgoingRoute) - if shouldStop { - return true, nil + err := requestSelectedTip(outgoingRoute) + if err != nil { + return err } - peerSelectedTipHash, shouldStop, err := receiveSelectedTip(incomingRoute) + peerSelectedTipHash, err := receiveSelectedTip(incomingRoute) if err != nil { - return false, err - } - if shouldStop { - return true, nil + return err } peer.SetSelectedTipHash(peerSelectedTipHash) StartIBDIfRequired(dag) - return false, nil + return nil } -func requestSelectedTip(outgoingRoute *router.Route) (shouldStop bool) { +func requestSelectedTip(outgoingRoute *router.Route) error { msgGetSelectedTip := wire.NewMsgGetSelectedTip() - isOpen := outgoingRoute.Enqueue(msgGetSelectedTip) - return !isOpen + return outgoingRoute.Enqueue(msgGetSelectedTip) } -func receiveSelectedTip(incomingRoute *router.Route) (selectedTipHash *daghash.Hash, shouldStop bool, err error) { - message, isOpen, err := incomingRoute.DequeueWithTimeout(common.DefaultTimeout) +func receiveSelectedTip(incomingRoute *router.Route) (selectedTipHash *daghash.Hash, err error) { + message, err := incomingRoute.DequeueWithTimeout(common.DefaultTimeout) if err != nil { - return nil, false, err - } - if !isOpen { - return nil, true, nil + return nil, err } msgSelectedTip := message.(*wire.MsgSelectedTip) - return msgSelectedTip.SelectedTipHash, false, nil + return msgSelectedTip.SelectedTipHash, nil } // HandleGetSelectedTip handles getSelectedTip messages func HandleGetSelectedTip(incomingRoute *router.Route, outgoingRoute *router.Route, dag *blockdag.BlockDAG) error { for { - shouldStop, err := receiveGetSelectedTip(incomingRoute) + err := receiveGetSelectedTip(incomingRoute) if err != nil { return err } - if shouldStop { - return nil - } selectedTipHash := dag.SelectedTipHash() - shouldStop = sendSelectedTipHash(outgoingRoute, selectedTipHash) - if shouldStop { - return nil + err = sendSelectedTipHash(outgoingRoute, selectedTipHash) + if err != nil { + return err } } } -func receiveGetSelectedTip(incomingRoute *router.Route) (shouldStop bool, err error) { - message, isOpen := incomingRoute.Dequeue() - if !isOpen { - return true, nil +func receiveGetSelectedTip(incomingRoute *router.Route) error { + message, err := incomingRoute.Dequeue() + if err != nil { + return err } _, ok := message.(*wire.MsgGetSelectedTip) if !ok { @@ -117,11 +104,10 @@ func receiveGetSelectedTip(incomingRoute *router.Route) (shouldStop bool, err er "expected: %s, got: %s", wire.CmdGetSelectedTip, message.Command())) } - return false, nil + return nil } -func sendSelectedTipHash(outgoingRoute *router.Route, selectedTipHash *daghash.Hash) (shouldStop bool) { +func sendSelectedTipHash(outgoingRoute *router.Route, selectedTipHash *daghash.Hash) error { msgSelectedTip := wire.NewMsgSelectedTip(selectedTipHash) - isOpen := outgoingRoute.Enqueue(msgSelectedTip) - return !isOpen + return outgoingRoute.Enqueue(msgSelectedTip) } diff --git a/protocol/flows/ping/ping.go b/protocol/flows/ping/ping.go index 82cd61cbe..dffa0bda4 100644 --- a/protocol/flows/ping/ping.go +++ b/protocol/flows/ping/ping.go @@ -15,16 +15,16 @@ import ( // This function assumes that incomingRoute will only return MsgPing. func ReceivePings(incomingRoute *router.Route, outgoingRoute *router.Route) error { for { - message, isOpen := incomingRoute.Dequeue() - if !isOpen { - return nil + message, err := incomingRoute.Dequeue() + if err != nil { + return err } pingMessage := message.(*wire.MsgPing) pongMessage := wire.NewMsgPong(pingMessage.Nonce) - isOpen = outgoingRoute.Enqueue(pongMessage) - if !isOpen { - return nil + err = outgoingRoute.Enqueue(pongMessage) + if err != nil { + return err } } } @@ -45,17 +45,14 @@ func SendPings(incomingRoute *router.Route, outgoingRoute *router.Route, peer *p peer.SetPingPending(nonce) pingMessage := wire.NewMsgPing(nonce) - isOpen := outgoingRoute.Enqueue(pingMessage) - if !isOpen { - return nil - } - - message, isOpen, err := incomingRoute.DequeueWithTimeout(common.DefaultTimeout) + err = outgoingRoute.Enqueue(pingMessage) if err != nil { return err } - if !isOpen { - return nil + + message, err := incomingRoute.DequeueWithTimeout(common.DefaultTimeout) + if err != nil { + return err } pongMessage := message.(*wire.MsgPong) if pongMessage.Nonce != pingMessage.Nonce { diff --git a/protocol/flows/relaytransactions/relaytransactions.go b/protocol/flows/relaytransactions/relaytransactions.go index 20a391513..eefaebf4e 100644 --- a/protocol/flows/relaytransactions/relaytransactions.go +++ b/protocol/flows/relaytransactions/relaytransactions.go @@ -26,36 +26,27 @@ func HandleRelayedTransactions(incomingRoute *router.Route, outgoingRoute *route invsQueue := make([]*wire.MsgInvTransaction, 0) for { - inv, shouldStop, err := readInv(incomingRoute, &invsQueue) + inv, err := readInv(incomingRoute, &invsQueue) if err != nil { return err } - if shouldStop { - return nil - } - requestedIDs, shouldStop, err := requestInvTransactions(outgoingRoute, txPool, dag, sharedRequestedTransactions, inv) + requestedIDs, err := requestInvTransactions(outgoingRoute, txPool, dag, sharedRequestedTransactions, inv) if err != nil { return err } - if shouldStop { - return nil - } - shouldStop, err = receiveTransactions(requestedIDs, incomingRoute, &invsQueue, txPool, netAdapter, + err = receiveTransactions(requestedIDs, incomingRoute, &invsQueue, txPool, netAdapter, sharedRequestedTransactions) if err != nil { return err } - if shouldStop { - return nil - } } } func requestInvTransactions(outgoingRoute *router.Route, txPool *mempool.TxPool, dag *blockdag.BlockDAG, sharedRequestedTransactions *SharedRequestedTransactions, inv *wire.MsgInvTransaction) (requestedIDs []*daghash.TxID, - shouldStop bool, err error) { + err error) { idsToRequest := make([]*daghash.TxID, 0, len(inv.TxIDS)) for _, txID := range inv.TxIDS { @@ -70,16 +61,16 @@ func requestInvTransactions(outgoingRoute *router.Route, txPool *mempool.TxPool, } if len(idsToRequest) == 0 { - return idsToRequest, false, nil + return idsToRequest, nil } msgGetTransactions := wire.NewMsgGetTransactions(idsToRequest) - isOpen := outgoingRoute.Enqueue(msgGetTransactions) - if !isOpen { + err = outgoingRoute.Enqueue(msgGetTransactions) + if err != nil { sharedRequestedTransactions.removeMany(idsToRequest) - return nil, true, nil + return nil, err } - return idsToRequest, false, nil + return idsToRequest, nil } func isKnownTransaction(txPool *mempool.TxPool, dag *blockdag.BlockDAG, txID *daghash.TxID) bool { @@ -108,25 +99,25 @@ func isKnownTransaction(txPool *mempool.TxPool, dag *blockdag.BlockDAG, txID *da return false } -func readInv(incomingRoute *router.Route, invsQueue *[]*wire.MsgInvTransaction) ( - inv *wire.MsgInvTransaction, shouldStop bool, err error) { +func readInv(incomingRoute *router.Route, invsQueue *[]*wire.MsgInvTransaction) (*wire.MsgInvTransaction, error) { if len(*invsQueue) > 0 { + var inv *wire.MsgInvTransaction inv, *invsQueue = (*invsQueue)[0], (*invsQueue)[1:] - return inv, false, nil + return inv, nil } - msg, isOpen := incomingRoute.Dequeue() - if !isOpen { - return nil, true, nil + msg, err := incomingRoute.Dequeue() + if err != nil { + return nil, err } inv, ok := msg.(*wire.MsgInvTransaction) if !ok { - return nil, false, protocolerrors.Errorf(true, "unexpected %s message in the block relay flow while "+ + return nil, protocolerrors.Errorf(true, "unexpected %s message in the block relay flow while "+ "expecting an inv message", msg.Command()) } - return inv, false, nil + return inv, nil } func broadcastAcceptedTransactions(netAdapter *netadapter.NetAdapter, acceptedTxs []*mempool.TxDesc) error { @@ -144,22 +135,19 @@ func broadcastAcceptedTransactions(netAdapter *netadapter.NetAdapter, acceptedTx // // Note: this function assumes msgChan can contain only wire.MsgInvTransaction and wire.MsgBlock messages. func readMsgTx(incomingRoute *router.Route, invsQueue *[]*wire.MsgInvTransaction) ( - msgTx *wire.MsgTx, shouldStop bool, err error) { + msgTx *wire.MsgTx, err error) { for { - message, isOpen, err := incomingRoute.DequeueWithTimeout(common.DefaultTimeout) + message, err := incomingRoute.DequeueWithTimeout(common.DefaultTimeout) if err != nil { - return nil, false, err - } - if !isOpen { - return nil, true, nil + return nil, err } switch message := message.(type) { case *wire.MsgInvTransaction: *invsQueue = append(*invsQueue, message) case *wire.MsgTx: - return message, false, nil + return message, nil default: panic(errors.Errorf("unexpected message %s", message.Command())) } @@ -168,22 +156,19 @@ func readMsgTx(incomingRoute *router.Route, invsQueue *[]*wire.MsgInvTransaction func receiveTransactions(requestedTransactions []*daghash.TxID, incomingRoute *router.Route, invsQueue *[]*wire.MsgInvTransaction, txPool *mempool.TxPool, netAdapter *netadapter.NetAdapter, - sharedRequestedTransactions *SharedRequestedTransactions) (shouldStop bool, err error) { + sharedRequestedTransactions *SharedRequestedTransactions) error { // In case the function returns earlier than expected, we want to make sure sharedRequestedTransactions is // clean from any pending transactions. defer sharedRequestedTransactions.removeMany(requestedTransactions) for _, expectedID := range requestedTransactions { - msgTx, shouldStop, err := readMsgTx(incomingRoute, invsQueue) + msgTx, err := readMsgTx(incomingRoute, invsQueue) if err != nil { - return false, err - } - if shouldStop { - return true, nil + return err } tx := util.NewTx(msgTx) if !tx.ID().IsEqual(expectedID) { - return false, protocolerrors.Errorf(true, "expected transaction %s", expectedID) + return protocolerrors.Errorf(true, "expected transaction %s", expectedID) } acceptedTxs, err := txPool.ProcessTransaction(tx, true, 0) // TODO(libp2p) Use the peer ID for the mempool tag @@ -210,7 +195,7 @@ func receiveTransactions(requestedTransactions []*daghash.TxID, incomingRoute *r continue } - return false, protocolerrors.Errorf(true, "rejected transaction %s", tx.ID()) + return protocolerrors.Errorf(true, "rejected transaction %s", tx.ID()) } err = broadcastAcceptedTransactions(netAdapter, acceptedTxs) if err != nil { @@ -218,5 +203,5 @@ func receiveTransactions(requestedTransactions []*daghash.TxID, incomingRoute *r } // TODO(libp2p) Notify transactionsAcceptedToMempool to RPC } - return false, nil + return nil } diff --git a/protocol/protocol.go b/protocol/protocol.go index 1527b8d3e..1f6e8da7c 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -43,6 +43,9 @@ func (m *Manager) routerInitializer() (*routerpkg.Router, error) { } return } + if errors.Is(err, routerpkg.ErrRouteClosed) { + return + } panic(err) } }) @@ -78,13 +81,13 @@ func (m *Manager) addAddressFlows(router *routerpkg.Router, stopped *uint32, sto outgoingRoute := router.OutgoingRoute() addOneTimeFlow("SendAddresses", router, []wire.MessageCommand{wire.CmdGetAddresses}, stopped, stop, - func(incomingRoute *routerpkg.Route) (routeClosed bool, err error) { + func(incomingRoute *routerpkg.Route) error { return addressexchange.SendAddresses(incomingRoute, outgoingRoute, m.addressManager) }, ) addOneTimeFlow("ReceiveAddresses", router, []wire.MessageCommand{wire.CmdAddress}, stopped, stop, - func(incomingRoute *routerpkg.Route) (routeClosed bool, err error) { + func(incomingRoute *routerpkg.Route) error { return addressexchange.ReceiveAddresses(incomingRoute, outgoingRoute, m.cfg, peer, m.addressManager) }, ) @@ -192,7 +195,7 @@ func addFlow(name string, router *routerpkg.Router, messageTypes []wire.MessageC } func addOneTimeFlow(name string, router *routerpkg.Router, messageTypes []wire.MessageCommand, stopped *uint32, - stopChan chan error, flow func(route *routerpkg.Route) (routeClosed bool, err error)) { + stopChan chan error, flow func(route *routerpkg.Route) error) { route, err := router.AddIncomingRoute(messageTypes) if err != nil { @@ -207,11 +210,11 @@ func addOneTimeFlow(name string, router *routerpkg.Router, messageTypes []wire.M } }() - closed, err := flow(route) + err := flow(route) if err != nil { log.Errorf("error from %s flow: %s", name, err) } - if (err != nil || closed) && atomic.AddUint32(stopped, 1) == 1 { + if err != nil && atomic.AddUint32(stopped, 1) == 1 { stopChan <- err } })