[NOD-1152] Move banning from netAdapter to connectionManager (#820)

* [NOD-1152] Move banning out of netadapter.

* [NOD-1152] Add a comment.

* [NOD-1152] Fix a comment.
This commit is contained in:
stasatdaglabs 2020-07-26 13:42:48 +03:00 committed by GitHub
parent 6a18b56587
commit 683ceda3a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 48 additions and 44 deletions

View File

@ -33,6 +33,9 @@ type ConnectionManager struct {
activeIncoming map[string]struct{}
maxIncoming int
bannedAddresses map[string]struct{}
bannedAddressesLock sync.RWMutex
stop uint32
connectionRequestsLock sync.Mutex
@ -50,6 +53,7 @@ func New(cfg *config.Config, netAdapter *netadapter.NetAdapter, addressManager *
pendingRequested: map[string]*connectionRequest{},
activeOutgoing: map[string]struct{}{},
activeIncoming: map[string]struct{}{},
bannedAddresses: map[string]struct{}{},
resetLoopChan: make(chan struct{}),
loopTicker: time.NewTicker(connectionsLoopInterval),
}
@ -124,9 +128,31 @@ func (c *ConnectionManager) ConnectionCount() int {
return c.netAdapter.ConnectionCount()
}
// Ban prevents the given netConnection from connecting again
// Ban marks the given netConnection as banned
func (c *ConnectionManager) Ban(netConnection *netadapter.NetConnection) {
c.netAdapter.Ban(netConnection)
c.banIP(netConnection.IP())
}
// IsBanned returns whether the given netConnection is banned
func (c *ConnectionManager) IsBanned(netConnection *netadapter.NetConnection) bool {
return c.isIPBanned(netConnection.IP())
}
// banIP marks the given IP as banned
func (c *ConnectionManager) banIP(ip string) {
c.bannedAddressesLock.Lock()
defer c.bannedAddressesLock.Unlock()
c.bannedAddresses[ip] = struct{}{}
}
// isIPBanned returns whether the given IP is banned
func (c *ConnectionManager) isIPBanned(ip string) bool {
c.bannedAddressesLock.RLock()
defer c.bannedAddressesLock.RUnlock()
_, ok := c.bannedAddresses[ip]
return ok
}
func (c *ConnectionManager) waitTillNextIteration() {

View File

@ -32,7 +32,7 @@ func (c *ConnectionManager) checkOutgoingConnections(connSet connectionSet) {
netAddress := address.NetAddress()
tcpAddress := netAddress.TCPAddress()
if c.netAdapter.IsBanned(tcpAddress) {
if c.isIPBanned(tcpAddress.IP.String()) {
continue
}

View File

@ -17,7 +17,7 @@ import (
// RouterInitializer is a function that initializes a new
// router to be used with a new connection
type RouterInitializer func(netConnection *NetConnection) (*routerpkg.Router, error)
type RouterInitializer func(netConnection *NetConnection) *routerpkg.Router
// NetAdapter is an abstraction layer over networking.
// This type expects a RouteInitializer function. This
@ -102,10 +102,7 @@ func (na *NetAdapter) ConnectionCount() int {
func (na *NetAdapter) onConnectedHandler(connection server.Connection) error {
netConnection := newNetConnection(connection)
router, err := na.routerInitializer(netConnection)
if err != nil {
return err
}
router := na.routerInitializer(netConnection)
connection.Start(router)
na.connectionsToRouters[netConnection] = router
@ -222,14 +219,3 @@ func (na *NetAdapter) Disconnect(netConnection *NetConnection) error {
}
return nil
}
// IsBanned checks whether the given address had previously
// been banned
func (na *NetAdapter) IsBanned(address *net.TCPAddr) bool {
return na.server.IsBanned(address)
}
// Ban prevents the given netConnection from connecting again
func (na *NetAdapter) Ban(netConnection *NetConnection) {
na.server.Ban(netConnection.connection.Address())
}

View File

@ -43,6 +43,11 @@ func (c *NetConnection) IsOutbound() bool {
return c.connection.IsOutbound()
}
// IP returns the IP address associated with this connection
func (c *NetConnection) IP() string {
return c.connection.Address().IP.String()
}
// SetOnInvalidMessageHandler sets a handler function
// for invalid messages
func (c *NetConnection) SetOnInvalidMessageHandler(onInvalidMessageHandler server.OnInvalidMessageHandler) {

View File

@ -19,16 +19,14 @@ type gRPCServer struct {
onConnectedHandler server.OnConnectedHandler
listeningAddrs []string
server *grpc.Server
bannedAddresses map[string]struct{}
}
// NewGRPCServer creates and starts a gRPC server, listening on the
// provided addresses/ports
func NewGRPCServer(listeningAddrs []string) (server.Server, error) {
s := &gRPCServer{
server: grpc.NewServer(),
listeningAddrs: listeningAddrs,
bannedAddresses: make(map[string]struct{}),
server: grpc.NewServer(),
listeningAddrs: listeningAddrs,
}
protowire.RegisterP2PServer(s.server, newP2PServer(s))
@ -114,15 +112,3 @@ func (s *gRPCServer) Connect(address string) (server.Connection, error) {
return connection, nil
}
// IsBanned checks whether the given address had previously
// been banned
func (s *gRPCServer) IsBanned(address *net.TCPAddr) bool {
_, ok := s.bannedAddresses[address.IP.String()]
return ok
}
// Ban prevents the given address from connecting
func (s *gRPCServer) Ban(address *net.TCPAddr) {
s.bannedAddresses[address.IP.String()] = struct{}{}
}

View File

@ -29,11 +29,6 @@ func (p *p2pServer) MessageStream(stream protowire.P2P_MessageStreamServer) erro
return errors.Errorf("non-tcp connections are not supported")
}
if p.server.IsBanned(tcpAddress) {
log.Debugf("received connection attempt from banned peer %s", peerInfo.Addr)
return nil
}
connection := newConnection(p.server, tcpAddress, false, stream)
err := p.server.onConnectedHandler(connection)

View File

@ -28,8 +28,6 @@ type Server interface {
Start() error
Stop() error
SetOnConnectedHandler(onConnectedHandler OnConnectedHandler)
IsBanned(address *net.TCPAddr) bool
Ban(address *net.TCPAddr)
}
// Connection represents a p2p server connection.

View File

@ -21,9 +21,17 @@ import (
"github.com/pkg/errors"
)
func (m *Manager) routerInitializer(netConnection *netadapter.NetConnection) (*routerpkg.Router, error) {
func (m *Manager) routerInitializer(netConnection *netadapter.NetConnection) *routerpkg.Router {
router := routerpkg.NewRouter()
spawn("newRouterInitializer-startFlows", func() {
if m.context.ConnectionManager().IsBanned(netConnection) {
err := m.context.NetAdapter().Disconnect(netConnection)
if err != nil {
panic(err)
}
return
}
err := m.startFlows(netConnection, router)
if err != nil {
if protocolErr := &(protocolerrors.ProtocolError{}); errors.As(err, &protocolErr) {
@ -49,7 +57,7 @@ func (m *Manager) routerInitializer(netConnection *netadapter.NetConnection) (*r
panic(err)
}
})
return router, nil
return router
}
func (m *Manager) startFlows(netConnection *netadapter.NetConnection, router *routerpkg.Router) error {