mirror of
https://github.com/kaspanet/kaspad.git
synced 2025-03-30 15:08:33 +00:00

* [NOD-1198] Make router a property of netConnection, and remove map from connection to router in netAdapter * [NOD-1198] Moved all router logic from netAdapter to netConnection * [NOD-1198] Move disconnect to NetConnection * [NOD-1198] Unexport netConnection.start * [NOD-1198] Remove error from Disconnect functions * [NOD-1198] Make sure OnDisconnectedHandler doesn't run when it shouldn't
213 lines
5.4 KiB
Go
213 lines
5.4 KiB
Go
package netadapter
|
|
|
|
import (
|
|
"net"
|
|
"strconv"
|
|
"sync"
|
|
"sync/atomic"
|
|
|
|
"github.com/kaspanet/kaspad/config"
|
|
"github.com/kaspanet/kaspad/netadapter/id"
|
|
routerpkg "github.com/kaspanet/kaspad/netadapter/router"
|
|
"github.com/kaspanet/kaspad/netadapter/server"
|
|
"github.com/kaspanet/kaspad/netadapter/server/grpcserver"
|
|
"github.com/kaspanet/kaspad/wire"
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
// RouterInitializer is a function that initializes a new
|
|
// router to be used with a new connection
|
|
type RouterInitializer func(*routerpkg.Router, *NetConnection)
|
|
|
|
// NetAdapter is an abstraction layer over networking.
|
|
// This type expects a RouteInitializer function. This
|
|
// function weaves together the various "routes" (messages
|
|
// and message handlers) without exposing anything related
|
|
// to networking internals.
|
|
type NetAdapter struct {
|
|
cfg *config.Config
|
|
id *id.ID
|
|
server server.Server
|
|
routerInitializer RouterInitializer
|
|
stop uint32
|
|
|
|
connections map[*NetConnection]struct{}
|
|
connectionsLock sync.RWMutex
|
|
}
|
|
|
|
// NewNetAdapter creates and starts a new NetAdapter on the
|
|
// given listeningPort
|
|
func NewNetAdapter(cfg *config.Config) (*NetAdapter, error) {
|
|
netAdapterID, err := id.GenerateID()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
s, err := grpcserver.NewGRPCServer(cfg.Listeners)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
adapter := NetAdapter{
|
|
cfg: cfg,
|
|
id: netAdapterID,
|
|
server: s,
|
|
|
|
connections: make(map[*NetConnection]struct{}),
|
|
}
|
|
|
|
adapter.server.SetOnConnectedHandler(adapter.onConnectedHandler)
|
|
|
|
return &adapter, nil
|
|
}
|
|
|
|
// Start begins the operation of the NetAdapter
|
|
func (na *NetAdapter) Start() error {
|
|
err := na.server.Start()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Stop safely closes the NetAdapter
|
|
func (na *NetAdapter) Stop() error {
|
|
if atomic.AddUint32(&na.stop, 1) != 1 {
|
|
return errors.New("net adapter stopped more than once")
|
|
}
|
|
return na.server.Stop()
|
|
}
|
|
|
|
// Connect tells the NetAdapter's underlying server to initiate a connection
|
|
// to the given address
|
|
func (na *NetAdapter) Connect(address string) error {
|
|
_, err := na.server.Connect(address)
|
|
return err
|
|
}
|
|
|
|
// Connections returns a list of connections currently connected and active
|
|
func (na *NetAdapter) Connections() []*NetConnection {
|
|
na.connectionsLock.RLock()
|
|
defer na.connectionsLock.RUnlock()
|
|
|
|
netConnections := make([]*NetConnection, 0, len(na.connections))
|
|
|
|
for netConnection := range na.connections {
|
|
netConnections = append(netConnections, netConnection)
|
|
}
|
|
|
|
return netConnections
|
|
}
|
|
|
|
// ConnectionCount returns the count of the connected connections
|
|
func (na *NetAdapter) ConnectionCount() int {
|
|
na.connectionsLock.RLock()
|
|
defer na.connectionsLock.RUnlock()
|
|
|
|
return len(na.connections)
|
|
}
|
|
|
|
func (na *NetAdapter) onConnectedHandler(connection server.Connection) error {
|
|
netConnection := newNetConnection(connection, na.routerInitializer)
|
|
|
|
na.connectionsLock.Lock()
|
|
defer na.connectionsLock.Unlock()
|
|
|
|
netConnection.setOnDisconnectedHandler(func() {
|
|
na.connectionsLock.Lock()
|
|
defer na.connectionsLock.Unlock()
|
|
|
|
delete(na.connections, netConnection)
|
|
})
|
|
|
|
na.connections[netConnection] = struct{}{}
|
|
|
|
netConnection.start()
|
|
|
|
return nil
|
|
}
|
|
|
|
// SetRouterInitializer sets the routerInitializer function
|
|
// for the net adapter
|
|
func (na *NetAdapter) SetRouterInitializer(routerInitializer RouterInitializer) {
|
|
na.routerInitializer = routerInitializer
|
|
}
|
|
|
|
// ID returns this netAdapter's ID in the network
|
|
func (na *NetAdapter) ID() *id.ID {
|
|
return na.id
|
|
}
|
|
|
|
// Broadcast sends the given `message` to every peer corresponding
|
|
// to each NetConnection in the given netConnections
|
|
func (na *NetAdapter) Broadcast(netConnections []*NetConnection, message wire.Message) error {
|
|
na.connectionsLock.RLock()
|
|
defer na.connectionsLock.RUnlock()
|
|
|
|
for _, netConnection := range netConnections {
|
|
err := netConnection.router.OutgoingRoute().Enqueue(message)
|
|
if err != nil {
|
|
if errors.Is(err, routerpkg.ErrRouteClosed) {
|
|
log.Debugf("Cannot enqueue message to %s: router is closed", netConnection)
|
|
continue
|
|
}
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetBestLocalAddress returns the most appropriate local address to use
|
|
// for the given remote address.
|
|
func (na *NetAdapter) GetBestLocalAddress() (*wire.NetAddress, error) {
|
|
//TODO(libp2p) Reimplement this, and check reachability to the other node
|
|
if len(na.cfg.ExternalIPs) > 0 {
|
|
host, portString, err := net.SplitHostPort(na.cfg.ExternalIPs[0])
|
|
if err != nil {
|
|
portString = na.cfg.NetParams().DefaultPort
|
|
}
|
|
portInt, err := strconv.Atoi(portString)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
ip := net.ParseIP(host)
|
|
if ip == nil {
|
|
hostAddrs, err := net.LookupHost(host)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ip = net.ParseIP(hostAddrs[0])
|
|
if ip == nil {
|
|
return nil, errors.Errorf("Cannot resolve IP address for host '%s'", host)
|
|
}
|
|
}
|
|
return wire.NewNetAddressIPPort(ip, uint16(portInt), wire.SFNodeNetwork), nil
|
|
|
|
}
|
|
listenAddress := na.cfg.Listeners[0]
|
|
_, portString, err := net.SplitHostPort(listenAddress)
|
|
if err != nil {
|
|
portString = na.cfg.NetParams().DefaultPort
|
|
}
|
|
|
|
portInt, err := strconv.Atoi(portString)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
addresses, err := net.InterfaceAddrs()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for _, address := range addresses {
|
|
ip, _, err := net.ParseCIDR(address.String())
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
return wire.NewNetAddressIPPort(ip, uint16(portInt), wire.SFNodeNetwork), nil
|
|
}
|
|
return nil, errors.New("no address was found")
|
|
}
|