diff --git a/connmanager/connmanager.go b/connmanager/connmanager.go index 5b77dcd62..b0a04c909 100644 --- a/connmanager/connmanager.go +++ b/connmanager/connmanager.go @@ -112,3 +112,8 @@ func (c *ConnectionManager) connectionsLoop() { func (c *ConnectionManager) ConnectionCount() int { return c.netAdapter.ConnectionCount() } + +// Ban prevents the given netConnection from connecting again +func (c *ConnectionManager) Ban(netConnection *netadapter.NetConnection) { + c.netAdapter.Ban(netConnection) +} diff --git a/connmanager/outgoing_connections.go b/connmanager/outgoing_connections.go index d757ac80c..594dc8d66 100644 --- a/connmanager/outgoing_connections.go +++ b/connmanager/outgoing_connections.go @@ -22,7 +22,8 @@ func (c *ConnectionManager) checkOutgoingConnections(connSet connectionSet) { log.Debugf("Have got %d outgoing connections out of target %d, adding %d more", liveConnections, c.targetOutgoing, c.targetOutgoing-liveConnections) - for len(c.activeOutgoing) < c.targetOutgoing { + connectionsNeededCount := c.targetOutgoing - len(c.activeOutgoing) + for i := 0; i < connectionsNeededCount; i++ { address := c.addressManager.GetAddress() if address == nil { log.Warnf("No more addresses available") @@ -30,15 +31,20 @@ func (c *ConnectionManager) checkOutgoingConnections(connSet connectionSet) { } netAddress := address.NetAddress() + tcpAddress := netAddress.TCPAddress() + if c.netAdapter.IsBanned(tcpAddress) { + continue + } + c.addressManager.Attempt(netAddress) - addressString := netAddress.TCPAddress().String() + addressString := tcpAddress.String() err := c.initiateConnection(addressString) if err != nil { log.Infof("Couldn't connect to %s: %s", addressString, err) continue } - c.addressManager.Connected(address.NetAddress()) - c.activeOutgoing[address.NetAddress().TCPAddress().String()] = struct{}{} + c.addressManager.Connected(netAddress) + c.activeOutgoing[addressString] = struct{}{} } } diff --git a/kaspad.go b/kaspad.go index b6805969c..1f399ecbd 100644 --- a/kaspad.go +++ b/kaspad.go @@ -125,12 +125,12 @@ func newKaspad(cfg *config.Config, databaseContext *dbaccess.DatabaseContext, in } addressManager := addrmgr.New(cfg, databaseContext) - protocolManager, err := protocol.NewManager(cfg, dag, addressManager, txMempool) + connectionManager, err := connmanager.New(cfg, netAdapter, addressManager) if err != nil { return nil, err } - connectionManager, err := connmanager.New(cfg, netAdapter, addressManager) + protocolManager, err := protocol.NewManager(cfg, dag, addressManager, txMempool, connectionManager) if err != nil { return nil, err } diff --git a/netadapter/netadapter.go b/netadapter/netadapter.go index f5163238b..31b8e439f 100644 --- a/netadapter/netadapter.go +++ b/netadapter/netadapter.go @@ -259,3 +259,14 @@ func (na *NetAdapter) Disconnect(netConnection *NetConnection) error { } return nil } + +// IsBanned checks whether the given address had previously +// been banned +func (na *NetAdapter) IsBanned(address *net.TCPAddr) bool { + return na.server.IsBanned(address) +} + +// Ban prevents the given netConnection from connecting again +func (na *NetAdapter) Ban(netConnection *NetConnection) { + na.server.Ban(netConnection.connection.Address()) +} diff --git a/netadapter/netconnection.go b/netadapter/netconnection.go index fae5ce6af..7b5565b8e 100644 --- a/netadapter/netconnection.go +++ b/netadapter/netconnection.go @@ -33,3 +33,9 @@ func (c *NetConnection) ID() *id.ID { func (c *NetConnection) Address() string { return c.connection.Address().String() } + +// SetOnInvalidMessageHandler sets a handler function +// for invalid messages +func (c *NetConnection) SetOnInvalidMessageHandler(onInvalidMessageHandler server.OnInvalidMessageHandler) { + c.connection.SetOnInvalidMessageHandler(onInvalidMessageHandler) +} diff --git a/netadapter/server/grpcserver/connection_loops.go b/netadapter/server/grpcserver/connection_loops.go index ee8ddceef..27dda71ab 100644 --- a/netadapter/server/grpcserver/connection_loops.go +++ b/netadapter/server/grpcserver/connection_loops.go @@ -68,6 +68,7 @@ func (c *gRPCConnection) receiveLoop() error { } message, err := protoMessage.ToWireMessage() if err != nil { + c.onInvalidMessageHandler(err) return err } @@ -81,6 +82,7 @@ func (c *gRPCConnection) receiveLoop() error { log.Debugf("Router for %s is closed. Exiting the receive loop", c) return nil } + c.onInvalidMessageHandler(err) return err } } diff --git a/netadapter/server/grpcserver/grpc_connection.go b/netadapter/server/grpcserver/grpc_connection.go index 11ecc37d6..753cbe4c4 100644 --- a/netadapter/server/grpcserver/grpc_connection.go +++ b/netadapter/server/grpcserver/grpc_connection.go @@ -13,19 +13,20 @@ import ( type gRPCConnection struct { server *gRPCServer - address net.Addr + address *net.TCPAddr isOutbound bool stream grpcStream router *router.Router - stopChan chan struct{} - clientConn grpc.ClientConn - onDisconnectedHandler server.OnDisconnectedHandler + stopChan chan struct{} + clientConn grpc.ClientConn + onDisconnectedHandler server.OnDisconnectedHandler + onInvalidMessageHandler server.OnInvalidMessageHandler isConnected uint32 } -func newConnection(server *gRPCServer, address net.Addr, isOutbound bool, stream grpcStream) *gRPCConnection { +func newConnection(server *gRPCServer, address *net.TCPAddr, isOutbound bool, stream grpcStream) *gRPCConnection { connection := &gRPCConnection{ server: server, address: address, @@ -61,6 +62,10 @@ func (c *gRPCConnection) SetOnDisconnectedHandler(onDisconnectedHandler server.O c.onDisconnectedHandler = onDisconnectedHandler } +func (c *gRPCConnection) SetOnInvalidMessageHandler(onInvalidMessageHandler server.OnInvalidMessageHandler) { + c.onInvalidMessageHandler = onInvalidMessageHandler +} + // Disconnect disconnects the connection // Calling this function a second time doesn't do anything // @@ -83,6 +88,6 @@ func (c *gRPCConnection) Disconnect() error { return c.onDisconnectedHandler() } -func (c *gRPCConnection) Address() net.Addr { +func (c *gRPCConnection) Address() *net.TCPAddr { return c.address } diff --git a/netadapter/server/grpcserver/grpc_server.go b/netadapter/server/grpcserver/grpc_server.go index 07bf8bf7f..253d3c68c 100644 --- a/netadapter/server/grpcserver/grpc_server.go +++ b/netadapter/server/grpcserver/grpc_server.go @@ -19,14 +19,16 @@ type gRPCServer struct { onConnectedHandler server.OnConnectedHandler listeningAddrs []string server *grpc.Server + bannedAddresses map[string]struct{} } // NewGRPCServer creates and starts a gRPC server, listening on the // provided addresses/ports func NewGRPCServer(listeningAddrs []string) (server.Server, error) { s := &gRPCServer{ - server: grpc.NewServer(), - listeningAddrs: listeningAddrs, + server: grpc.NewServer(), + listeningAddrs: listeningAddrs, + bannedAddresses: make(map[string]struct{}), } protowire.RegisterP2PServer(s.server, newP2PServer(s)) @@ -96,8 +98,12 @@ func (s *gRPCServer) Connect(address string) (server.Connection, error) { if !ok { return nil, errors.Errorf("error getting stream peer info from context for %s", address) } + tcpAddress, ok := peerInfo.Addr.(*net.TCPAddr) + if !ok { + return nil, errors.Errorf("non-tcp addresses are not supported") + } - connection := newConnection(s, peerInfo.Addr, true, stream) + connection := newConnection(s, tcpAddress, true, stream) err = s.onConnectedHandler(connection) if err != nil { @@ -108,3 +114,15 @@ func (s *gRPCServer) Connect(address string) (server.Connection, error) { return connection, nil } + +// IsBanned checks whether the given address had previously +// been banned +func (s *gRPCServer) IsBanned(address *net.TCPAddr) bool { + _, ok := s.bannedAddresses[address.IP.String()] + return ok +} + +// Ban prevents the given address from connecting +func (s *gRPCServer) Ban(address *net.TCPAddr) { + s.bannedAddresses[address.IP.String()] = struct{}{} +} diff --git a/netadapter/server/grpcserver/p2pserver.go b/netadapter/server/grpcserver/p2pserver.go index c49d0cced..34b7a356d 100644 --- a/netadapter/server/grpcserver/p2pserver.go +++ b/netadapter/server/grpcserver/p2pserver.go @@ -5,6 +5,7 @@ import ( "github.com/kaspanet/kaspad/util/panics" "github.com/pkg/errors" "google.golang.org/grpc/peer" + "net" ) type p2pServer struct { @@ -23,7 +24,17 @@ func (p *p2pServer) MessageStream(stream protowire.P2P_MessageStreamServer) erro if !ok { return errors.Errorf("Error getting stream peer info from context") } - connection := newConnection(p.server, peerInfo.Addr, false, stream) + tcpAddress, ok := peerInfo.Addr.(*net.TCPAddr) + if !ok { + return errors.Errorf("non-tcp connections are not supported") + } + + if p.server.IsBanned(tcpAddress) { + log.Debugf("received connection attempt from banned peer %s", peerInfo.Addr) + return nil + } + + connection := newConnection(p.server, tcpAddress, false, stream) err := p.server.onConnectedHandler(connection) if err != nil { diff --git a/netadapter/server/server.go b/netadapter/server/server.go index a683a170b..1e139bbbf 100644 --- a/netadapter/server/server.go +++ b/netadapter/server/server.go @@ -17,12 +17,19 @@ type OnConnectedHandler func(connection Connection) error // called once a Connection has been disconnected. type OnDisconnectedHandler func() error +// OnInvalidMessageHandler is a function that is to be called when +// an invalid message (cannot be parsed/doesn't have a route) +// was received from a connection. +type OnInvalidMessageHandler func(err error) + // Server represents a p2p server. type Server interface { Connect(address string) (Connection, error) Start() error Stop() error SetOnConnectedHandler(onConnectedHandler OnConnectedHandler) + IsBanned(address *net.TCPAddr) bool + Ban(address *net.TCPAddr) } // Connection represents a p2p server connection. @@ -32,7 +39,8 @@ type Connection interface { Disconnect() error IsConnected() bool SetOnDisconnectedHandler(onDisconnectedHandler OnDisconnectedHandler) - Address() net.Addr + SetOnInvalidMessageHandler(onInvalidMessageHandler OnInvalidMessageHandler) + Address() *net.TCPAddr } // ErrNetwork is an error related to the internals of the connection, and not an error that diff --git a/protocol/flowcontext/flow_context.go b/protocol/flowcontext/flow_context.go index 99f7cc904..4d84cddf5 100644 --- a/protocol/flowcontext/flow_context.go +++ b/protocol/flowcontext/flow_context.go @@ -4,6 +4,7 @@ import ( "github.com/kaspanet/kaspad/addrmgr" "github.com/kaspanet/kaspad/blockdag" "github.com/kaspanet/kaspad/config" + "github.com/kaspanet/kaspad/connmanager" "github.com/kaspanet/kaspad/mempool" "github.com/kaspanet/kaspad/netadapter" "github.com/kaspanet/kaspad/netadapter/id" @@ -25,6 +26,7 @@ type FlowContext struct { addedTransactions []*util.Tx dag *blockdag.BlockDAG addressManager *addrmgr.AddrManager + connectionManager *connmanager.ConnectionManager transactionsToRebroadcastLock sync.Mutex transactionsToRebroadcast map[daghash.TxID]*util.Tx @@ -42,13 +44,16 @@ type FlowContext struct { } // New returns a new instance of FlowContext. -func New(cfg *config.Config, dag *blockdag.BlockDAG, - addressManager *addrmgr.AddrManager, txPool *mempool.TxPool, netAdapter *netadapter.NetAdapter) *FlowContext { +func New(cfg *config.Config, dag *blockdag.BlockDAG, addressManager *addrmgr.AddrManager, + txPool *mempool.TxPool, netAdapter *netadapter.NetAdapter, + connectionManager *connmanager.ConnectionManager) *FlowContext { + return &FlowContext{ cfg: cfg, netAdapter: netAdapter, dag: dag, addressManager: addressManager, + connectionManager: connectionManager, txPool: txPool, sharedRequestedTransactions: relaytransactions.NewSharedRequestedTransactions(), sharedRequestedBlocks: blockrelay.NewSharedRequestedBlocks(), diff --git a/protocol/flowcontext/network.go b/protocol/flowcontext/network.go index 2b8ed900e..9a92b6b8d 100644 --- a/protocol/flowcontext/network.go +++ b/protocol/flowcontext/network.go @@ -1,6 +1,7 @@ package flowcontext import ( + "github.com/kaspanet/kaspad/connmanager" "github.com/kaspanet/kaspad/netadapter" "github.com/kaspanet/kaspad/netadapter/id" "github.com/kaspanet/kaspad/protocol/common" @@ -14,6 +15,11 @@ func (f *FlowContext) NetAdapter() *netadapter.NetAdapter { return f.netAdapter } +// ConnectionManager returns the connection manager that is associated to the flow context. +func (f *FlowContext) ConnectionManager() *connmanager.ConnectionManager { + return f.connectionManager +} + // AddToPeers marks this peer as ready and adds it to the ready peers list. func (f *FlowContext) AddToPeers(peer *peerpkg.Peer) error { f.peersMutex.RLock() diff --git a/protocol/manager.go b/protocol/manager.go index 534b03186..2d6bc8569 100644 --- a/protocol/manager.go +++ b/protocol/manager.go @@ -4,6 +4,7 @@ import ( "github.com/kaspanet/kaspad/addrmgr" "github.com/kaspanet/kaspad/blockdag" "github.com/kaspanet/kaspad/config" + "github.com/kaspanet/kaspad/connmanager" "github.com/kaspanet/kaspad/mempool" "github.com/kaspanet/kaspad/netadapter" "github.com/kaspanet/kaspad/protocol/flowcontext" @@ -18,7 +19,8 @@ type Manager struct { // NewManager creates a new instance of the p2p protocol manager func NewManager(cfg *config.Config, dag *blockdag.BlockDAG, - addressManager *addrmgr.AddrManager, txPool *mempool.TxPool) (*Manager, error) { + addressManager *addrmgr.AddrManager, txPool *mempool.TxPool, + connectionManager *connmanager.ConnectionManager) (*Manager, error) { netAdapter, err := netadapter.NewNetAdapter(cfg) if err != nil { @@ -26,7 +28,7 @@ func NewManager(cfg *config.Config, dag *blockdag.BlockDAG, } manager := Manager{ - context: flowcontext.New(cfg, dag, addressManager, txPool, netAdapter), + context: flowcontext.New(cfg, dag, addressManager, txPool, netAdapter, connectionManager), } netAdapter.SetRouterInitializer(manager.routerInitializer) return &manager, nil diff --git a/protocol/protocol.go b/protocol/protocol.go index 27399ed1f..85167efb4 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -28,8 +28,7 @@ func (m *Manager) routerInitializer(netConnection *netadapter.NetConnection) (*r if err != nil { if protocolErr := &(protocolerrors.ProtocolError{}); errors.As(err, &protocolErr) { if protocolErr.ShouldBan { - // TODO(libp2p) Ban peer - panic("unimplemented") + m.context.ConnectionManager().Ban(netConnection) } err = m.context.NetAdapter().Disconnect(netConnection) if err != nil { @@ -57,6 +56,12 @@ func (m *Manager) startFlows(netConnection *netadapter.NetConnection, router *ro stop := make(chan error) stopped := uint32(0) + netConnection.SetOnInvalidMessageHandler(func(err error) { + if atomic.AddUint32(&stopped, 1) == 1 { + stop <- protocolerrors.Wrap(true, err, "received bad message") + } + }) + peer, closed, err := handshake.HandleHandshake(m.context, router, netConnection) if err != nil { return err