Add a size limit to the address manager (#1652)

* Remove a random address from the address manager if it's full.

* Implement TestOverfillAddressManager.

* Add connectionFailedCount to addresses.

* Mark connection failures.

* Mark connection successes.

* Implement removing by most connection failures.

* Expand TestOverfillAddressManager.

* Add comments.

* Use a better method for finding the address with the greatest connectionFailedCount.

* Fix a comment.

* Compare addresses by IP in TestOverfillAddressManager.

* Add a comment for updateNotBanned.

Co-authored-by: Ori Newman <orinewman1@gmail.com>
This commit is contained in:
stasatdaglabs 2021-04-05 17:56:13 +03:00 committed by GitHub
parent 0be1bba408
commit a795a9e619
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 236 additions and 64 deletions

View File

@ -15,6 +15,8 @@ import (
"github.com/pkg/errors"
)
const maxAddresses = 4096
// addressRandomizer is the interface for the randomizer needed for the AddressManager.
type addressRandomizer interface {
RandomAddress(addresses []*appmessage.NetAddress) *appmessage.NetAddress
@ -27,6 +29,11 @@ type addressKey struct {
address ipv6
}
type address struct {
netAddress *appmessage.NetAddress
connectionFailedCount uint64
}
type ipv6 [net.IPv6len]byte
func (i ipv6) equal(other ipv6) bool {
@ -45,17 +52,6 @@ func netAddressKey(netAddress *appmessage.NetAddress) addressKey {
return key
}
// netAddressKeys returns a key of the ip address to use it in maps.
func netAddressesKeys(netAddresses []*appmessage.NetAddress) map[addressKey]bool {
result := make(map[addressKey]bool, len(netAddresses))
for _, netAddress := range netAddresses {
key := netAddressKey(netAddress)
result[key] = true
}
return result
}
// AddressManager provides a concurrency safe address manager for caching potential
// peers on the Kaspa network.
type AddressManager struct {
@ -85,13 +81,42 @@ func New(cfg *Config, database database.Database) (*AddressManager, error) {
}, nil
}
func (am *AddressManager) addAddressNoLock(address *appmessage.NetAddress) error {
if !IsRoutable(address, am.cfg.AcceptUnroutable) {
func (am *AddressManager) addAddressNoLock(netAddress *appmessage.NetAddress) error {
if !IsRoutable(netAddress, am.cfg.AcceptUnroutable) {
return nil
}
key := netAddressKey(netAddress)
address := &address{netAddress: netAddress, connectionFailedCount: 0}
err := am.store.add(key, address)
if err != nil {
return err
}
if am.store.notBannedCount() > maxAddresses {
allAddresses := am.store.getAllNotBanned()
maxConnectionFailedCount := uint64(0)
toRemove := allAddresses[0]
for _, address := range allAddresses[1:] {
if address.connectionFailedCount > maxConnectionFailedCount {
maxConnectionFailedCount = address.connectionFailedCount
toRemove = address
}
}
toRemoveKey := netAddressKey(toRemove.netAddress)
err := am.store.remove(toRemoveKey)
if err != nil {
return err
}
}
return nil
}
func (am *AddressManager) removeAddressNoLock(address *appmessage.NetAddress) error {
key := netAddressKey(address)
return am.store.add(key, address)
return am.store.remove(key)
}
// AddAddress adds address to the address manager
@ -121,8 +146,37 @@ func (am *AddressManager) RemoveAddress(address *appmessage.NetAddress) error {
am.mutex.Lock()
defer am.mutex.Unlock()
return am.removeAddressNoLock(address)
}
// MarkConnectionFailure notifies the address manager that the given address
// has failed to connect
func (am *AddressManager) MarkConnectionFailure(address *appmessage.NetAddress) error {
am.mutex.Lock()
defer am.mutex.Unlock()
key := netAddressKey(address)
return am.store.remove(key)
entry, ok := am.store.getNotBanned(key)
if !ok {
return errors.Errorf("address %s is not registered with the address manager", address.TCPAddress())
}
entry.connectionFailedCount = entry.connectionFailedCount + 1
return am.store.updateNotBanned(key, entry)
}
// MarkConnectionSuccess notifies the address manager that the given address
// has successfully connected
func (am *AddressManager) MarkConnectionSuccess(address *appmessage.NetAddress) error {
am.mutex.Lock()
defer am.mutex.Unlock()
key := netAddressKey(address)
entry, ok := am.store.getNotBanned(key)
if !ok {
return errors.Errorf("address %s is not registered with the address manager", address.TCPAddress())
}
entry.connectionFailedCount = 0
return am.store.updateNotBanned(key, entry)
}
// Addresses returns all addresses
@ -130,7 +184,7 @@ func (am *AddressManager) Addresses() []*appmessage.NetAddress {
am.mutex.Lock()
defer am.mutex.Unlock()
return am.store.getAllNotBanned()
return am.store.getAllNotBannedNetAddresses()
}
// BannedAddresses returns all banned addresses
@ -138,7 +192,7 @@ func (am *AddressManager) BannedAddresses() []*appmessage.NetAddress {
am.mutex.Lock()
defer am.mutex.Unlock()
return am.store.getAllBanned()
return am.store.getAllBannedNetAddresses()
}
// notBannedAddressesWithException returns all not banned addresses with excpetion
@ -146,7 +200,7 @@ func (am *AddressManager) notBannedAddressesWithException(exceptions []*appmessa
am.mutex.Lock()
defer am.mutex.Unlock()
return am.store.getAllNotBannedWithout(exceptions)
return am.store.getAllNotBannedNetAddressesWithout(exceptions)
}
// RandomAddress returns a random address that isn't banned and isn't in exceptions
@ -174,7 +228,7 @@ func (am *AddressManager) Ban(addressToBan *appmessage.NetAddress) error {
keyToBan := netAddressKey(addressToBan)
keysToDelete := make([]addressKey, 0)
for _, address := range am.store.getAllNotBanned() {
for _, address := range am.store.getAllNotBannedNetAddresses() {
key := netAddressKey(address)
if key.address.equal(keyToBan.address) {
keysToDelete = append(keysToDelete, key)
@ -187,7 +241,8 @@ func (am *AddressManager) Ban(addressToBan *appmessage.NetAddress) error {
}
}
return am.store.addBanned(keyToBan, addressToBan)
address := &address{netAddress: addressToBan}
return am.store.addBanned(keyToBan, address)
}
// Unban unmarks the given address as banned
@ -223,7 +278,6 @@ func (am *AddressManager) IsBanned(address *appmessage.NetAddress) (bool, error)
}
return true, nil
}
func (am *AddressManager) unbanIfOldEnough(key addressKey) error {
@ -233,7 +287,7 @@ func (am *AddressManager) unbanIfOldEnough(key addressKey) error {
}
const maxBanTime = 24 * time.Hour
if mstime.Since(address.Timestamp) > maxBanTime {
if mstime.Since(address.netAddress.Timestamp) > maxBanTime {
err := am.store.removeBanned(key)
if err != nil {
return err

View File

@ -303,3 +303,71 @@ func TestRestoreAddressManager(t *testing.T) {
t.Fatalf("Banned address %s not returned from BannedAddresses()", addressToBan.IP)
}
}
func TestOverfillAddressManager(t *testing.T) {
addressManager, teardown := newAddressManagerForTest(t, "TestAddressManager")
defer teardown()
generateTestAddresses := func(amount int) []*appmessage.NetAddress {
testAddresses := make([]*appmessage.NetAddress, 0, amount)
for i := byte(0); i < 128; i++ {
for j := byte(0); j < 128; j++ {
testAddress := &appmessage.NetAddress{IP: net.IP{1, 2, i, j}, Timestamp: mstime.Now()}
testAddresses = append(testAddresses, testAddress)
if len(testAddresses) == amount {
break
}
}
if len(testAddresses) == amount {
break
}
}
return testAddresses
}
// Add a single test address to the address manager
testAddress := &appmessage.NetAddress{IP: net.IP{5, 6, 0, 0}, Timestamp: mstime.Now()}
err := addressManager.AddAddress(testAddress)
if err != nil {
t.Fatalf("AddAddress: %s", err)
}
// Add `maxAddresses-1` addresses to the address manager
addresses := generateTestAddresses(maxAddresses - 1)
err = addressManager.AddAddresses(addresses...)
if err != nil {
t.Fatalf("AddAddresses: %s", err)
}
// Make sure that it now contains exactly `maxAddresses` entries
returnedAddresses := addressManager.Addresses()
if len(returnedAddresses) != maxAddresses {
t.Fatalf("Unexpected address amount. Want: %d, got: %d", maxAddresses, len(returnedAddresses))
}
// Mark the first test address as a connection failure
err = addressManager.MarkConnectionFailure(testAddress)
if err != nil {
t.Fatalf("MarkConnectionFailure: %s", err)
}
// Add one more address to the address manager
err = addressManager.AddAddress(&appmessage.NetAddress{IP: net.IP{7, 8, 0, 0}, Timestamp: mstime.Now()})
if err != nil {
t.Fatalf("AddAddress: %s", err)
}
// Make sure that it now still contains exactly `maxAddresses` entries
returnedAddresses = addressManager.Addresses()
if len(returnedAddresses) != maxAddresses {
t.Fatalf("Unexpected address amount. Want: %d, got: %d", maxAddresses, len(returnedAddresses))
}
// Make sure that the first address is no longer in the
// connection manager
for _, address := range returnedAddresses {
if address.IP.Equal(testAddress.IP) {
t.Fatalf("Unexpectedly found testAddress returned addresses")
}
}
}

View File

@ -5,6 +5,7 @@ import (
"github.com/kaspanet/kaspad/app/appmessage"
"github.com/kaspanet/kaspad/infrastructure/db/database"
"github.com/kaspanet/kaspad/util/mstime"
"github.com/pkg/errors"
"net"
)
@ -13,15 +14,15 @@ var bannedAddressBucket = database.MakeBucket([]byte("banned-addresses"))
type addressStore struct {
database database.Database
notBannedAddresses map[addressKey]*appmessage.NetAddress
bannedAddresses map[ipv6]*appmessage.NetAddress
notBannedAddresses map[addressKey]*address
bannedAddresses map[ipv6]*address
}
func newAddressStore(database database.Database) (*addressStore, error) {
addressStore := &addressStore{
database: database,
notBannedAddresses: map[addressKey]*appmessage.NetAddress{},
bannedAddresses: map[ipv6]*appmessage.NetAddress{},
notBannedAddresses: map[addressKey]*address{},
bannedAddresses: map[ipv6]*address{},
}
err := addressStore.restoreNotBannedAddresses()
if err != nil {
@ -56,7 +57,7 @@ func (as *addressStore) restoreNotBannedAddresses() error {
if err != nil {
return err
}
netAddress := as.deserializeNetAddress(serializedNetAddress)
netAddress := as.deserializeAddress(serializedNetAddress)
as.notBannedAddresses[key] = netAddress
}
return nil
@ -80,13 +81,17 @@ func (as *addressStore) restoreBannedAddresses() error {
if err != nil {
return err
}
netAddress := as.deserializeNetAddress(serializedNetAddress)
netAddress := as.deserializeAddress(serializedNetAddress)
as.bannedAddresses[ipv6] = netAddress
}
return nil
}
func (as *addressStore) add(key addressKey, address *appmessage.NetAddress) error {
func (as *addressStore) notBannedCount() int {
return len(as.notBannedAddresses)
}
func (as *addressStore) add(key addressKey, address *address) error {
if _, ok := as.notBannedAddresses[key]; ok {
return nil
}
@ -94,10 +99,28 @@ func (as *addressStore) add(key addressKey, address *appmessage.NetAddress) erro
as.notBannedAddresses[key] = address
databaseKey := as.notBannedDatabaseKey(key)
serializedAddress := as.serializeNetAddress(address)
serializedAddress := as.serializeAddress(address)
return as.database.Put(databaseKey, serializedAddress)
}
// updateNotBanned updates the not-banned address collection
func (as *addressStore) updateNotBanned(key addressKey, address *address) error {
if _, ok := as.notBannedAddresses[key]; !ok {
return errors.Errorf("address %s is not in the store", address.netAddress.TCPAddress())
}
as.notBannedAddresses[key] = address
databaseKey := as.notBannedDatabaseKey(key)
serializedAddress := as.serializeAddress(address)
return as.database.Put(databaseKey, serializedAddress)
}
func (as *addressStore) getNotBanned(key addressKey) (*address, bool) {
address, ok := as.notBannedAddresses[key]
return address, ok
}
func (as *addressStore) remove(key addressKey) error {
delete(as.notBannedAddresses, key)
@ -105,21 +128,29 @@ func (as *addressStore) remove(key addressKey) error {
return as.database.Delete(databaseKey)
}
func (as *addressStore) getAllNotBanned() []*appmessage.NetAddress {
addresses := make([]*appmessage.NetAddress, 0, len(as.notBannedAddresses))
func (as *addressStore) getAllNotBanned() []*address {
addresses := make([]*address, 0, len(as.notBannedAddresses))
for _, address := range as.notBannedAddresses {
addresses = append(addresses, address)
}
return addresses
}
func (as *addressStore) getAllNotBannedWithout(ignoredAddresses []*appmessage.NetAddress) []*appmessage.NetAddress {
func (as *addressStore) getAllNotBannedNetAddresses() []*appmessage.NetAddress {
addresses := make([]*appmessage.NetAddress, 0, len(as.notBannedAddresses))
for _, address := range as.notBannedAddresses {
addresses = append(addresses, address.netAddress)
}
return addresses
}
func (as *addressStore) getAllNotBannedNetAddressesWithout(ignoredAddresses []*appmessage.NetAddress) []*appmessage.NetAddress {
ignoredKeys := netAddressesKeys(ignoredAddresses)
addresses := make([]*appmessage.NetAddress, 0, len(as.notBannedAddresses))
for key, address := range as.notBannedAddresses {
if !ignoredKeys[key] {
addresses = append(addresses, address)
addresses = append(addresses, address.netAddress)
}
}
return addresses
@ -130,7 +161,7 @@ func (as *addressStore) isNotBanned(key addressKey) bool {
return ok
}
func (as *addressStore) addBanned(key addressKey, address *appmessage.NetAddress) error {
func (as *addressStore) addBanned(key addressKey, address *address) error {
if _, ok := as.bannedAddresses[key.address]; ok {
return nil
}
@ -138,7 +169,7 @@ func (as *addressStore) addBanned(key addressKey, address *appmessage.NetAddress
as.bannedAddresses[key.address] = address
databaseKey := as.bannedDatabaseKey(key)
serializedAddress := as.serializeNetAddress(address)
serializedAddress := as.serializeAddress(address)
return as.database.Put(databaseKey, serializedAddress)
}
@ -149,10 +180,10 @@ func (as *addressStore) removeBanned(key addressKey) error {
return as.database.Delete(databaseKey)
}
func (as *addressStore) getAllBanned() []*appmessage.NetAddress {
func (as *addressStore) getAllBannedNetAddresses() []*appmessage.NetAddress {
bannedAddresses := make([]*appmessage.NetAddress, 0, len(as.bannedAddresses))
for _, bannedAddress := range as.bannedAddresses {
bannedAddresses = append(bannedAddresses, bannedAddress)
bannedAddresses = append(bannedAddresses, bannedAddress.netAddress)
}
return bannedAddresses
}
@ -162,11 +193,21 @@ func (as *addressStore) isBanned(key addressKey) bool {
return ok
}
func (as *addressStore) getBanned(key addressKey) (*appmessage.NetAddress, bool) {
func (as *addressStore) getBanned(key addressKey) (*address, bool) {
bannedAddress, ok := as.bannedAddresses[key.address]
return bannedAddress, ok
}
// netAddressKeys returns a key of the ip address to use it in maps.
func netAddressesKeys(netAddresses []*appmessage.NetAddress) map[addressKey]bool {
result := make(map[addressKey]bool, len(netAddresses))
for _, netAddress := range netAddresses {
key := netAddressKey(netAddress)
result[key] = true
}
return result
}
func (as *addressStore) notBannedDatabaseKey(key addressKey) *database.Key {
serializedKey := as.serializeAddressKey(key)
return notBannedAddressBucket.Key(serializedKey)
@ -198,27 +239,32 @@ func (as *addressStore) deserializeAddressKey(serializedKey []byte) addressKey {
}
}
func (as *addressStore) serializeNetAddress(netAddress *appmessage.NetAddress) []byte {
serializedSize := 16 + 2 + 8 // ipv6 + port + timestamp
func (as *addressStore) serializeAddress(address *address) []byte {
serializedSize := 16 + 2 + 8 + 8 // ipv6 + port + timestamp + connectionFailedCount
serializedNetAddress := make([]byte, serializedSize)
copy(serializedNetAddress[:], netAddress.IP[:])
binary.LittleEndian.PutUint16(serializedNetAddress[16:], netAddress.Port)
binary.LittleEndian.PutUint64(serializedNetAddress[18:], uint64(netAddress.Timestamp.UnixMilliseconds()))
copy(serializedNetAddress[:], address.netAddress.IP[:])
binary.LittleEndian.PutUint16(serializedNetAddress[16:], address.netAddress.Port)
binary.LittleEndian.PutUint64(serializedNetAddress[18:], uint64(address.netAddress.Timestamp.UnixMilliseconds()))
binary.LittleEndian.PutUint64(serializedNetAddress[26:], uint64(address.connectionFailedCount))
return serializedNetAddress
}
func (as *addressStore) deserializeNetAddress(serializedNetAddress []byte) *appmessage.NetAddress {
func (as *addressStore) deserializeAddress(serializedAddress []byte) *address {
ip := make(net.IP, 16)
copy(ip[:], serializedNetAddress[:])
copy(ip[:], serializedAddress[:])
port := binary.LittleEndian.Uint16(serializedNetAddress[16:])
timestamp := mstime.UnixMilliseconds(int64(binary.LittleEndian.Uint64(serializedNetAddress[18:])))
port := binary.LittleEndian.Uint16(serializedAddress[16:])
timestamp := mstime.UnixMilliseconds(int64(binary.LittleEndian.Uint64(serializedAddress[18:])))
connectionFailedCount := binary.LittleEndian.Uint64(serializedAddress[26:])
return &appmessage.NetAddress{
IP: ip,
Port: port,
Timestamp: timestamp,
return &address{
netAddress: &appmessage.NetAddress{
IP: ip,
Port: port,
Timestamp: timestamp,
},
connectionFailedCount: connectionFailedCount,
}
}

View File

@ -24,21 +24,24 @@ func TestAddressKeySerialization(t *testing.T) {
}
}
func TestNetAddressSerialization(t *testing.T) {
addressManager, teardown := newAddressManagerForTest(t, "TestNetAddressSerialization")
func TestAddressSerialization(t *testing.T) {
addressManager, teardown := newAddressManagerForTest(t, "TestAddressSerialization")
defer teardown()
addressStore := addressManager.store
testAddress := &appmessage.NetAddress{
IP: net.ParseIP("2602:100:abcd::102"),
Port: 12345,
Timestamp: mstime.Now(),
testAddress := &address{
netAddress: &appmessage.NetAddress{
IP: net.ParseIP("2602:100:abcd::102"),
Port: 12345,
Timestamp: mstime.Now(),
},
connectionFailedCount: 98465,
}
serializedTestNetAddress := addressStore.serializeNetAddress(testAddress)
deserializedTestNetAddress := addressStore.deserializeNetAddress(serializedTestNetAddress)
if !reflect.DeepEqual(testAddress, deserializedTestNetAddress) {
t.Fatalf("testAddress and deserializedTestNetAddress are not equal\n"+
"testAddress:%+v\ndeserializedTestNetAddress:%+v", testAddress, deserializedTestNetAddress)
serializedTestAddress := addressStore.serializeAddress(testAddress)
deserializedTestAddress := addressStore.deserializeAddress(serializedTestAddress)
if !reflect.DeepEqual(testAddress, deserializedTestAddress) {
t.Fatalf("testAddress and deserializedTestAddress are not equal\n"+
"testAddress:%+v\ndeserializedTestAddress:%+v", testAddress, deserializedTestAddress)
}
}

View File

@ -42,9 +42,10 @@ func (c *ConnectionManager) checkOutgoingConnections(connSet connectionSet) {
err := c.initiateConnection(addressString)
if err != nil {
log.Infof("Couldn't connect to %s: %s", addressString, err)
c.addressManager.RemoveAddress(netAddress)
c.addressManager.MarkConnectionFailure(netAddress)
continue
}
c.addressManager.MarkConnectionSuccess(netAddress)
c.activeOutgoing[addressString] = struct{}{}
}