[NOD-847] Fix CIDR protection and prevent connecting to the same address twice (#714)

* [NOD-847] Fix CIDR protection and prevent connecting to the same address twice

* [NOD-847] Fix Tests

* [NOD-847] Add TestDuplicateOutboundConnections and TestSameOutboundGroupConnections

* [NOD-847] Fix TestRetryPermanent, TestNetworkFailure and wait 10 ms before restoring the previous active config

* [NOD-847] Add "is" before boolean methods

* [NOD-847] Fix Connect's lock

* [NOD-847] Make numAddressesInAddressManager an argument

* [NOD-847] Add teardown function for address manager

* [NOD-847] Add stack trace to ConnManager errors

* [NOD-847] Change emptyAddressManagerForTest->createEmptyAddressManagerForTest and fix typos

* [NOD-847] Fix wrong test name for addressManagerForTest

* [NOD-847] Change error message if New fails

* [NOD-847] Add new line on releaseAddress

* [NOD-847] Always try to reconnect on disconnect
This commit is contained in:
Ori Newman 2020-05-12 13:47:15 +03:00 committed by GitHub
parent c8a381d5bb
commit 585510d76c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 500 additions and 188 deletions

View File

@ -7,6 +7,9 @@ package connmgr
import (
nativeerrors "errors"
"fmt"
"github.com/kaspanet/kaspad/addrmgr"
"github.com/kaspanet/kaspad/config"
"github.com/kaspanet/kaspad/wire"
"net"
"sync"
"sync/atomic"
@ -30,10 +33,6 @@ var (
// defaultRetryDuration is the default duration of time for retrying
// persistent connections.
defaultRetryDuration = time.Second * 5
// defaultTargetOutbound is the default number of outbound connections to
// maintain.
defaultTargetOutbound = uint32(8)
)
var (
@ -54,6 +53,9 @@ var (
// ErrPeerNotFound is an error that is thrown if the peer was not found.
ErrPeerNotFound = errors.New("peer not found")
//ErrAddressManagerNil is used to indicate that Address Manager cannot be nil in the configuration.
ErrAddressManagerNil = errors.New("Config: Address manager cannot be nil")
)
// ConnState represents the state of the requested connection.
@ -77,7 +79,7 @@ type ConnReq struct {
// The following variables must only be used atomically.
id uint64
Addr net.Addr
Addr *net.TCPAddr
Permanent bool
conn net.Conn
@ -159,9 +161,7 @@ type Config struct {
// connection is disconnected.
OnDisconnection func(*ConnReq)
// GetNewAddress is a way to get an address to make a network connection
// to. If nil, no new connections will be made automatically.
GetNewAddress func() (net.Addr, error)
AddrManager *addrmgr.AddrManager
// Dial connects to the address on the named network. It cannot be nil.
Dial func(net.Addr) (net.Conn, error)
@ -201,7 +201,9 @@ type ConnManager struct {
start int32
stop int32
newConnReqMtx sync.Mutex
addressMtx sync.Mutex
usedOutboundGroups map[string]int64
usedAddresses map[string]struct{}
cfg Config
wg sync.WaitGroup
@ -237,9 +239,12 @@ func (cm *ConnManager) handleFailedConn(c *ConnReq, err error) {
log.Debugf("Retrying further connections to %s every %s", c, d)
}
spawnAfter(d, func() {
cm.Connect(c)
cm.connect(c)
})
} else if cm.cfg.GetNewAddress != nil {
} else {
if c.Addr != nil {
cm.releaseAddress(c.Addr)
}
cm.failedAttempts++
if cm.failedAttempts >= maxFailedAttempts {
if shouldWriteLog {
@ -254,6 +259,43 @@ func (cm *ConnManager) handleFailedConn(c *ConnReq, err error) {
}
}
func (cm *ConnManager) releaseAddress(addr *net.TCPAddr) {
cm.addressMtx.Lock()
defer cm.addressMtx.Unlock()
groupKey := usedOutboundGroupsKey(addr)
cm.usedOutboundGroups[groupKey]--
if cm.usedOutboundGroups[groupKey] < 0 {
panic(fmt.Errorf("cm.usedOutboundGroups[%s] has a negative value of %d. This should never happen", groupKey, cm.usedOutboundGroups[groupKey]))
}
delete(cm.usedAddresses, usedAddressesKey(addr))
}
func (cm *ConnManager) markAddressAsUsed(addr *net.TCPAddr) {
cm.usedOutboundGroups[usedOutboundGroupsKey(addr)]++
cm.usedAddresses[usedAddressesKey(addr)] = struct{}{}
}
func (cm *ConnManager) isOutboundGroupUsed(addr *net.TCPAddr) bool {
_, ok := cm.usedOutboundGroups[usedOutboundGroupsKey(addr)]
return ok
}
func (cm *ConnManager) isAddressUsed(addr *net.TCPAddr) bool {
_, ok := cm.usedAddresses[usedAddressesKey(addr)]
return ok
}
func usedOutboundGroupsKey(addr *net.TCPAddr) string {
// A fake service flag is used since it doesn't affect the group key.
na := wire.NewNetAddress(addr, wire.SFNodeNetwork)
return addrmgr.GroupKey(na)
}
func usedAddressesKey(addr *net.TCPAddr) string {
return addr.String()
}
// throttledError defines an error type whose logs get throttled. This is to
// prevent flooding the logs with identical errors.
type throttledError error
@ -392,21 +434,16 @@ out:
continue
}
// Otherwise, we will attempt a reconnection if
// we do not have enough peers, or if this is a
// persistent peer. The connection request is
// re added to the pending map, so that
// subsequent processing of connections and
// failures do not ignore the request.
if uint32(len(conns)) < cm.cfg.TargetOutbound ||
connReq.Permanent {
connReq.updateState(ConnPending)
log.Debugf("Reconnecting to %s",
connReq)
pending[msg.id] = connReq
cm.handleFailedConn(connReq, nil)
}
// Otherwise, we will attempt a reconnection.
// The connection request is re added to the
// pending map, so that subsequent processing
// of connections and failures do not ignore
// the request.
connReq.updateState(ConnPending)
log.Debugf("Reconnecting to %s",
connReq)
pending[msg.id] = connReq
cm.handleFailedConn(connReq, nil)
case handleFailed:
connReq := msg.c
@ -448,14 +485,9 @@ func (cm *ConnManager) NotifyConnectionRequestComplete() {
// NewConnReq creates a new connection request and connects to the
// corresponding address.
func (cm *ConnManager) NewConnReq() {
cm.newConnReqMtx.Lock()
defer cm.newConnReqMtx.Unlock()
if atomic.LoadInt32(&cm.stop) != 0 {
return
}
if cm.cfg.GetNewAddress == nil {
return
}
c := &ConnReq{}
atomic.StoreUint64(&c.id, atomic.AddUint64(&cm.connReqCount, 1))
@ -478,8 +510,7 @@ func (cm *ConnManager) NewConnReq() {
case <-cm.quit:
return
}
addr, err := cm.cfg.GetNewAddress()
err := cm.associateAddressToConnReq(c)
if err != nil {
select {
case cm.requests <- handleFailed{c, err}:
@ -488,17 +519,52 @@ func (cm *ConnManager) NewConnReq() {
return
}
c.Addr = addr
cm.connect(c)
}
cm.Connect(c)
func (cm *ConnManager) associateAddressToConnReq(c *ConnReq) error {
cm.addressMtx.Lock()
defer cm.addressMtx.Unlock()
addr, err := cm.getNewAddress()
if err != nil {
return err
}
cm.markAddressAsUsed(addr)
c.Addr = addr
return nil
}
// Connect assigns an id and dials a connection to the address of the
// connection request.
func (cm *ConnManager) Connect(c *ConnReq) {
func (cm *ConnManager) Connect(c *ConnReq) error {
err := func() error {
cm.addressMtx.Lock()
defer cm.addressMtx.Unlock()
if cm.isAddressUsed(c.Addr) {
return fmt.Errorf("address %s is already in use", c.Addr)
}
cm.markAddressAsUsed(c.Addr)
return nil
}()
if err != nil {
return err
}
cm.connect(c)
return nil
}
// connect assigns an id and dials a connection to the address of the
// connection request. This function assumes that the connection address
// has checked and already marked as used.
func (cm *ConnManager) connect(c *ConnReq) {
if atomic.LoadInt32(&cm.stop) != 0 {
return
}
if atomic.LoadUint64(&c.id) == 0 {
atomic.StoreUint64(&c.id, atomic.AddUint64(&cm.connReqCount, 1))
@ -645,23 +711,69 @@ func (cm *ConnManager) Stop() {
log.Trace("Connection manager stopped")
}
func (cm *ConnManager) getNewAddress() (*net.TCPAddr, error) {
for tries := 0; tries < 100; tries++ {
addr := cm.cfg.AddrManager.GetAddress()
if addr == nil {
break
}
// Check if there's already a connection to the same address.
netAddr := addr.NetAddress().TCPAddress()
if cm.isAddressUsed(netAddr) {
continue
}
// Address will not be invalid, local or unroutable
// because addrmanager rejects those on addition.
// Just check that we don't already have an address
// in the same group so that we are not connecting
// to the same network segment at the expense of
// others.
//
// Networks that accept unroutable connections are exempt
// from this rule, since they're meant to run within a
// private subnet, like 10.0.0.0/16.
if !config.ActiveConfig().NetParams().AcceptUnroutable && cm.isOutboundGroupUsed(netAddr) {
continue
}
// only allow recent nodes (10mins) after we failed 30
// times
if tries < 30 && time.Since(addr.LastAttempt()) < 10*time.Minute {
continue
}
// allow nondefault ports after 50 failed tries.
if tries < 50 && fmt.Sprintf("%d", netAddr.Port) !=
config.ActiveConfig().NetParams().DefaultPort {
continue
}
return netAddr, nil
}
return nil, ErrNoAddress
}
// New returns a new connection manager.
// Use Start to start connecting to the network.
func New(cfg *Config) (*ConnManager, error) {
if cfg.Dial == nil {
return nil, ErrDialNil
return nil, errors.WithStack(ErrDialNil)
}
if cfg.AddrManager == nil {
return nil, errors.WithStack(ErrAddressManagerNil)
}
// Default to sane values
if cfg.RetryDuration <= 0 {
cfg.RetryDuration = defaultRetryDuration
}
if cfg.TargetOutbound == 0 {
cfg.TargetOutbound = defaultTargetOutbound
}
cm := ConnManager{
cfg: *cfg, // Copy so caller can't mutate
requests: make(chan interface{}),
quit: make(chan struct{}),
cfg: *cfg, // Copy so caller can't mutate
requests: make(chan interface{}),
quit: make(chan struct{}),
usedAddresses: make(map[string]struct{}),
usedOutboundGroups: make(map[string]int64),
}
return &cm, nil
}

View File

@ -5,9 +5,15 @@
package connmgr
import (
"fmt"
"github.com/kaspanet/kaspad/addrmgr"
"github.com/kaspanet/kaspad/config"
"github.com/kaspanet/kaspad/dagconfig"
"github.com/pkg/errors"
"io"
"io/ioutil"
"net"
"os"
"sync/atomic"
"testing"
"time"
@ -70,13 +76,28 @@ func mockDialer(addr net.Addr) (net.Conn, error) {
// TestNewConfig tests that new ConnManager config is validated as expected.
func TestNewConfig(t *testing.T) {
restoreConfig := overrideActiveConfig()
defer restoreConfig()
_, err := New(&Config{})
if err == nil {
t.Fatalf("New expected error: 'Dial can't be nil', got nil")
if !errors.Is(err, ErrDialNil) {
t.Fatalf("New expected error: %s, got %s", ErrDialNil, err)
}
_, err = New(&Config{
Dial: mockDialer,
})
if !errors.Is(err, ErrAddressManagerNil) {
t.Fatalf("New expected error: %s, got %s", ErrAddressManagerNil, err)
}
amgr, teardown := addressManagerForTest(t, "TestNewConfig", 10)
defer teardown()
_, err = New(&Config{
Dial: mockDialer,
AddrManager: amgr,
})
if err != nil {
t.Fatalf("New unexpected error: %v", err)
}
@ -85,17 +106,19 @@ func TestNewConfig(t *testing.T) {
// TestStartStop tests that the connection manager starts and stops as
// expected.
func TestStartStop(t *testing.T) {
restoreConfig := overrideActiveConfig()
defer restoreConfig()
connected := make(chan *ConnReq)
disconnected := make(chan *ConnReq)
amgr, teardown := addressManagerForTest(t, "TestStartStop", 10)
defer teardown()
cmgr, err := New(&Config{
TargetOutbound: 1,
GetNewAddress: func() (net.Addr, error) {
return &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 18555,
}, nil
},
Dial: mockDialer,
AddrManager: amgr,
Dial: mockDialer,
OnConnection: func(c *ConnReq, conn net.Conn) {
connected <- c
},
@ -104,7 +127,7 @@ func TestStartStop(t *testing.T) {
},
})
if err != nil {
t.Fatalf("New error: %v", err)
t.Fatalf("unexpected error from New: %s", err)
}
cmgr.Start()
gotConnReq := <-connected
@ -119,7 +142,10 @@ func TestStartStop(t *testing.T) {
},
Permanent: true,
}
cmgr.Connect(cr)
err = cmgr.Connect(cr)
if err != nil {
t.Fatalf("Connect error: %s", err)
}
if cr.ID() != 0 {
t.Fatalf("start/stop: got id: %v, want: 0", cr.ID())
}
@ -133,21 +159,78 @@ func TestStartStop(t *testing.T) {
}
}
func overrideActiveConfig() func() {
originalActiveCfg := config.ActiveConfig()
config.SetActiveConfig(&config.Config{
Flags: &config.Flags{
NetworkFlags: config.NetworkFlags{
ActiveNetParams: &dagconfig.SimnetParams},
},
})
return func() {
// Give some extra time to all open NewConnReq goroutines
// to finish before restoring the active config to prevent
// potential panics.
time.Sleep(10 * time.Millisecond)
config.SetActiveConfig(originalActiveCfg)
}
}
func addressManagerForTest(t *testing.T, testName string, numAddresses uint8) (*addrmgr.AddrManager, func()) {
amgr, teardown := createEmptyAddressManagerForTest(t, testName)
for i := uint8(0); i < numAddresses; i++ {
ip := fmt.Sprintf("173.%d.115.66:16511", i)
err := amgr.AddAddressByIP(ip, nil)
if err != nil {
t.Fatalf("AddAddressByIP unexpectedly failed to add IP %s: %s", ip, err)
}
}
return amgr, teardown
}
func createEmptyAddressManagerForTest(t *testing.T, testName string) (*addrmgr.AddrManager, func()) {
path, err := ioutil.TempDir("", fmt.Sprintf("%s-addressmanager", testName))
if err != nil {
t.Fatalf("createEmptyAddressManagerForTest: TempDir unexpectedly "+
"failed: %s", err)
}
return addrmgr.New(path, nil, nil), func() {
// Wait for the connection manager to finish
time.Sleep(10 * time.Millisecond)
err := os.RemoveAll(path)
if err != nil {
t.Fatalf("couldn't remove path %s", path)
}
}
}
// TestConnectMode tests that the connection manager works in the connect mode.
//
// In connect mode, automatic connections are disabled, so we test that
// requests using Connect are handled and that no other connections are made.
func TestConnectMode(t *testing.T) {
restoreConfig := overrideActiveConfig()
defer restoreConfig()
connected := make(chan *ConnReq)
amgr, teardown := addressManagerForTest(t, "TestConnectMode", 10)
defer teardown()
cmgr, err := New(&Config{
TargetOutbound: 2,
TargetOutbound: 0,
Dial: mockDialer,
OnConnection: func(c *ConnReq, conn net.Conn) {
connected <- c
},
AddrManager: amgr,
})
if err != nil {
t.Fatalf("New error: %v", err)
t.Fatalf("unexpected error from New: %s", err)
}
cr := &ConnReq{
Addr: &net.TCPAddr{
@ -176,6 +259,7 @@ func TestConnectMode(t *testing.T) {
break
}
cmgr.Stop()
cmgr.Wait()
}
// TestTargetOutbound tests the target number of outbound connections.
@ -183,23 +267,26 @@ func TestConnectMode(t *testing.T) {
// We wait until all connections are established, then test they there are the
// only connections made.
func TestTargetOutbound(t *testing.T) {
targetOutbound := uint32(10)
restoreConfig := overrideActiveConfig()
defer restoreConfig()
const numAddressesInAddressManager = 10
targetOutbound := uint32(numAddressesInAddressManager - 2)
connected := make(chan *ConnReq)
amgr, teardown := addressManagerForTest(t, "TestTargetOutbound", 10)
defer teardown()
cmgr, err := New(&Config{
TargetOutbound: targetOutbound,
Dial: mockDialer,
GetNewAddress: func() (net.Addr, error) {
return &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 18555,
}, nil
},
AddrManager: amgr,
OnConnection: func(c *ConnReq, conn net.Conn) {
connected <- c
},
})
if err != nil {
t.Fatalf("New error: %v", err)
t.Fatalf("unexpected error from New: %s", err)
}
cmgr.Start()
for i := uint32(0); i < targetOutbound; i++ {
@ -213,6 +300,146 @@ func TestTargetOutbound(t *testing.T) {
break
}
cmgr.Stop()
cmgr.Wait()
}
// TestDuplicateOutboundConnections tests that connection requests cannot use an already used address.
// It checks it by creating one connection request for each address in the address manager, so that
// the next connection request will have to fail because no unused address will be available.
func TestDuplicateOutboundConnections(t *testing.T) {
restoreConfig := overrideActiveConfig()
defer restoreConfig()
const numAddressesInAddressManager = 10
targetOutbound := uint32(numAddressesInAddressManager - 1)
connected := make(chan struct{})
failedConnections := make(chan struct{})
amgr, teardown := addressManagerForTest(t, "TestDuplicateOutboundConnections", 10)
defer teardown()
cmgr, err := New(&Config{
TargetOutbound: targetOutbound,
Dial: mockDialer,
AddrManager: amgr,
OnConnection: func(c *ConnReq, conn net.Conn) {
connected <- struct{}{}
},
OnConnectionFailed: func(_ *ConnReq) {
failedConnections <- struct{}{}
},
})
if err != nil {
t.Fatalf("unexpected error from New: %s", err)
}
cmgr.Start()
for i := uint32(0); i < targetOutbound; i++ {
<-connected
}
time.Sleep(time.Millisecond)
// Here we check that making a manual connection request beyond the target outbound connection
// doesn't fail, so we can know that the reason such connection request will fail is an address
// related issue.
cmgr.NewConnReq()
select {
case <-connected:
break
case <-time.After(time.Millisecond):
t.Fatalf("connection request unexpectedly didn't connect")
}
select {
case <-failedConnections:
t.Fatalf("a connection request unexpectedly failed")
case <-time.After(time.Millisecond):
break
}
// After we created numAddressesInAddressManager connection requests, this request should fail
// because there aren't any more available addresses.
cmgr.NewConnReq()
select {
case <-connected:
t.Fatalf("connection request unexpectedly succeeded")
case <-time.After(time.Millisecond):
t.Fatalf("connection request didn't fail as expected")
case <-failedConnections:
break
}
cmgr.Stop()
cmgr.Wait()
}
// TestSameOutboundGroupConnections tests that connection requests cannot use an address with an already used
// address CIDR group.
// It checks it by creating an address manager with only two addresses, that both belong to the same CIDR group
// and checks that the second connection request fails.
func TestSameOutboundGroupConnections(t *testing.T) {
restoreConfig := overrideActiveConfig()
defer restoreConfig()
amgr, teardown := createEmptyAddressManagerForTest(t, "TestSameOutboundGroupConnections")
defer teardown()
err := amgr.AddAddressByIP("173.190.115.66:16511", nil)
if err != nil {
t.Fatalf("AddAddressByIP unexpectedly failed: %s", err)
}
err = amgr.AddAddressByIP("173.190.115.67:16511", nil)
if err != nil {
t.Fatalf("AddAddressByIP unexpectedly failed: %s", err)
}
connected := make(chan struct{})
failedConnections := make(chan struct{})
cmgr, err := New(&Config{
TargetOutbound: 0,
Dial: mockDialer,
AddrManager: amgr,
OnConnection: func(c *ConnReq, conn net.Conn) {
connected <- struct{}{}
},
OnConnectionFailed: func(_ *ConnReq) {
failedConnections <- struct{}{}
},
})
if err != nil {
t.Fatalf("unexpected error from New: %s", err)
}
cmgr.Start()
cmgr.NewConnReq()
select {
case <-connected:
break
case <-time.After(time.Millisecond):
t.Fatalf("connection request unexpectedly didn't connect")
}
select {
case <-failedConnections:
t.Fatalf("a connection request unexpectedly failed")
case <-time.After(time.Millisecond):
break
}
cmgr.NewConnReq()
select {
case <-connected:
t.Fatalf("connection request unexpectedly succeeded")
case <-time.After(time.Millisecond):
t.Fatalf("connection request didn't fail as expected")
case <-failedConnections:
break
}
cmgr.Stop()
cmgr.Wait()
}
// TestRetryPermanent tests that permanent connection requests are retried.
@ -220,11 +447,18 @@ func TestTargetOutbound(t *testing.T) {
// We make a permanent connection request using Connect, disconnect it using
// Disconnect and we wait for it to be connected back.
func TestRetryPermanent(t *testing.T) {
restoreConfig := overrideActiveConfig()
defer restoreConfig()
connected := make(chan *ConnReq)
disconnected := make(chan *ConnReq)
amgr, teardown := addressManagerForTest(t, "TestRetryPermanent", 10)
defer teardown()
cmgr, err := New(&Config{
RetryDuration: time.Millisecond,
TargetOutbound: 1,
TargetOutbound: 0,
Dial: mockDialer,
OnConnection: func(c *ConnReq, conn net.Conn) {
connected <- c
@ -232,9 +466,10 @@ func TestRetryPermanent(t *testing.T) {
OnDisconnection: func(c *ConnReq) {
disconnected <- c
},
AddrManager: amgr,
})
if err != nil {
t.Fatalf("New error: %v", err)
t.Fatalf("unexpected error from New: %s", err)
}
cr := &ConnReq{
@ -289,6 +524,9 @@ func TestRetryPermanent(t *testing.T) {
cmgr.Remove(cr.ID())
gotConnReq = <-disconnected
// Wait for status to be updated
time.Sleep(10 * time.Millisecond)
wantID = cr.ID()
gotID = gotConnReq.ID()
if gotID != wantID {
@ -300,6 +538,7 @@ func TestRetryPermanent(t *testing.T) {
t.Fatalf("retry: %v - want state %v, got state %v", cr.Addr, wantState, gotState)
}
cmgr.Stop()
cmgr.Wait()
}
// TestMaxRetryDuration tests the maximum retry duration.
@ -307,6 +546,9 @@ func TestRetryPermanent(t *testing.T) {
// We have a timed dialer which initially returns err but after RetryDuration
// hits maxRetryDuration returns a mock conn.
func TestMaxRetryDuration(t *testing.T) {
restoreConfig := overrideActiveConfig()
defer restoreConfig()
networkUp := make(chan struct{})
time.AfterFunc(5*time.Millisecond, func() {
close(networkUp)
@ -320,6 +562,9 @@ func TestMaxRetryDuration(t *testing.T) {
}
}
amgr, teardown := addressManagerForTest(t, "TestMaxRetryDuration", 10)
defer teardown()
connected := make(chan *ConnReq)
cmgr, err := New(&Config{
RetryDuration: time.Millisecond,
@ -328,9 +573,10 @@ func TestMaxRetryDuration(t *testing.T) {
OnConnection: func(c *ConnReq, conn net.Conn) {
connected <- c
},
AddrManager: amgr,
})
if err != nil {
t.Fatalf("New error: %v", err)
t.Fatalf("unexpected error from New: %s", err)
}
cr := &ConnReq{
@ -350,35 +596,40 @@ func TestMaxRetryDuration(t *testing.T) {
case <-time.Tick(100 * time.Millisecond):
t.Fatalf("max retry duration: connection timeout")
}
cmgr.Stop()
cmgr.Wait()
}
// TestNetworkFailure tests that the connection manager handles a network
// failure gracefully.
func TestNetworkFailure(t *testing.T) {
restoreConfig := overrideActiveConfig()
defer restoreConfig()
var dials uint32
errDialer := func(net net.Addr) (net.Conn, error) {
atomic.AddUint32(&dials, 1)
return nil, errors.New("network down")
}
amgr, teardown := addressManagerForTest(t, "TestNetworkFailure", 10)
defer teardown()
cmgr, err := New(&Config{
TargetOutbound: 5,
RetryDuration: 5 * time.Millisecond,
Dial: errDialer,
GetNewAddress: func() (net.Addr, error) {
return &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 18555,
}, nil
},
AddrManager: amgr,
OnConnection: func(c *ConnReq, conn net.Conn) {
t.Fatalf("network failure: got unexpected connection - %v", c.Addr)
},
})
if err != nil {
t.Fatalf("New error: %v", err)
t.Fatalf("unexpected error from New: %s", err)
}
cmgr.Start()
time.AfterFunc(10*time.Millisecond, cmgr.Stop)
time.Sleep(10 * time.Millisecond)
cmgr.Stop()
cmgr.Wait()
wantMaxDials := uint32(75)
if atomic.LoadUint32(&dials) > wantMaxDials {
@ -394,17 +645,25 @@ func TestNetworkFailure(t *testing.T) {
// err so that the handler assumes that the conn manager is stopped and ignores
// the failure.
func TestStopFailed(t *testing.T) {
restoreConfig := overrideActiveConfig()
defer restoreConfig()
done := make(chan struct{}, 1)
waitDialer := func(addr net.Addr) (net.Conn, error) {
done <- struct{}{}
time.Sleep(time.Millisecond)
return nil, errors.New("network down")
}
amgr, teardown := addressManagerForTest(t, "TestStopFailed", 10)
defer teardown()
cmgr, err := New(&Config{
Dial: waitDialer,
Dial: waitDialer,
AddrManager: amgr,
})
if err != nil {
t.Fatalf("New error: %v", err)
t.Fatalf("unexpected error from New: %s", err)
}
cmgr.Start()
go func() {
@ -428,6 +687,9 @@ func TestStopFailed(t *testing.T) {
// TestRemovePendingConnection tests that it's possible to cancel a pending
// connection, removing its internal state from the ConnMgr.
func TestRemovePendingConnection(t *testing.T) {
restoreConfig := overrideActiveConfig()
defer restoreConfig()
// Create a ConnMgr instance with an instance of a dialer that'll never
// succeed.
wait := make(chan struct{})
@ -435,11 +697,16 @@ func TestRemovePendingConnection(t *testing.T) {
<-wait
return nil, errors.Errorf("error")
}
amgr, teardown := addressManagerForTest(t, "TestRemovePendingConnection", 10)
defer teardown()
cmgr, err := New(&Config{
Dial: indefiniteDialer,
Dial: indefiniteDialer,
AddrManager: amgr,
})
if err != nil {
t.Fatalf("New error: %v", err)
t.Fatalf("unexpected error from New: %s", err)
}
cmgr.Start()
@ -474,12 +741,16 @@ func TestRemovePendingConnection(t *testing.T) {
close(wait)
cmgr.Stop()
cmgr.Wait()
}
// TestCancelIgnoreDelayedConnection tests that a canceled connection request will
// not execute the on connection callback, even if an outstanding retry
// succeeds.
func TestCancelIgnoreDelayedConnection(t *testing.T) {
restoreConfig := overrideActiveConfig()
defer restoreConfig()
retryTimeout := 10 * time.Millisecond
// Setup a dialer that will continue to return an error until the
@ -497,18 +768,22 @@ func TestCancelIgnoreDelayedConnection(t *testing.T) {
}
connected := make(chan *ConnReq)
amgr, teardown := addressManagerForTest(t, "TestCancelIgnoreDelayedConnection", 10)
defer teardown()
cmgr, err := New(&Config{
Dial: failingDialer,
RetryDuration: retryTimeout,
OnConnection: func(c *ConnReq, conn net.Conn) {
connected <- c
},
AddrManager: amgr,
})
if err != nil {
t.Fatalf("New error: %v", err)
t.Fatalf("unexpected error from New: %s", err)
}
cmgr.Start()
defer cmgr.Stop()
// Establish a connection request to a random IP we've chosen.
cr := &ConnReq{
@ -552,7 +827,8 @@ func TestCancelIgnoreDelayedConnection(t *testing.T) {
t.Fatalf("on-connect should not be called for canceled req")
case <-time.After(5 * retryTimeout):
}
cmgr.Stop()
cmgr.Wait()
}
// mockListener implements the net.Listener interface and is used to test
@ -617,21 +893,29 @@ func newMockListener(localAddr string) *mockListener {
// TestListeners ensures providing listeners to the connection manager along
// with an accept callback works properly.
func TestListeners(t *testing.T) {
restoreConfig := overrideActiveConfig()
defer restoreConfig()
// Setup a connection manager with a couple of mock listeners that
// notify a channel when they receive mock connections.
receivedConns := make(chan net.Conn)
listener1 := newMockListener("127.0.0.1:16111")
listener2 := newMockListener("127.0.0.1:9333")
listeners := []net.Listener{listener1, listener2}
amgr, teardown := addressManagerForTest(t, "TestListeners", 10)
defer teardown()
cmgr, err := New(&Config{
Listeners: listeners,
OnAccept: func(conn net.Conn) {
receivedConns <- conn
},
Dial: mockDialer,
Dial: mockDialer,
AddrManager: amgr,
})
if err != nil {
t.Fatalf("New error: %v", err)
t.Fatalf("unexpected error from New: %s", err)
}
cmgr.Start()

View File

@ -8,7 +8,6 @@ package p2p
import (
"crypto/rand"
"encoding/binary"
"fmt"
"math"
"net"
"runtime"
@ -150,7 +149,6 @@ type peerState struct {
outboundPeers map[int32]*Peer
persistentPeers map[int32]*Peer
banned map[string]time.Time
outboundGroups map[string]int
}
// Count returns the count of all known peers.
@ -665,9 +663,6 @@ func (s *Server) handleDonePeerMsg(state *peerState, sp *Peer) {
list = state.outboundPeers
}
if _, ok := list[sp.ID()]; ok {
if !sp.Inbound() && sp.VersionKnown() {
state.outboundGroups[addrmgr.GroupKey(sp.NA())]--
}
if !sp.Inbound() && sp.connReq != nil {
s.connManager.Disconnect(sp.connReq.ID())
}
@ -785,11 +780,6 @@ type GetPeersMsg struct {
Reply chan []*Peer
}
type getOutboundGroup struct {
key string
reply chan int
}
//GetManualNodesMsg is the message type which is used by the rpc server to get the list of persistent peers from the p2p server
type GetManualNodesMsg struct {
Reply chan []*Peer
@ -843,15 +833,15 @@ func (s *Server) handleQuery(state *peerState, querymsg interface{}) {
// TODO: duplicate oneshots?
// Limit max number of total peers.
if state.countOutboundPeers() >= config.ActiveConfig().TargetOutboundPeers {
msg.Reply <- connmgr.ErrMaxOutboundPeers
msg.Reply <- errors.WithStack(connmgr.ErrMaxOutboundPeers)
return
}
for _, peer := range state.persistentPeers {
if peer.Addr() == msg.Addr {
if msg.Permanent {
msg.Reply <- connmgr.ErrAlreadyConnected
msg.Reply <- errors.WithStack(connmgr.ErrAlreadyConnected)
} else {
msg.Reply <- connmgr.ErrAlreadyPermanent
msg.Reply <- errors.WithStack(connmgr.ErrAlreadyPermanent)
}
return
}
@ -872,23 +862,12 @@ func (s *Server) handleQuery(state *peerState, querymsg interface{}) {
})
msg.Reply <- nil
case RemoveNodeMsg:
found := disconnectPeer(state.persistentPeers, msg.Cmp, func(sp *Peer) {
// Keep group counts ok since we remove from
// the list now.
state.outboundGroups[addrmgr.GroupKey(sp.NA())]--
})
found := disconnectPeer(state.persistentPeers, msg.Cmp)
if found {
msg.Reply <- nil
} else {
msg.Reply <- connmgr.ErrPeerNotFound
}
case getOutboundGroup:
count, ok := state.outboundGroups[msg.key]
if ok {
msg.reply <- count
} else {
msg.reply <- 0
msg.Reply <- errors.WithStack(connmgr.ErrPeerNotFound)
}
// Request a list of the persistent (added) peers.
case GetManualNodesMsg:
@ -901,32 +880,26 @@ func (s *Server) handleQuery(state *peerState, querymsg interface{}) {
case DisconnectNodeMsg:
// Check inbound peers. We pass a nil callback since we don't
// require any additional actions on disconnect for inbound peers.
found := disconnectPeer(state.inboundPeers, msg.Cmp, nil)
found := disconnectPeer(state.inboundPeers, msg.Cmp)
if found {
msg.Reply <- nil
return
}
// Check outbound peers.
found = disconnectPeer(state.outboundPeers, msg.Cmp, func(sp *Peer) {
// Keep group counts ok since we remove from
// the list now.
state.outboundGroups[addrmgr.GroupKey(sp.NA())]--
})
found = disconnectPeer(state.outboundPeers, msg.Cmp)
if found {
// If there are multiple outbound connections to the same
// ip:port, continue disconnecting them all until no such
// peers are found.
for found {
found = disconnectPeer(state.outboundPeers, msg.Cmp, func(sp *Peer) {
state.outboundGroups[addrmgr.GroupKey(sp.NA())]--
})
found = disconnectPeer(state.outboundPeers, msg.Cmp)
}
msg.Reply <- nil
return
}
msg.Reply <- connmgr.ErrPeerNotFound
msg.Reply <- errors.WithStack(connmgr.ErrPeerNotFound)
}
}
@ -937,13 +910,9 @@ func (s *Server) handleQuery(state *peerState, querymsg interface{}) {
// to be located. If the peer is found, and the passed callback: `whenFound'
// isn't nil, we call it with the peer as the argument before it is removed
// from the peerList, and is disconnected from the server.
func disconnectPeer(peerList map[int32]*Peer, compareFunc func(*Peer) bool, whenFound func(*Peer)) bool {
func disconnectPeer(peerList map[int32]*Peer, compareFunc func(*Peer) bool) bool {
for addr, peer := range peerList {
if compareFunc(peer) {
if whenFound != nil {
whenFound(peer)
}
// This is ok because we are not continuing
// to iterate so won't corrupt the loop.
delete(peerList, addr)
@ -1026,7 +995,6 @@ func (s *Server) outboundPeerConnected(state *peerState, msg *outboundPeerConnec
s.peerDoneHandler(sp)
})
s.addrManager.Attempt(sp.NA())
state.outboundGroups[addrmgr.GroupKey(sp.NA())]++
}
// outboundPeerConnected is invoked by the connection manager when a new
@ -1097,7 +1065,6 @@ func (s *Server) peerHandler() {
persistentPeers: make(map[int32]*Peer),
outboundPeers: make(map[int32]*Peer),
banned: make(map[string]time.Time),
outboundGroups: make(map[string]int),
}
if !config.ActiveConfig().DisableDNSSeed {
@ -1226,14 +1193,6 @@ func (s *Server) ConnectedCount() int32 {
return <-replyChan
}
// OutboundGroupCount returns the number of peers connected to the given
// outbound group key.
func (s *Server) OutboundGroupCount(key string) int {
replyChan := make(chan int)
s.Query <- getOutboundGroup{key: key, reply: replyChan}
return <-replyChan
}
// AddBytesSent adds the passed number of bytes to the total bytes sent counter
// for the server. It is safe for concurrent access.
func (s *Server) AddBytesSent(bytesSent uint64) {
@ -1602,57 +1561,6 @@ func NewServer(listenAddrs []string, dagParams *dagconfig.Params, interrupt <-ch
return nil, err
}
// Only setup a function to return new addresses to connect to when
// not running in connect-only mode. The simulation network is always
// in connect-only mode since it is only intended to connect to
// specified peers and actively avoid advertising and connecting to
// discovered peers in order to prevent it from becoming a public test
// network.
var newAddressFunc func() (net.Addr, error)
if !config.ActiveConfig().Simnet && len(config.ActiveConfig().ConnectPeers) == 0 {
newAddressFunc = func() (net.Addr, error) {
for tries := 0; tries < 100; tries++ {
addr := s.addrManager.GetAddress()
if addr == nil {
break
}
// Address will not be invalid, local or unroutable
// because addrmanager rejects those on addition.
// Just check that we don't already have an address
// in the same group so that we are not connecting
// to the same network segment at the expense of
// others.
//
// Networks that accept unroutable connections are exempt
// from this rule, since they're meant to run within a
// private subnet, like 10.0.0.0/16.
if !config.ActiveConfig().NetParams().AcceptUnroutable {
key := addrmgr.GroupKey(addr.NetAddress())
if s.OutboundGroupCount(key) != 0 {
continue
}
}
// only allow recent nodes (10mins) after we failed 30
// times
if tries < 30 && time.Since(addr.LastAttempt()) < 10*time.Minute {
continue
}
// allow nondefault ports after 50 failed tries.
if tries < 50 && fmt.Sprintf("%d", addr.NetAddress().Port) !=
config.ActiveConfig().NetParams().DefaultPort {
continue
}
addrString := addrmgr.NetAddressKey(addr.NetAddress())
return addrStringToNetAddr(addrString)
}
return nil, connmgr.ErrNoAddress
}
}
// Create a connection manager.
cmgr, err := connmgr.New(&connmgr.Config{
Listeners: listeners,
@ -1671,7 +1579,7 @@ func NewServer(listenAddrs []string, dagParams *dagconfig.Params, interrupt <-ch
connReq: c,
}
},
GetNewAddress: newAddressFunc,
AddrManager: s.addrManager,
})
if err != nil {
return nil, err
@ -1782,7 +1690,7 @@ func initListeners(amgr *addrmgr.AddrManager, listenAddrs []string, services wir
// a net.Addr which maps to the original address with any host names resolved
// to IP addresses. It also handles tor addresses properly by returning a
// net.Addr that encapsulates the address.
func addrStringToNetAddr(addr string) (net.Addr, error) {
func addrStringToNetAddr(addr string) (*net.TCPAddr, error) {
host, strPort, err := net.SplitHostPort(addr)
if err != nil {
return nil, err

View File

@ -48,6 +48,14 @@ func (na *NetAddress) AddService(service ServiceFlag) {
na.Services |= service
}
// TCPAddress converts the NetAddress to *net.TCPAddr
func (na *NetAddress) TCPAddress() *net.TCPAddr {
return &net.TCPAddr{
IP: na.IP,
Port: int(na.Port),
}
}
// NewNetAddressIPPort returns a new NetAddress using the provided IP, port, and
// supported services with defaults for the remaining fields.
func NewNetAddressIPPort(ip net.IP, port uint16, services ServiceFlag) *NetAddress {