diff --git a/app/rpc/rpc.go b/app/rpc/rpc.go index 6eb9f3059..7a6932491 100644 --- a/app/rpc/rpc.go +++ b/app/rpc/rpc.go @@ -40,16 +40,12 @@ func (m *Manager) routerInitializer(router *router.Router, netConnection *netada if err != nil { panic(err) } - spawn("routerInitializer-handleIncomingMessages", func() { - err := m.handleIncomingMessages(router, incomingRoute) - m.handleError(err, netConnection) - }) + m.context.NotificationManager.AddListener(router) - notificationListener := m.context.NotificationManager.AddListener(router) - spawn("routerInitializer-handleOutgoingNotifications", func() { + spawn("routerInitializer-handleIncomingMessages", func() { defer m.context.NotificationManager.RemoveListener(router) - err := m.handleOutgoingNotifications(notificationListener) + err := m.handleIncomingMessages(router, incomingRoute) m.handleError(err, netConnection) }) } @@ -76,15 +72,6 @@ func (m *Manager) handleIncomingMessages(router *router.Router, incomingRoute *r } } -func (m *Manager) handleOutgoingNotifications(notificationListener *rpccontext.NotificationListener) error { - for { - err := notificationListener.ProcessNextNotification() - if err != nil { - return err - } - } -} - func (m *Manager) handleError(err error, netConnection *netadapter.NetConnection) { if errors.Is(err, router.ErrTimeout) { log.Warnf("Got timeout from %s. Disconnecting...", netConnection) diff --git a/app/rpc/rpccontext/notificationmanager.go b/app/rpc/rpccontext/notificationmanager.go index bb24dd45a..1e5886878 100644 --- a/app/rpc/rpccontext/notificationmanager.go +++ b/app/rpc/rpccontext/notificationmanager.go @@ -2,7 +2,7 @@ package rpccontext import ( "github.com/kaspanet/kaspad/app/appmessage" - "github.com/kaspanet/kaspad/infrastructure/network/netadapter/router" + routerpkg "github.com/kaspanet/kaspad/infrastructure/network/netadapter/router" "github.com/pkg/errors" "sync" ) @@ -10,58 +10,41 @@ import ( // NotificationManager manages notifications for the RPC type NotificationManager struct { sync.RWMutex - listeners map[*router.Router]*NotificationListener + listeners map[*routerpkg.Router]*NotificationListener } -// OnBlockAddedListener is a listener function for when a block is added to the DAG -type OnBlockAddedListener func(notification *appmessage.BlockAddedNotificationMessage) error - -// OnChainChangedListener is a listener function for when the DAG's selected parent chain changes -type OnChainChangedListener func(notification *appmessage.ChainChangedNotificationMessage) error - // NotificationListener represents a registered RPC notification listener type NotificationListener struct { - onBlockAddedListener OnBlockAddedListener - onBlockAddedNotificationChan chan *appmessage.BlockAddedNotificationMessage - onChainChangedListener OnChainChangedListener - onChainChangedNotificationChan chan *appmessage.ChainChangedNotificationMessage - - closeChan chan struct{} + propagateBlockAddedNotifications bool + propagateChainChangedNotifications bool } // NewNotificationManager creates a new NotificationManager func NewNotificationManager() *NotificationManager { return &NotificationManager{ - listeners: make(map[*router.Router]*NotificationListener), + listeners: make(map[*routerpkg.Router]*NotificationListener), } } // AddListener registers a listener with the given router -func (nm *NotificationManager) AddListener(router *router.Router) *NotificationListener { +func (nm *NotificationManager) AddListener(router *routerpkg.Router) { nm.Lock() defer nm.Unlock() listener := newNotificationListener() nm.listeners[router] = listener - return listener } // RemoveListener unregisters the given router -func (nm *NotificationManager) RemoveListener(router *router.Router) { +func (nm *NotificationManager) RemoveListener(router *routerpkg.Router) { nm.Lock() defer nm.Unlock() - listener, ok := nm.listeners[router] - if !ok { - return - } - listener.close() - delete(nm.listeners, router) } // Listener retrieves the listener registered with the given router -func (nm *NotificationManager) Listener(router *router.Router) (*NotificationListener, error) { +func (nm *NotificationManager) Listener(router *routerpkg.Router) (*NotificationListener, error) { nm.RLock() defer nm.RUnlock() @@ -73,67 +56,52 @@ func (nm *NotificationManager) Listener(router *router.Router) (*NotificationLis } // NotifyBlockAdded notifies the notification manager that a block has been added to the DAG -func (nm *NotificationManager) NotifyBlockAdded(notification *appmessage.BlockAddedNotificationMessage) { +func (nm *NotificationManager) NotifyBlockAdded(notification *appmessage.BlockAddedNotificationMessage) error { nm.RLock() defer nm.RUnlock() - for _, listener := range nm.listeners { - if listener.onBlockAddedListener != nil { - select { - case listener.onBlockAddedNotificationChan <- notification: - case <-listener.closeChan: - continue + for router, listener := range nm.listeners { + if listener.propagateBlockAddedNotifications { + err := router.OutgoingRoute().Enqueue(notification) + if err != nil { + return err } } } + return nil } // NotifyChainChanged notifies the notification manager that the DAG's selected parent chain has changed -func (nm *NotificationManager) NotifyChainChanged(message *appmessage.ChainChangedNotificationMessage) { +func (nm *NotificationManager) NotifyChainChanged(notification *appmessage.ChainChangedNotificationMessage) error { nm.RLock() defer nm.RUnlock() - for _, listener := range nm.listeners { - if listener.onChainChangedListener != nil { - select { - case listener.onChainChangedNotificationChan <- message: - case <-listener.closeChan: - continue + for router, listener := range nm.listeners { + if listener.propagateChainChangedNotifications { + err := router.OutgoingRoute().Enqueue(notification) + if err != nil { + return err } } } + return nil } func newNotificationListener() *NotificationListener { return &NotificationListener{ - onBlockAddedNotificationChan: make(chan *appmessage.BlockAddedNotificationMessage), - onChainChangedNotificationChan: make(chan *appmessage.ChainChangedNotificationMessage), - closeChan: make(chan struct{}, 1), + propagateBlockAddedNotifications: false, + propagateChainChangedNotifications: false, } } -// SetOnBlockAddedListener sets the onBlockAddedListener handler for this listener -func (nl *NotificationListener) SetOnBlockAddedListener(onBlockAddedListener OnBlockAddedListener) { - nl.onBlockAddedListener = onBlockAddedListener +// PropagateBlockAddedNotifications instructs the listener to send block added notifications +// to the remote listener +func (nl *NotificationListener) PropagateBlockAddedNotifications() { + nl.propagateBlockAddedNotifications = true } -// SetOnChainChangedListener sets the onChainChangedListener handler for this listener -func (nl *NotificationListener) SetOnChainChangedListener(onChainChangedListener OnChainChangedListener) { - nl.onChainChangedListener = onChainChangedListener -} - -// ProcessNextNotification waits until a notification arrives and processes it -func (nl *NotificationListener) ProcessNextNotification() error { - select { - case block := <-nl.onBlockAddedNotificationChan: - return nl.onBlockAddedListener(block) - case notification := <-nl.onChainChangedNotificationChan: - return nl.onChainChangedListener(notification) - case <-nl.closeChan: - return nil - } -} - -func (nl *NotificationListener) close() { - nl.closeChan <- struct{}{} +// PropagateChainChangedNotifications instructs the listener to send chain changed notifications +// to the remote listener +func (nl *NotificationListener) PropagateChainChangedNotifications() { + nl.propagateChainChangedNotifications = true } diff --git a/app/rpc/rpchandlers/notify_block_added.go b/app/rpc/rpchandlers/notify_block_added.go index b86ef03b8..2929aa662 100644 --- a/app/rpc/rpchandlers/notify_block_added.go +++ b/app/rpc/rpchandlers/notify_block_added.go @@ -12,9 +12,7 @@ func HandleNotifyBlockAdded(context *rpccontext.Context, router *router.Router, if err != nil { return nil, err } - listener.SetOnBlockAddedListener(func(notification *appmessage.BlockAddedNotificationMessage) error { - return router.OutgoingRoute().Enqueue(notification) - }) + listener.PropagateBlockAddedNotifications() response := appmessage.NewNotifyBlockAddedResponseMessage() return response, nil diff --git a/app/rpc/rpchandlers/notify_chain_changed.go b/app/rpc/rpchandlers/notify_chain_changed.go index 872c9dced..a7f18c0b0 100644 --- a/app/rpc/rpchandlers/notify_chain_changed.go +++ b/app/rpc/rpchandlers/notify_chain_changed.go @@ -18,9 +18,7 @@ func HandleNotifyChainChanged(context *rpccontext.Context, router *router.Router if err != nil { return nil, err } - listener.SetOnChainChangedListener(func(message *appmessage.ChainChangedNotificationMessage) error { - return router.OutgoingRoute().Enqueue(message) - }) + listener.PropagateChainChangedNotifications() response := appmessage.NewNotifyChainChangedResponseMessage() return response, nil