mirror of
https://github.com/kaspanet/kaspad.git
synced 2025-06-10 16:16:47 +00:00
[NOD-1153] Remove redundant maps from NetAdapter (#817)
* [NOD-1153] Remove redundant maps from NetAdapter. * [NOD-1153] Fix a comment. * [NOD-1153] Fix a comment.
This commit is contained in:
parent
428f16ffef
commit
5d5a0ef335
@ -31,9 +31,7 @@ type NetAdapter struct {
|
||||
routerInitializer RouterInitializer
|
||||
stop uint32
|
||||
|
||||
routersToConnections map[*routerpkg.Router]*NetConnection
|
||||
connectionsToIDs map[*NetConnection]*id.ID
|
||||
idsToRouters map[*id.ID]*routerpkg.Router
|
||||
connectionsToRouters map[*NetConnection]*routerpkg.Router
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
@ -53,9 +51,7 @@ func NewNetAdapter(cfg *config.Config) (*NetAdapter, error) {
|
||||
id: netAdapterID,
|
||||
server: s,
|
||||
|
||||
routersToConnections: make(map[*routerpkg.Router]*NetConnection),
|
||||
connectionsToIDs: make(map[*NetConnection]*id.ID),
|
||||
idsToRouters: make(map[*id.ID]*routerpkg.Router),
|
||||
connectionsToRouters: make(map[*NetConnection]*routerpkg.Router),
|
||||
}
|
||||
|
||||
adapter.server.SetOnConnectedHandler(adapter.onConnectedHandler)
|
||||
@ -90,9 +86,9 @@ func (na *NetAdapter) Connect(address string) error {
|
||||
|
||||
// Connections returns a list of connections currently connected and active
|
||||
func (na *NetAdapter) Connections() []*NetConnection {
|
||||
netConnections := make([]*NetConnection, 0, len(na.connectionsToIDs))
|
||||
netConnections := make([]*NetConnection, 0, len(na.connectionsToRouters))
|
||||
|
||||
for netConnection := range na.connectionsToIDs {
|
||||
for netConnection := range na.connectionsToRouters {
|
||||
netConnections = append(netConnections, netConnection)
|
||||
}
|
||||
|
||||
@ -101,7 +97,7 @@ func (na *NetAdapter) Connections() []*NetConnection {
|
||||
|
||||
// ConnectionCount returns the count of the connected connections
|
||||
func (na *NetAdapter) ConnectionCount() int {
|
||||
return len(na.connectionsToIDs)
|
||||
return len(na.connectionsToRouters)
|
||||
}
|
||||
|
||||
func (na *NetAdapter) onConnectedHandler(connection server.Connection) error {
|
||||
@ -112,9 +108,7 @@ func (na *NetAdapter) onConnectedHandler(connection server.Connection) error {
|
||||
}
|
||||
connection.Start(router)
|
||||
|
||||
na.routersToConnections[router] = netConnection
|
||||
|
||||
na.connectionsToIDs[netConnection] = nil
|
||||
na.connectionsToRouters[netConnection] = router
|
||||
|
||||
router.SetOnRouteCapacityReachedHandler(func() {
|
||||
err := connection.Disconnect()
|
||||
@ -126,38 +120,12 @@ func (na *NetAdapter) onConnectedHandler(connection server.Connection) error {
|
||||
}
|
||||
})
|
||||
connection.SetOnDisconnectedHandler(func() error {
|
||||
na.cleanupConnection(netConnection, router)
|
||||
delete(na.connectionsToRouters, netConnection)
|
||||
return router.Close()
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// AssociateRouterID associates the connection for the given router
|
||||
// with the given ID
|
||||
func (na *NetAdapter) AssociateRouterID(router *routerpkg.Router, id *id.ID) error {
|
||||
netConnection, ok := na.routersToConnections[router]
|
||||
if !ok {
|
||||
return errors.Errorf("router not registered for id %s", id)
|
||||
}
|
||||
|
||||
netConnection.id = id
|
||||
|
||||
na.connectionsToIDs[netConnection] = id
|
||||
na.idsToRouters[id] = router
|
||||
return nil
|
||||
}
|
||||
|
||||
func (na *NetAdapter) cleanupConnection(netConnection *NetConnection, router *routerpkg.Router) {
|
||||
connectionID, ok := na.connectionsToIDs[netConnection]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
delete(na.routersToConnections, router)
|
||||
delete(na.connectionsToIDs, netConnection)
|
||||
delete(na.idsToRouters, connectionID)
|
||||
}
|
||||
|
||||
// SetRouterInitializer sets the routerInitializer function
|
||||
// for the net adapter
|
||||
func (na *NetAdapter) SetRouterInitializer(routerInitializer RouterInitializer) {
|
||||
@ -170,21 +138,16 @@ func (na *NetAdapter) ID() *id.ID {
|
||||
}
|
||||
|
||||
// Broadcast sends the given `message` to every peer corresponding
|
||||
// to each ID in `ids`
|
||||
func (na *NetAdapter) Broadcast(connectionIDs []*id.ID, message wire.Message) error {
|
||||
// to each NetConnection in the given netConnections
|
||||
func (na *NetAdapter) Broadcast(netConnections []*NetConnection, message wire.Message) error {
|
||||
na.RLock()
|
||||
defer na.RUnlock()
|
||||
for _, connectionID := range connectionIDs {
|
||||
router, ok := na.idsToRouters[connectionID]
|
||||
if !ok {
|
||||
log.Warnf("connectionID %s is not registered", connectionID)
|
||||
continue
|
||||
}
|
||||
for _, netConnection := range netConnections {
|
||||
router := na.connectionsToRouters[netConnection]
|
||||
err := router.EnqueueIncomingMessage(message)
|
||||
if err != nil {
|
||||
if errors.Is(err, routerpkg.ErrRouteClosed) {
|
||||
connection := na.routersToConnections[router]
|
||||
log.Debugf("Cannot enqueue message to %s: router is closed", connection)
|
||||
log.Debugf("Cannot enqueue message to %s: router is closed", netConnection)
|
||||
continue
|
||||
}
|
||||
return err
|
||||
|
@ -3,7 +3,6 @@ package flowcontext
|
||||
import (
|
||||
"github.com/kaspanet/kaspad/connmanager"
|
||||
"github.com/kaspanet/kaspad/netadapter"
|
||||
"github.com/kaspanet/kaspad/netadapter/id"
|
||||
"github.com/kaspanet/kaspad/protocol/common"
|
||||
peerpkg "github.com/kaspanet/kaspad/protocol/peer"
|
||||
"github.com/kaspanet/kaspad/wire"
|
||||
@ -33,22 +32,22 @@ func (f *FlowContext) AddToPeers(peer *peerpkg.Peer) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// readyPeerIDs returns the peer IDs of all the ready peers.
|
||||
func (f *FlowContext) readyPeerIDs() []*id.ID {
|
||||
// readyPeerConnections returns the NetConnections of all the ready peers.
|
||||
func (f *FlowContext) readyPeerConnections() []*netadapter.NetConnection {
|
||||
f.peersMutex.RLock()
|
||||
defer f.peersMutex.RUnlock()
|
||||
peerIDs := make([]*id.ID, len(f.peers))
|
||||
peerConnections := make([]*netadapter.NetConnection, len(f.peers))
|
||||
i := 0
|
||||
for peerID := range f.peers {
|
||||
peerIDs[i] = peerID
|
||||
for _, peer := range f.peers {
|
||||
peerConnections[i] = peer.Connection()
|
||||
i++
|
||||
}
|
||||
return peerIDs
|
||||
return peerConnections
|
||||
}
|
||||
|
||||
// Broadcast broadcast the given message to all the ready peers.
|
||||
func (f *FlowContext) Broadcast(message wire.Message) error {
|
||||
return f.netAdapter.Broadcast(f.readyPeerIDs(), message)
|
||||
return f.netAdapter.Broadcast(f.readyPeerConnections(), message)
|
||||
}
|
||||
|
||||
// Peers returns the currently active peers
|
||||
|
@ -99,12 +99,6 @@ func HandleHandshake(context HandleHandshakeContext, router *routerpkg.Router,
|
||||
panic(err)
|
||||
}
|
||||
|
||||
peerID := peer.ID()
|
||||
err = context.NetAdapter().AssociateRouterID(router, peerID)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if peerAddress != nil {
|
||||
subnetworkID := peer.SubnetworkID()
|
||||
context.AddressManager().AddAddress(peerAddress, peerAddress, subnetworkID)
|
||||
|
@ -21,7 +21,6 @@ type Peer struct {
|
||||
selectedTipHashMtx sync.RWMutex
|
||||
selectedTipHash *daghash.Hash
|
||||
|
||||
id *id.ID
|
||||
userAgent string
|
||||
services wire.ServiceFlag
|
||||
advertisedProtocolVer uint32 // protocol version advertised by remote
|
||||
@ -50,6 +49,11 @@ func New(connection *netadapter.NetConnection) *Peer {
|
||||
}
|
||||
}
|
||||
|
||||
// Connection returns the NetConnection associated with this peer
|
||||
func (p *Peer) Connection() *netadapter.NetConnection {
|
||||
return p.connection
|
||||
}
|
||||
|
||||
// SelectedTipHash returns the selected tip of the peer.
|
||||
func (p *Peer) SelectedTipHash() *daghash.Hash {
|
||||
p.selectedTipHashMtx.RLock()
|
||||
@ -72,7 +76,7 @@ func (p *Peer) SubnetworkID() *subnetworkid.SubnetworkID {
|
||||
|
||||
// ID returns the peer ID.
|
||||
func (p *Peer) ID() *id.ID {
|
||||
return p.id
|
||||
return p.connection.ID()
|
||||
}
|
||||
|
||||
// UpdateFieldsFromMsgVersion updates the peer with the data from the version message.
|
||||
@ -81,10 +85,7 @@ func (p *Peer) UpdateFieldsFromMsgVersion(msg *wire.MsgVersion) {
|
||||
p.advertisedProtocolVer = msg.ProtocolVersion
|
||||
p.protocolVersion = mathUtil.MinUint32(p.protocolVersion, p.advertisedProtocolVer)
|
||||
log.Debugf("Negotiated protocol version %d for peer %s",
|
||||
p.protocolVersion, p.id)
|
||||
|
||||
// Set the peer's ID.
|
||||
p.id = msg.ID
|
||||
p.protocolVersion, p.ID())
|
||||
|
||||
// Set the supported services for the peer to what the remote peer
|
||||
// advertised.
|
||||
|
Loading…
x
Reference in New Issue
Block a user