diff --git a/netadapter/netconnection.go b/netadapter/netconnection.go index b47f7706e..5ea4b381e 100644 --- a/netadapter/netconnection.go +++ b/netadapter/netconnection.go @@ -2,9 +2,9 @@ package netadapter import ( "fmt" - "github.com/kaspanet/kaspad/domainmessage" routerpkg "github.com/kaspanet/kaspad/netadapter/router" + "github.com/pkg/errors" "github.com/kaspanet/kaspad/netadapter/id" "github.com/kaspanet/kaspad/netadapter/server" @@ -15,23 +15,28 @@ type NetConnection struct { connection server.Connection id *id.ID router *routerpkg.Router + invalidMessageChan chan error onDisconnectedHandler server.OnDisconnectedHandler + isConnected uint32 } func newNetConnection(connection server.Connection, routerInitializer RouterInitializer) *NetConnection { router := routerpkg.NewRouter() netConnection := &NetConnection{ - connection: connection, - router: router, + connection: connection, + router: router, + invalidMessageChan: make(chan error), } netConnection.connection.SetOnDisconnectedHandler(func() { router.Close() + close(netConnection.invalidMessageChan) + netConnection.onDisconnectedHandler() + }) - if netConnection.onDisconnectedHandler != nil { - netConnection.onDisconnectedHandler() - } + netConnection.connection.SetOnInvalidMessageHandler(func(err error) { + netConnection.invalidMessageChan <- err }) router.SetOnRouteCapacityReachedHandler(func() { @@ -44,6 +49,10 @@ func newNetConnection(connection server.Connection, routerInitializer RouterInit } func (c *NetConnection) start() { + if c.onDisconnectedHandler == nil { + panic(errors.New("onDisconnectedHandler is nil")) + } + c.connection.Start(c.router) } @@ -76,12 +85,6 @@ func (c *NetConnection) NetAddress() *domainmessage.NetAddress { return domainmessage.NewNetAddress(c.connection.Address(), 0) } -// SetOnInvalidMessageHandler sets a handler function -// for invalid messages -func (c *NetConnection) SetOnInvalidMessageHandler(onInvalidMessageHandler server.OnInvalidMessageHandler) { - c.connection.SetOnInvalidMessageHandler(onInvalidMessageHandler) -} - func (c *NetConnection) setOnDisconnectedHandler(onDisconnectedHandler server.OnDisconnectedHandler) { c.onDisconnectedHandler = onDisconnectedHandler } @@ -90,3 +93,9 @@ func (c *NetConnection) setOnDisconnectedHandler(onDisconnectedHandler server.On func (c *NetConnection) Disconnect() { c.connection.Disconnect() } + +// DequeueInvalidMessage dequeues the next invalid message +func (c *NetConnection) DequeueInvalidMessage() (isOpen bool, err error) { + err, isOpen = <-c.invalidMessageChan + return isOpen, err +} diff --git a/netadapter/server/grpcserver/grpc_connection.go b/netadapter/server/grpcserver/grpc_connection.go index 299e141aa..b3d959206 100644 --- a/netadapter/server/grpcserver/grpc_connection.go +++ b/netadapter/server/grpcserver/grpc_connection.go @@ -1,6 +1,7 @@ package grpcserver import ( + "github.com/pkg/errors" "net" "sync/atomic" @@ -40,6 +41,14 @@ func newConnection(server *gRPCServer, address *net.TCPAddr, isOutbound bool, st } func (c *gRPCConnection) Start(router *router.Router) { + if c.onDisconnectedHandler == nil { + panic(errors.New("onDisconnectedHandler is nil")) + } + + if c.onInvalidMessageHandler == nil { + panic(errors.New("onInvalidMessageHandler is nil")) + } + c.router = router spawn("gRPCConnection.Start-connectionLoops", func() { diff --git a/netadapter/server/grpcserver/grpc_server.go b/netadapter/server/grpcserver/grpc_server.go index 5a2713413..f9cf141d3 100644 --- a/netadapter/server/grpcserver/grpc_server.go +++ b/netadapter/server/grpcserver/grpc_server.go @@ -37,6 +37,10 @@ func NewGRPCServer(listeningAddrs []string) (server.Server, error) { } func (s *gRPCServer) Start() error { + if s.onConnectedHandler == nil { + return errors.New("onConnectedHandler is nil") + } + for _, listenAddr := range s.listeningAddrs { err := s.listenOn(listenAddr) if err != nil { diff --git a/protocol/protocol.go b/protocol/protocol.go index 45923f55b..7d6ada3e0 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -49,9 +49,15 @@ func (m *Manager) routerInitializer(router *routerpkg.Router, netConnection *net return } - netConnection.SetOnInvalidMessageHandler(func(err error) { - if atomic.AddUint32(&isStopping, 1) == 1 { - errChan <- protocolerrors.Wrap(true, err, "received bad message") + spawn("Manager.routerInitializer-netConnection.DequeueInvalidMessage", func() { + for { + isOpen, err := netConnection.DequeueInvalidMessage() + if !isOpen { + return + } + if atomic.AddUint32(&isStopping, 1) == 1 { + errChan <- protocolerrors.Wrap(true, err, "received bad message") + } } }) diff --git a/protocol/protocolerrors/protocolerrors.go b/protocol/protocolerrors/protocolerrors.go index 05a1b843d..562fe4313 100644 --- a/protocol/protocolerrors/protocolerrors.go +++ b/protocol/protocolerrors/protocolerrors.go @@ -18,8 +18,7 @@ func (e *ProtocolError) Unwrap() error { } // Errorf formats according to a format specifier and returns the string -// as a value that satisfies error. -// Errorf also records the stack trace at the point it was called. +// as a ProtocolError. func Errorf(shouldBan bool, format string, args ...interface{}) error { return &ProtocolError{ ShouldBan: shouldBan, @@ -27,7 +26,7 @@ func Errorf(shouldBan bool, format string, args ...interface{}) error { } } -// New returns an error with the supplied message. +// New returns a ProtocolError with the supplied message. // New also records the stack trace at the point it was called. func New(shouldBan bool, message string) error { return &ProtocolError{ @@ -36,8 +35,7 @@ func New(shouldBan bool, message string) error { } } -// Wrap returns an error annotating err with a stack trace -// at the point Wrap is called, and the supplied message. +// Wrap wraps the given error and returns it as a ProtocolError. func Wrap(shouldBan bool, err error, message string) error { return &ProtocolError{ ShouldBan: shouldBan, @@ -45,8 +43,7 @@ func Wrap(shouldBan bool, err error, message string) error { } } -// Wrapf returns an error annotating err with a stack trace -// at the point Wrapf is called, and the format specifier. +// Wrapf wraps the given error with the given format and returns it as a ProtocolError. func Wrapf(shouldBan bool, err error, format string, args ...interface{}) error { return &ProtocolError{ ShouldBan: shouldBan,