diff --git a/connmanager/connection_requests.go b/connmanager/connection_requests.go index 17073a130..728ed47c3 100644 --- a/connmanager/connection_requests.go +++ b/connmanager/connection_requests.go @@ -100,6 +100,8 @@ func (c *ConnectionManager) AddConnectionRequest(address string, isPermanent boo address: address, isPermanent: isPermanent, } + + c.run() }) } diff --git a/connmanager/connmanager.go b/connmanager/connmanager.go index b0a04c909..a6481f2f9 100644 --- a/connmanager/connmanager.go +++ b/connmanager/connmanager.go @@ -35,6 +35,9 @@ type ConnectionManager struct { stop uint32 connectionRequestsLock sync.Mutex + + resetLoopChan chan struct{} + loopTicker *time.Ticker } // New instantiates a new instance of a ConnectionManager @@ -47,6 +50,8 @@ func New(cfg *config.Config, netAdapter *netadapter.NetAdapter, addressManager * pendingRequested: map[string]*connectionRequest{}, activeOutgoing: map[string]struct{}{}, activeIncoming: map[string]struct{}{}, + resetLoopChan: make(chan struct{}), + loopTicker: time.NewTicker(connectionsLoopInterval), } connectPeers := cfg.AddPeers @@ -79,6 +84,12 @@ func (c *ConnectionManager) Stop() { for _, connection := range c.netAdapter.Connections() { _ = c.netAdapter.Disconnect(connection) // Ignore errors since connection might be in the midst of disconnecting } + + c.loopTicker.Stop() +} + +func (c *ConnectionManager) run() { + c.resetLoopChan <- struct{}{} } func (c *ConnectionManager) initiateConnection(address string) error { @@ -104,7 +115,7 @@ func (c *ConnectionManager) connectionsLoop() { c.checkIncomingConnections(connSet) - <-time.Tick(connectionsLoopInterval) + c.waitTillNextIteration() } } @@ -117,3 +128,12 @@ func (c *ConnectionManager) ConnectionCount() int { func (c *ConnectionManager) Ban(netConnection *netadapter.NetConnection) { c.netAdapter.Ban(netConnection) } + +func (c *ConnectionManager) waitTillNextIteration() { + select { + case <-c.resetLoopChan: + c.loopTicker.Stop() + c.loopTicker = time.NewTicker(connectionsLoopInterval) + case <-c.loopTicker.C: + } +} diff --git a/netadapter/netadapter.go b/netadapter/netadapter.go index ab2139bb0..b074c7675 100644 --- a/netadapter/netadapter.go +++ b/netadapter/netadapter.go @@ -101,7 +101,7 @@ func (na *NetAdapter) ConnectionCount() int { } func (na *NetAdapter) onConnectedHandler(connection server.Connection) error { - netConnection := newNetConnection(connection, nil) + netConnection := newNetConnection(connection) router, err := na.routerInitializer(netConnection) if err != nil { return err @@ -144,7 +144,7 @@ func (na *NetAdapter) Broadcast(netConnections []*NetConnection, message wire.Me defer na.RUnlock() for _, netConnection := range netConnections { router := na.connectionsToRouters[netConnection] - err := router.EnqueueIncomingMessage(message) + err := router.OutgoingRoute().Enqueue(message) if err != nil { if errors.Is(err, routerpkg.ErrRouteClosed) { log.Debugf("Cannot enqueue message to %s: router is closed", netConnection) diff --git a/netadapter/netconnection.go b/netadapter/netconnection.go index 3d8ae092d..2bcfdce3f 100644 --- a/netadapter/netconnection.go +++ b/netadapter/netconnection.go @@ -13,10 +13,9 @@ type NetConnection struct { id *id.ID } -func newNetConnection(connection server.Connection, id *id.ID) *NetConnection { +func newNetConnection(connection server.Connection) *NetConnection { return &NetConnection{ connection: connection, - id: id, } } @@ -29,6 +28,11 @@ func (c *NetConnection) ID() *id.ID { return c.id } +// SetID sets the ID associated with this connection +func (c *NetConnection) SetID(peerID *id.ID) { + c.id = peerID +} + // Address returns the address associated with this connection func (c *NetConnection) Address() string { return c.connection.Address().String() diff --git a/protocol/flows/blockrelay/handle_relay_invs.go b/protocol/flows/blockrelay/handle_relay_invs.go index ff9a25dd0..664fa5c8d 100644 --- a/protocol/flows/blockrelay/handle_relay_invs.go +++ b/protocol/flows/blockrelay/handle_relay_invs.go @@ -111,7 +111,7 @@ func (flow *handleRelayInvsFlow) requestBlocks(requestQueue *hashesQueueSet) err var filteredHashesToRequest []*daghash.Hash for _, hash := range hashesToRequest { exists := flow.SharedRequestedBlocks().addIfNotExists(hash) - if !exists { + if exists { continue } diff --git a/protocol/flows/blockrelay/hashes_queue_set.go b/protocol/flows/blockrelay/hashes_queue_set.go index cfcf44fdc..60a2c43fa 100644 --- a/protocol/flows/blockrelay/hashes_queue_set.go +++ b/protocol/flows/blockrelay/hashes_queue_set.go @@ -8,7 +8,7 @@ type hashesQueueSet struct { } func (r *hashesQueueSet) enqueueIfNotExists(hash *daghash.Hash) { - if _, ok := r.set[*hash]; !ok { + if _, ok := r.set[*hash]; ok { return } r.queue = append(r.queue, hash) diff --git a/protocol/flows/handshake/handshake.go b/protocol/flows/handshake/handshake.go index a737b6874..f948dafe8 100644 --- a/protocol/flows/handshake/handshake.go +++ b/protocol/flows/handshake/handshake.go @@ -1,13 +1,14 @@ package handshake import ( + "sync" + "sync/atomic" + "github.com/kaspanet/kaspad/addrmgr" "github.com/kaspanet/kaspad/blockdag" "github.com/kaspanet/kaspad/config" "github.com/kaspanet/kaspad/netadapter" "github.com/kaspanet/kaspad/protocol/common" - "sync" - "sync/atomic" routerpkg "github.com/kaspanet/kaspad/netadapter/router" peerpkg "github.com/kaspanet/kaspad/protocol/peer" diff --git a/protocol/flows/handshake/receiveversion.go b/protocol/flows/handshake/receiveversion.go index a17e5f69e..d762552ed 100644 --- a/protocol/flows/handshake/receiveversion.go +++ b/protocol/flows/handshake/receiveversion.go @@ -90,5 +90,8 @@ func (flow *receiveVersionFlow) start() (*wire.NetAddress, error) { if err != nil { return nil, err } + + flow.peer.Connection().SetID(msgVersion.ID) + return msgVersion.Address, nil } diff --git a/wire/message.go b/wire/message.go index e5642448d..66c324ba9 100644 --- a/wire/message.go +++ b/wire/message.go @@ -182,6 +182,12 @@ func MakeEmptyMessage(command MessageCommand) (Message, error) { case CmdIBDBlock: msg = &MsgIBDBlock{} + case CmdInvRelayBlock: + msg = &MsgInvRelayBlock{} + + case CmdGetRelayBlocks: + msg = &MsgGetRelayBlocks{} + default: return nil, errors.Errorf("unhandled command [%s]", command) } diff --git a/wire/msggetrelayblocks.go b/wire/msggetrelayblocks.go index daa4c284a..9eb564c35 100644 --- a/wire/msggetrelayblocks.go +++ b/wire/msggetrelayblocks.go @@ -1,8 +1,9 @@ package wire import ( - "github.com/kaspanet/kaspad/util/daghash" "io" + + "github.com/kaspanet/kaspad/util/daghash" ) // MsgGetRelayBlocksHashes is the maximum number of hashes that can @@ -19,13 +20,38 @@ type MsgGetRelayBlocks struct { // KaspaDecode decodes r using the kaspa protocol encoding into the receiver. // This is part of the Message interface implementation. func (msg *MsgGetRelayBlocks) KaspaDecode(r io.Reader, pver uint32) error { - return ReadElement(r, &msg.Hashes) + numHashes, err := ReadVarInt(r) + if err != nil { + return err + } + + msg.Hashes = make([]*daghash.Hash, numHashes) + for i := uint64(0); i < numHashes; i++ { + msg.Hashes[i] = &daghash.Hash{} + err := ReadElement(r, msg.Hashes[i]) + if err != nil { + return err + } + } + + return nil } // KaspaEncode encodes the receiver to w using the kaspa protocol encoding. // This is part of the Message interface implementation. func (msg *MsgGetRelayBlocks) KaspaEncode(w io.Writer, pver uint32) error { - return WriteElement(w, msg.Hashes) + err := WriteVarInt(w, uint64(len(msg.Hashes))) + if err != nil { + return err + } + for _, hash := range msg.Hashes { + err := WriteElement(w, hash) + if err != nil { + return err + } + } + + return nil } // Command returns the protocol command string for the message. This is part diff --git a/wire/msginvrelayblock.go b/wire/msginvrelayblock.go index 24732f2d1..152f8d5ee 100644 --- a/wire/msginvrelayblock.go +++ b/wire/msginvrelayblock.go @@ -1,8 +1,9 @@ package wire import ( - "github.com/kaspanet/kaspad/util/daghash" "io" + + "github.com/kaspanet/kaspad/util/daghash" ) // MsgInvRelayBlock implements the Message interface and represents a kaspa @@ -15,7 +16,8 @@ type MsgInvRelayBlock struct { // KaspaDecode decodes r using the kaspa protocol encoding into the receiver. // This is part of the Message interface implementation. func (msg *MsgInvRelayBlock) KaspaDecode(r io.Reader, pver uint32) error { - return ReadElement(r, &msg.Hash) + msg.Hash = &daghash.Hash{} + return ReadElement(r, msg.Hash) } // KaspaEncode encodes the receiver to w using the kaspa protocol encoding.