[NOD-1152] Move banning from netAdapter to connectionManager (#820)

* [NOD-1152] Move banning out of netadapter.

* [NOD-1152] Add a comment.

* [NOD-1152] Fix a comment.
This commit is contained in:
stasatdaglabs 2020-07-26 13:42:48 +03:00 committed by GitHub
parent 6a18b56587
commit 683ceda3a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 48 additions and 44 deletions

View File

@ -33,6 +33,9 @@ type ConnectionManager struct {
activeIncoming map[string]struct{} activeIncoming map[string]struct{}
maxIncoming int maxIncoming int
bannedAddresses map[string]struct{}
bannedAddressesLock sync.RWMutex
stop uint32 stop uint32
connectionRequestsLock sync.Mutex connectionRequestsLock sync.Mutex
@ -50,6 +53,7 @@ func New(cfg *config.Config, netAdapter *netadapter.NetAdapter, addressManager *
pendingRequested: map[string]*connectionRequest{}, pendingRequested: map[string]*connectionRequest{},
activeOutgoing: map[string]struct{}{}, activeOutgoing: map[string]struct{}{},
activeIncoming: map[string]struct{}{}, activeIncoming: map[string]struct{}{},
bannedAddresses: map[string]struct{}{},
resetLoopChan: make(chan struct{}), resetLoopChan: make(chan struct{}),
loopTicker: time.NewTicker(connectionsLoopInterval), loopTicker: time.NewTicker(connectionsLoopInterval),
} }
@ -124,9 +128,31 @@ func (c *ConnectionManager) ConnectionCount() int {
return c.netAdapter.ConnectionCount() return c.netAdapter.ConnectionCount()
} }
// Ban prevents the given netConnection from connecting again // Ban marks the given netConnection as banned
func (c *ConnectionManager) Ban(netConnection *netadapter.NetConnection) { func (c *ConnectionManager) Ban(netConnection *netadapter.NetConnection) {
c.netAdapter.Ban(netConnection) c.banIP(netConnection.IP())
}
// IsBanned returns whether the given netConnection is banned
func (c *ConnectionManager) IsBanned(netConnection *netadapter.NetConnection) bool {
return c.isIPBanned(netConnection.IP())
}
// banIP marks the given IP as banned
func (c *ConnectionManager) banIP(ip string) {
c.bannedAddressesLock.Lock()
defer c.bannedAddressesLock.Unlock()
c.bannedAddresses[ip] = struct{}{}
}
// isIPBanned returns whether the given IP is banned
func (c *ConnectionManager) isIPBanned(ip string) bool {
c.bannedAddressesLock.RLock()
defer c.bannedAddressesLock.RUnlock()
_, ok := c.bannedAddresses[ip]
return ok
} }
func (c *ConnectionManager) waitTillNextIteration() { func (c *ConnectionManager) waitTillNextIteration() {

View File

@ -32,7 +32,7 @@ func (c *ConnectionManager) checkOutgoingConnections(connSet connectionSet) {
netAddress := address.NetAddress() netAddress := address.NetAddress()
tcpAddress := netAddress.TCPAddress() tcpAddress := netAddress.TCPAddress()
if c.netAdapter.IsBanned(tcpAddress) { if c.isIPBanned(tcpAddress.IP.String()) {
continue continue
} }

View File

@ -17,7 +17,7 @@ import (
// RouterInitializer is a function that initializes a new // RouterInitializer is a function that initializes a new
// router to be used with a new connection // router to be used with a new connection
type RouterInitializer func(netConnection *NetConnection) (*routerpkg.Router, error) type RouterInitializer func(netConnection *NetConnection) *routerpkg.Router
// NetAdapter is an abstraction layer over networking. // NetAdapter is an abstraction layer over networking.
// This type expects a RouteInitializer function. This // This type expects a RouteInitializer function. This
@ -102,10 +102,7 @@ func (na *NetAdapter) ConnectionCount() int {
func (na *NetAdapter) onConnectedHandler(connection server.Connection) error { func (na *NetAdapter) onConnectedHandler(connection server.Connection) error {
netConnection := newNetConnection(connection) netConnection := newNetConnection(connection)
router, err := na.routerInitializer(netConnection) router := na.routerInitializer(netConnection)
if err != nil {
return err
}
connection.Start(router) connection.Start(router)
na.connectionsToRouters[netConnection] = router na.connectionsToRouters[netConnection] = router
@ -222,14 +219,3 @@ func (na *NetAdapter) Disconnect(netConnection *NetConnection) error {
} }
return nil return nil
} }
// IsBanned checks whether the given address had previously
// been banned
func (na *NetAdapter) IsBanned(address *net.TCPAddr) bool {
return na.server.IsBanned(address)
}
// Ban prevents the given netConnection from connecting again
func (na *NetAdapter) Ban(netConnection *NetConnection) {
na.server.Ban(netConnection.connection.Address())
}

View File

@ -43,6 +43,11 @@ func (c *NetConnection) IsOutbound() bool {
return c.connection.IsOutbound() return c.connection.IsOutbound()
} }
// IP returns the IP address associated with this connection
func (c *NetConnection) IP() string {
return c.connection.Address().IP.String()
}
// SetOnInvalidMessageHandler sets a handler function // SetOnInvalidMessageHandler sets a handler function
// for invalid messages // for invalid messages
func (c *NetConnection) SetOnInvalidMessageHandler(onInvalidMessageHandler server.OnInvalidMessageHandler) { func (c *NetConnection) SetOnInvalidMessageHandler(onInvalidMessageHandler server.OnInvalidMessageHandler) {

View File

@ -19,7 +19,6 @@ type gRPCServer struct {
onConnectedHandler server.OnConnectedHandler onConnectedHandler server.OnConnectedHandler
listeningAddrs []string listeningAddrs []string
server *grpc.Server server *grpc.Server
bannedAddresses map[string]struct{}
} }
// NewGRPCServer creates and starts a gRPC server, listening on the // NewGRPCServer creates and starts a gRPC server, listening on the
@ -28,7 +27,6 @@ func NewGRPCServer(listeningAddrs []string) (server.Server, error) {
s := &gRPCServer{ s := &gRPCServer{
server: grpc.NewServer(), server: grpc.NewServer(),
listeningAddrs: listeningAddrs, listeningAddrs: listeningAddrs,
bannedAddresses: make(map[string]struct{}),
} }
protowire.RegisterP2PServer(s.server, newP2PServer(s)) protowire.RegisterP2PServer(s.server, newP2PServer(s))
@ -114,15 +112,3 @@ func (s *gRPCServer) Connect(address string) (server.Connection, error) {
return connection, nil return connection, nil
} }
// IsBanned checks whether the given address had previously
// been banned
func (s *gRPCServer) IsBanned(address *net.TCPAddr) bool {
_, ok := s.bannedAddresses[address.IP.String()]
return ok
}
// Ban prevents the given address from connecting
func (s *gRPCServer) Ban(address *net.TCPAddr) {
s.bannedAddresses[address.IP.String()] = struct{}{}
}

View File

@ -29,11 +29,6 @@ func (p *p2pServer) MessageStream(stream protowire.P2P_MessageStreamServer) erro
return errors.Errorf("non-tcp connections are not supported") return errors.Errorf("non-tcp connections are not supported")
} }
if p.server.IsBanned(tcpAddress) {
log.Debugf("received connection attempt from banned peer %s", peerInfo.Addr)
return nil
}
connection := newConnection(p.server, tcpAddress, false, stream) connection := newConnection(p.server, tcpAddress, false, stream)
err := p.server.onConnectedHandler(connection) err := p.server.onConnectedHandler(connection)

View File

@ -28,8 +28,6 @@ type Server interface {
Start() error Start() error
Stop() error Stop() error
SetOnConnectedHandler(onConnectedHandler OnConnectedHandler) SetOnConnectedHandler(onConnectedHandler OnConnectedHandler)
IsBanned(address *net.TCPAddr) bool
Ban(address *net.TCPAddr)
} }
// Connection represents a p2p server connection. // Connection represents a p2p server connection.

View File

@ -21,9 +21,17 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
func (m *Manager) routerInitializer(netConnection *netadapter.NetConnection) (*routerpkg.Router, error) { func (m *Manager) routerInitializer(netConnection *netadapter.NetConnection) *routerpkg.Router {
router := routerpkg.NewRouter() router := routerpkg.NewRouter()
spawn("newRouterInitializer-startFlows", func() { spawn("newRouterInitializer-startFlows", func() {
if m.context.ConnectionManager().IsBanned(netConnection) {
err := m.context.NetAdapter().Disconnect(netConnection)
if err != nil {
panic(err)
}
return
}
err := m.startFlows(netConnection, router) err := m.startFlows(netConnection, router)
if err != nil { if err != nil {
if protocolErr := &(protocolerrors.ProtocolError{}); errors.As(err, &protocolErr) { if protocolErr := &(protocolerrors.ProtocolError{}); errors.As(err, &protocolErr) {
@ -49,7 +57,7 @@ func (m *Manager) routerInitializer(netConnection *netadapter.NetConnection) (*r
panic(err) panic(err)
} }
}) })
return router, nil return router
} }
func (m *Manager) startFlows(netConnection *netadapter.NetConnection, router *routerpkg.Router) error { func (m *Manager) startFlows(netConnection *netadapter.NetConnection, router *routerpkg.Router) error {