diff --git a/netadapter/netadapter.go b/netadapter/netadapter.go index f1ace7277..d296eb5f7 100644 --- a/netadapter/netadapter.go +++ b/netadapter/netadapter.go @@ -31,8 +31,8 @@ type NetAdapter struct { routerInitializer RouterInitializer stop uint32 - connectionsToRouters map[*NetConnection]*routerpkg.Router - sync.RWMutex + connectionsToRouters map[*NetConnection]*routerpkg.Router + connectionsToRoutersLock sync.RWMutex } // NewNetAdapter creates and starts a new NetAdapter on the @@ -86,6 +86,9 @@ func (na *NetAdapter) Connect(address string) error { // Connections returns a list of connections currently connected and active func (na *NetAdapter) Connections() []*NetConnection { + na.connectionsToRoutersLock.RLock() + defer na.connectionsToRoutersLock.RUnlock() + netConnections := make([]*NetConnection, 0, len(na.connectionsToRouters)) for netConnection := range na.connectionsToRouters { @@ -97,6 +100,9 @@ func (na *NetAdapter) Connections() []*NetConnection { // ConnectionCount returns the count of the connected connections func (na *NetAdapter) ConnectionCount() int { + na.connectionsToRoutersLock.RLock() + defer na.connectionsToRoutersLock.RUnlock() + return len(na.connectionsToRouters) } @@ -105,6 +111,9 @@ func (na *NetAdapter) onConnectedHandler(connection server.Connection) error { router := na.routerInitializer(netConnection) connection.Start(router) + na.connectionsToRoutersLock.Lock() + defer na.connectionsToRoutersLock.Unlock() + na.connectionsToRouters[netConnection] = router router.SetOnRouteCapacityReachedHandler(func() { @@ -117,6 +126,9 @@ func (na *NetAdapter) onConnectedHandler(connection server.Connection) error { } }) connection.SetOnDisconnectedHandler(func() error { + na.connectionsToRoutersLock.Lock() + defer na.connectionsToRoutersLock.Unlock() + delete(na.connectionsToRouters, netConnection) return router.Close() }) @@ -137,10 +149,15 @@ func (na *NetAdapter) ID() *id.ID { // Broadcast sends the given `message` to every peer corresponding // to each NetConnection in the given netConnections func (na *NetAdapter) Broadcast(netConnections []*NetConnection, message wire.Message) error { - na.RLock() - defer na.RUnlock() + na.connectionsToRoutersLock.RLock() + defer na.connectionsToRoutersLock.RUnlock() + for _, netConnection := range netConnections { - router := na.connectionsToRouters[netConnection] + router, ok := na.connectionsToRouters[netConnection] + if !ok { // skip connections that were removed + continue + } + err := router.OutgoingRoute().Enqueue(message) if err != nil { if errors.Is(err, routerpkg.ErrRouteClosed) {