From d835f72e74af66629874f2522fbb5110685fc845 Mon Sep 17 00:00:00 2001 From: stasatdaglabs <39559713+stasatdaglabs@users.noreply.github.com> Date: Sun, 14 Feb 2021 19:08:06 +0200 Subject: [PATCH] Make AddressManager persistent (#1525) * Move existing address/bannedAddress functionality to a new addressStore object. * Implement TestAddressManager. * Implement serializeAddressKey and deserializeAddressKey. * Implement serializeNetAddress and deserializeNetAddress. * Store addresses and banned addresses to disk. * Implement restoreNotBannedAddresses. * Fix bannedDatabaseKey. * Implement restoreBannedAddresses. * Implement TestRestoreAddressManager. * Defer closing the database in TestRestoreAddressManager. * Defer closing the database in TestRestoreAddressManager. * Add a log. * Return errors where appropriate. Co-authored-by: Elichai Turkel --- app/component_manager.go | 2 +- .../flows/addressexchange/receiveaddresses.go | 3 +- app/protocol/flows/handshake/handshake.go | 5 +- .../network/addressmanager/addressmanager.go | 114 +++++---- .../addressmanager/addressmanager_test.go | 210 +++++++++++++++- .../network/addressmanager/network_test.go | 4 +- .../network/addressmanager/store.go | 225 ++++++++++++++++++ .../network/addressmanager/store_test.go | 45 ++++ .../network/addressmanager/test_utils.go | 3 +- .../network/connmanager/connmanager.go | 6 +- 10 files changed, 538 insertions(+), 79 deletions(-) create mode 100644 infrastructure/network/addressmanager/store.go create mode 100644 infrastructure/network/addressmanager/store_test.go diff --git a/app/component_manager.go b/app/component_manager.go index e2f20beae..f8481dd05 100644 --- a/app/component_manager.go +++ b/app/component_manager.go @@ -90,7 +90,7 @@ func NewComponentManager(cfg *config.Config, db infrastructuredatabase.Database, return nil, err } - addressManager, err := addressmanager.New(addressmanager.NewConfig(cfg)) + addressManager, err := addressmanager.New(addressmanager.NewConfig(cfg), db) if err != nil { return nil, err } diff --git a/app/protocol/flows/addressexchange/receiveaddresses.go b/app/protocol/flows/addressexchange/receiveaddresses.go index 366f350b0..654c823ff 100644 --- a/app/protocol/flows/addressexchange/receiveaddresses.go +++ b/app/protocol/flows/addressexchange/receiveaddresses.go @@ -35,6 +35,5 @@ func ReceiveAddresses(context ReceiveAddressesContext, incomingRoute *router.Rou return protocolerrors.Errorf(true, "address count exceeded %d", addressmanager.GetAddressesMax) } - context.AddressManager().AddAddresses(msgAddresses.AddressList...) - return nil + return context.AddressManager().AddAddresses(msgAddresses.AddressList...) } diff --git a/app/protocol/flows/handshake/handshake.go b/app/protocol/flows/handshake/handshake.go index aa6c5f686..b3a332961 100644 --- a/app/protocol/flows/handshake/handshake.go +++ b/app/protocol/flows/handshake/handshake.go @@ -89,7 +89,10 @@ func HandleHandshake(context HandleHandshakeContext, netConnection *netadapter.N } if peerAddress != nil { - context.AddressManager().AddAddresses(peerAddress) + err := context.AddressManager().AddAddresses(peerAddress) + if err != nil { + return nil, err + } } return peer, nil } diff --git a/infrastructure/network/addressmanager/addressmanager.go b/infrastructure/network/addressmanager/addressmanager.go index df01497b4..1bd13a075 100644 --- a/infrastructure/network/addressmanager/addressmanager.go +++ b/infrastructure/network/addressmanager/addressmanager.go @@ -5,6 +5,7 @@ package addressmanager import ( + "github.com/kaspanet/kaspad/infrastructure/db/database" "github.com/kaspanet/kaspad/util/mstime" "net" "sync" @@ -58,67 +59,70 @@ func netAddressesKeys(netAddresses []*appmessage.NetAddress) map[addressKey]bool // AddressManager provides a concurrency safe address manager for caching potential // peers on the Kaspa network. type AddressManager struct { - addresses map[addressKey]*appmessage.NetAddress - bannedAddresses map[ipv6]*appmessage.NetAddress - localAddresses *localAddressManager - mutex sync.Mutex - cfg *Config - random addressRandomizer + store *addressStore + localAddresses *localAddressManager + mutex sync.Mutex + cfg *Config + random addressRandomizer } // New returns a new Kaspa address manager. -func New(cfg *Config) (*AddressManager, error) { +func New(cfg *Config, database database.Database) (*AddressManager, error) { + addressStore, err := newAddressStore(database) + if err != nil { + return nil, err + } localAddresses, err := newLocalAddressManager(cfg) if err != nil { return nil, err } return &AddressManager{ - addresses: map[addressKey]*appmessage.NetAddress{}, - bannedAddresses: map[ipv6]*appmessage.NetAddress{}, - localAddresses: localAddresses, - random: NewAddressRandomize(), - cfg: cfg, + store: addressStore, + localAddresses: localAddresses, + random: NewAddressRandomize(), + cfg: cfg, }, nil } -func (am *AddressManager) addAddressNoLock(address *appmessage.NetAddress) { +func (am *AddressManager) addAddressNoLock(address *appmessage.NetAddress) error { if !IsRoutable(address, am.cfg.AcceptUnroutable) { - return + return nil } key := netAddressKey(address) - _, ok := am.addresses[key] - if !ok { - am.addresses[key] = address - } + return am.store.add(key, address) } // AddAddress adds address to the address manager -func (am *AddressManager) AddAddress(address *appmessage.NetAddress) { +func (am *AddressManager) AddAddress(address *appmessage.NetAddress) error { am.mutex.Lock() defer am.mutex.Unlock() - am.addAddressNoLock(address) + return am.addAddressNoLock(address) } // AddAddresses adds addresses to the address manager -func (am *AddressManager) AddAddresses(addresses ...*appmessage.NetAddress) { +func (am *AddressManager) AddAddresses(addresses ...*appmessage.NetAddress) error { am.mutex.Lock() defer am.mutex.Unlock() for _, address := range addresses { - am.addAddressNoLock(address) + err := am.addAddressNoLock(address) + if err != nil { + return err + } } + return nil } // RemoveAddress removes addresses from the address manager -func (am *AddressManager) RemoveAddress(address *appmessage.NetAddress) { +func (am *AddressManager) RemoveAddress(address *appmessage.NetAddress) error { am.mutex.Lock() defer am.mutex.Unlock() key := netAddressKey(address) - delete(am.addresses, key) + return am.store.remove(key) } // Addresses returns all addresses @@ -126,12 +130,7 @@ func (am *AddressManager) Addresses() []*appmessage.NetAddress { am.mutex.Lock() defer am.mutex.Unlock() - result := make([]*appmessage.NetAddress, 0, len(am.addresses)) - for _, address := range am.addresses { - result = append(result, address) - } - - return result + return am.store.getAllNotBanned() } // BannedAddresses returns all banned addresses @@ -139,28 +138,15 @@ func (am *AddressManager) BannedAddresses() []*appmessage.NetAddress { am.mutex.Lock() defer am.mutex.Unlock() - result := make([]*appmessage.NetAddress, 0, len(am.bannedAddresses)) - for _, address := range am.bannedAddresses { - result = append(result, address) - } - - return result + return am.store.getAllBanned() } // notBannedAddressesWithException returns all not banned addresses with excpetion func (am *AddressManager) notBannedAddressesWithException(exceptions []*appmessage.NetAddress) []*appmessage.NetAddress { - exceptionsKeys := netAddressesKeys(exceptions) am.mutex.Lock() defer am.mutex.Unlock() - result := make([]*appmessage.NetAddress, 0, len(am.addresses)) - for key, address := range am.addresses { - if !exceptionsKeys[key] { - result = append(result, address) - } - } - - return result + return am.store.getAllNotBannedWithout(exceptions) } // RandomAddress returns a random address that isn't banned and isn't in exceptions @@ -182,23 +168,26 @@ func (am *AddressManager) BestLocalAddress(remoteAddress *appmessage.NetAddress) } // Ban marks the given address as banned -func (am *AddressManager) Ban(addressToBan *appmessage.NetAddress) { +func (am *AddressManager) Ban(addressToBan *appmessage.NetAddress) error { am.mutex.Lock() defer am.mutex.Unlock() keyToBan := netAddressKey(addressToBan) keysToDelete := make([]addressKey, 0) - for _, address := range am.addresses { + for _, address := range am.store.getAllNotBanned() { key := netAddressKey(address) if key.address.equal(keyToBan.address) { keysToDelete = append(keysToDelete, key) } } for _, key := range keysToDelete { - delete(am.addresses, key) + err := am.store.remove(key) + if err != nil { + return err + } } - am.bannedAddresses[keyToBan.address] = addressToBan + return am.store.addBanned(keyToBan, addressToBan) } // Unban unmarks the given address as banned @@ -207,14 +196,12 @@ func (am *AddressManager) Unban(address *appmessage.NetAddress) error { defer am.mutex.Unlock() key := netAddressKey(address) - _, ok := am.bannedAddresses[key.address] - if !ok { + if !am.store.isBanned(key) { return errors.Wrapf(ErrAddressNotFound, "address %s "+ "is not registered with the address manager as banned", address.TCPAddress()) } - delete(am.bannedAddresses, key.address) - return nil + return am.store.removeBanned(key) } // IsBanned returns true if the given address is marked as banned @@ -223,9 +210,12 @@ func (am *AddressManager) IsBanned(address *appmessage.NetAddress) (bool, error) defer am.mutex.Unlock() key := netAddressKey(address) - am.unbanIfOldEnough(key.address) - if _, ok := am.bannedAddresses[key.address]; !ok { - if _, ok = am.addresses[key]; !ok { + err := am.unbanIfOldEnough(key) + if err != nil { + return false, err + } + if !am.store.isBanned(key) { + if !am.store.isNotBanned(key) { return false, errors.Wrapf(ErrAddressNotFound, "address %s "+ "is not registered with the address manager", address.TCPAddress()) } @@ -236,14 +226,18 @@ func (am *AddressManager) IsBanned(address *appmessage.NetAddress) (bool, error) } -func (am *AddressManager) unbanIfOldEnough(ipv6Address ipv6) { - address, ok := am.bannedAddresses[ipv6Address] +func (am *AddressManager) unbanIfOldEnough(key addressKey) error { + address, ok := am.store.getBanned(key) if !ok { - return + return nil } const maxBanTime = 24 * time.Hour if mstime.Since(address.Timestamp) > maxBanTime { - delete(am.bannedAddresses, ipv6Address) + err := am.store.removeBanned(key) + if err != nil { + return err + } } + return nil } diff --git a/infrastructure/network/addressmanager/addressmanager_test.go b/infrastructure/network/addressmanager/addressmanager_test.go index 1a5564a5f..14bb6f1fc 100644 --- a/infrastructure/network/addressmanager/addressmanager_test.go +++ b/infrastructure/network/addressmanager/addressmanager_test.go @@ -5,23 +5,32 @@ package addressmanager import ( - "net" - "testing" - "github.com/kaspanet/kaspad/app/appmessage" + "github.com/kaspanet/kaspad/infrastructure/db/database/ldb" + "github.com/kaspanet/kaspad/util/mstime" + "net" + "reflect" + "testing" "github.com/kaspanet/kaspad/infrastructure/config" ) -func newAddrManagerForTest(t *testing.T, testName string) (addressManager *AddressManager, teardown func()) { +func newAddressManagerForTest(t *testing.T, testName string) (addressManager *AddressManager, teardown func()) { cfg := config.DefaultConfig() - addressManager, err := New(NewConfig(cfg)) + datadir := t.TempDir() + database, err := ldb.NewLevelDB(datadir, 8) if err != nil { - t.Fatalf("error creating address manager: %s", err) + t.Fatalf("%s: could not create a database: %s", testName, err) + } + + addressManager, err = New(NewConfig(cfg), database) + if err != nil { + t.Fatalf("%s: error creating address manager: %s", testName, err) } return addressManager, func() { + database.Close() } } @@ -66,7 +75,7 @@ func TestBestLocalAddress(t *testing.T) { }, } - amgr, teardown := newAddrManagerForTest(t, "TestGetBestLocalAddress") + amgr, teardown := newAddressManagerForTest(t, "TestGetBestLocalAddress") defer teardown() // Test against default when there's no address @@ -107,3 +116,190 @@ func TestBestLocalAddress(t *testing.T) { } } } + +func TestAddressManager(t *testing.T) { + addressManager, teardown := newAddressManagerForTest(t, "TestAddressManager") + defer teardown() + + testAddress1 := &appmessage.NetAddress{IP: net.ParseIP("1.2.3.4"), Timestamp: mstime.Now()} + testAddress2 := &appmessage.NetAddress{IP: net.ParseIP("5.6.8.8"), Timestamp: mstime.Now()} + testAddress3 := &appmessage.NetAddress{IP: net.ParseIP("9.0.1.2"), Timestamp: mstime.Now()} + testAddresses := []*appmessage.NetAddress{testAddress1, testAddress2, testAddress3} + + // Add a few addresses + err := addressManager.AddAddresses(testAddresses...) + if err != nil { + t.Fatalf("AddAddresses() failed: %s", err) + } + + // Make sure that all the addresses are returned + addresses := addressManager.Addresses() + if len(testAddresses) != len(addresses) { + t.Fatalf("Unexpected amount of addresses returned from Addresses. "+ + "Want: %d, got: %d", len(testAddresses), len(addresses)) + } + for _, testAddress := range testAddresses { + found := false + for _, address := range addresses { + if reflect.DeepEqual(testAddress, address) { + found = true + break + } + } + if !found { + t.Fatalf("Address %s not returned from Addresses().", testAddress.IP) + } + } + + // Remove an address + addressToRemove := testAddress2 + err = addressManager.RemoveAddress(addressToRemove) + if err != nil { + t.Fatalf("RemoveAddress() failed: %s", err) + } + + // Make sure that the removed address is not returned + addresses = addressManager.Addresses() + if len(addresses) != len(testAddresses)-1 { + t.Fatalf("Unexpected amount of addresses returned from Addresses(). "+ + "Want: %d, got: %d", len(addresses), len(testAddresses)-1) + } + for _, address := range addresses { + if reflect.DeepEqual(addressToRemove, address) { + t.Fatalf("Removed addresses %s returned from Addresses()", addressToRemove.IP) + } + } + + // Add that address back + err = addressManager.AddAddress(addressToRemove) + if err != nil { + t.Fatalf("AddAddress() failed: %s", err) + } + + // Ban a different address + addressToBan := testAddress3 + err = addressManager.Ban(addressToBan) + if err != nil { + t.Fatalf("Ban() failed: %s", err) + } + + // Make sure that the banned address is not returned + addresses = addressManager.Addresses() + if len(addresses) != len(testAddresses)-1 { + t.Fatalf("Unexpected amount of addresses returned from Addresses(). "+ + "Want: %d, got: %d", len(addresses), len(testAddresses)-1) + } + for _, address := range addresses { + if reflect.DeepEqual(addressToBan, address) { + t.Fatalf("Banned addresses %s returned from Addresses()", addressToBan.IP) + } + } + + // Check that the address is banned + isBanned, err := addressManager.IsBanned(addressToBan) + if err != nil { + t.Fatalf("IsBanned() failed: %s", err) + } + if !isBanned { + t.Fatalf("Adderss %s is unexpectedly not banned", addressToBan.IP) + } + + // Check that BannedAddresses() returns the banned address + bannedAddresses := addressManager.BannedAddresses() + if len(bannedAddresses) != 1 { + t.Fatalf("Unexpected amount of addresses returned from BannedAddresses(). "+ + "Want: %d, got: %d", 1, len(bannedAddresses)) + } + if !reflect.DeepEqual(addressToBan, bannedAddresses[0]) { + t.Fatalf("Banned address %s not returned from BannedAddresses()", addressToBan.IP) + } + + // Unban the address + err = addressManager.Unban(addressToBan) + if err != nil { + t.Fatalf("Unban() failed: %s", err) + } + + // Check that BannedAddresses() not longer returns the banned address + bannedAddresses = addressManager.BannedAddresses() + if len(bannedAddresses) != 0 { + t.Fatalf("Unexpected amount of addresses returned from BannedAddresses(). "+ + "Want: %d, got: %d", 0, len(bannedAddresses)) + } +} + +func TestRestoreAddressManager(t *testing.T) { + cfg := config.DefaultConfig() + + // Create an empty database + datadir := t.TempDir() + database, err := ldb.NewLevelDB(datadir, 8) + if err != nil { + t.Fatalf("Could not create a database: %s", err) + } + defer database.Close() + + // Create an addressManager with the empty database + addressManager, err := New(NewConfig(cfg), database) + if err != nil { + t.Fatalf("Error creating address manager: %s", err) + } + + testAddress1 := &appmessage.NetAddress{IP: net.ParseIP("1.2.3.4"), Timestamp: mstime.Now()} + testAddress2 := &appmessage.NetAddress{IP: net.ParseIP("5.6.8.8"), Timestamp: mstime.Now()} + testAddress3 := &appmessage.NetAddress{IP: net.ParseIP("9.0.1.2"), Timestamp: mstime.Now()} + testAddresses := []*appmessage.NetAddress{testAddress1, testAddress2, testAddress3} + + // Add some addresses + err = addressManager.AddAddresses(testAddresses...) + if err != nil { + t.Fatalf("AddAddresses() failed: %s", err) + } + + // Ban one of the addresses + addressToBan := testAddress1 + err = addressManager.Ban(addressToBan) + if err != nil { + t.Fatalf("Ban() failed: %s", err) + } + + // Close the database + err = database.Close() + if err != nil { + t.Fatalf("Close() failed: %s", err) + } + + // Reopen the database + database, err = ldb.NewLevelDB(datadir, 8) + if err != nil { + t.Fatalf("Could not create a database: %s", err) + } + + // Recreate an addressManager with a the previous database + addressManager, err = New(NewConfig(cfg), database) + if err != nil { + t.Fatalf("Error creating address manager: %s", err) + } + + // Make sure that Addresses() returns the correct addresses + addresses := addressManager.Addresses() + if len(addresses) != len(testAddresses)-1 { + t.Fatalf("Unexpected amount of addresses returned from Addresses(). "+ + "Want: %d, got: %d", len(addresses), len(testAddresses)-1) + } + for _, address := range addresses { + if reflect.DeepEqual(addressToBan, address) { + t.Fatalf("Banned addresses %s returned from Addresses()", addressToBan.IP) + } + } + + // Make sure that BannedAddresses() returns the correct addresses + bannedAddresses := addressManager.BannedAddresses() + if len(bannedAddresses) != 1 { + t.Fatalf("Unexpected amount of addresses returned from BannedAddresses(). "+ + "Want: %d, got: %d", 1, len(bannedAddresses)) + } + if !reflect.DeepEqual(addressToBan, bannedAddresses[0]) { + t.Fatalf("Banned address %s not returned from BannedAddresses()", addressToBan.IP) + } +} diff --git a/infrastructure/network/addressmanager/network_test.go b/infrastructure/network/addressmanager/network_test.go index 6d47062ff..9904db233 100644 --- a/infrastructure/network/addressmanager/network_test.go +++ b/infrastructure/network/addressmanager/network_test.go @@ -14,7 +14,7 @@ import ( // TestIPTypes ensures the various functions which determine the type of an IP // address based on RFCs work as intended. func TestIPTypes(t *testing.T) { - amgr, teardown := newAddrManagerForTest(t, "TestAddAddressByIP") + amgr, teardown := newAddressManagerForTest(t, "TestAddAddressByIP") defer teardown() type ipTest struct { in appmessage.NetAddress @@ -146,7 +146,7 @@ func TestIPTypes(t *testing.T) { // TestGroupKey tests the GroupKey function to ensure it properly groups various // IP addresses. func TestGroupKey(t *testing.T) { - amgr, teardown := newAddrManagerForTest(t, "TestAddAddressByIP") + amgr, teardown := newAddressManagerForTest(t, "TestAddAddressByIP") defer teardown() tests := []struct { diff --git a/infrastructure/network/addressmanager/store.go b/infrastructure/network/addressmanager/store.go new file mode 100644 index 000000000..49c3cd606 --- /dev/null +++ b/infrastructure/network/addressmanager/store.go @@ -0,0 +1,225 @@ +package addressmanager + +import ( + "encoding/binary" + "github.com/kaspanet/kaspad/app/appmessage" + "github.com/kaspanet/kaspad/infrastructure/db/database" + "github.com/kaspanet/kaspad/util/mstime" + "net" +) + +var notBannedAddressBucket = database.MakeBucket([]byte("not-banned-addresses")) +var bannedAddressBucket = database.MakeBucket([]byte("banned-addresses")) + +type addressStore struct { + database database.Database + notBannedAddresses map[addressKey]*appmessage.NetAddress + bannedAddresses map[ipv6]*appmessage.NetAddress +} + +func newAddressStore(database database.Database) (*addressStore, error) { + addressStore := &addressStore{ + database: database, + notBannedAddresses: map[addressKey]*appmessage.NetAddress{}, + bannedAddresses: map[ipv6]*appmessage.NetAddress{}, + } + err := addressStore.restoreNotBannedAddresses() + if err != nil { + return nil, err + } + err = addressStore.restoreBannedAddresses() + if err != nil { + return nil, err + } + + log.Infof("Loaded %d addresses and %d banned addresses", + len(addressStore.notBannedAddresses), len(addressStore.bannedAddresses)) + + return addressStore, nil +} + +func (as *addressStore) restoreNotBannedAddresses() error { + cursor, err := as.database.Cursor(notBannedAddressBucket) + if err != nil { + return err + } + for ok := cursor.First(); ok; ok = cursor.Next() { + databaseKey, err := cursor.Key() + if err != nil { + return err + } + serializedKey := databaseKey.Suffix() + key := as.deserializeAddressKey(serializedKey) + + serializedNetAddress, err := cursor.Value() + if err != nil { + return err + } + netAddress := as.deserializeNetAddress(serializedNetAddress) + as.notBannedAddresses[key] = netAddress + } + return nil +} + +func (as *addressStore) restoreBannedAddresses() error { + cursor, err := as.database.Cursor(bannedAddressBucket) + if err != nil { + return err + } + for ok := cursor.First(); ok; ok = cursor.Next() { + databaseKey, err := cursor.Key() + if err != nil { + return err + } + var ipv6 ipv6 + copy(ipv6[:], databaseKey.Suffix()) + + serializedNetAddress, err := cursor.Value() + if err != nil { + return err + } + netAddress := as.deserializeNetAddress(serializedNetAddress) + as.bannedAddresses[ipv6] = netAddress + } + return nil +} + +func (as *addressStore) add(key addressKey, address *appmessage.NetAddress) error { + if _, ok := as.notBannedAddresses[key]; ok { + return nil + } + + as.notBannedAddresses[key] = address + + databaseKey := as.notBannedDatabaseKey(key) + serializedAddress := as.serializeNetAddress(address) + return as.database.Put(databaseKey, serializedAddress) +} + +func (as *addressStore) remove(key addressKey) error { + delete(as.notBannedAddresses, key) + + databaseKey := as.notBannedDatabaseKey(key) + return as.database.Delete(databaseKey) +} + +func (as *addressStore) getAllNotBanned() []*appmessage.NetAddress { + addresses := make([]*appmessage.NetAddress, 0, len(as.notBannedAddresses)) + for _, address := range as.notBannedAddresses { + addresses = append(addresses, address) + } + return addresses +} + +func (as *addressStore) getAllNotBannedWithout(ignoredAddresses []*appmessage.NetAddress) []*appmessage.NetAddress { + ignoredKeys := netAddressesKeys(ignoredAddresses) + + addresses := make([]*appmessage.NetAddress, 0, len(as.notBannedAddresses)) + for key, address := range as.notBannedAddresses { + if !ignoredKeys[key] { + addresses = append(addresses, address) + } + } + return addresses +} + +func (as *addressStore) isNotBanned(key addressKey) bool { + _, ok := as.notBannedAddresses[key] + return ok +} + +func (as *addressStore) addBanned(key addressKey, address *appmessage.NetAddress) error { + if _, ok := as.bannedAddresses[key.address]; ok { + return nil + } + + as.bannedAddresses[key.address] = address + + databaseKey := as.bannedDatabaseKey(key) + serializedAddress := as.serializeNetAddress(address) + return as.database.Put(databaseKey, serializedAddress) +} + +func (as *addressStore) removeBanned(key addressKey) error { + delete(as.bannedAddresses, key.address) + + databaseKey := as.bannedDatabaseKey(key) + return as.database.Delete(databaseKey) +} + +func (as *addressStore) getAllBanned() []*appmessage.NetAddress { + bannedAddresses := make([]*appmessage.NetAddress, 0, len(as.bannedAddresses)) + for _, bannedAddress := range as.bannedAddresses { + bannedAddresses = append(bannedAddresses, bannedAddress) + } + return bannedAddresses +} + +func (as *addressStore) isBanned(key addressKey) bool { + _, ok := as.bannedAddresses[key.address] + return ok +} + +func (as *addressStore) getBanned(key addressKey) (*appmessage.NetAddress, bool) { + bannedAddress, ok := as.bannedAddresses[key.address] + return bannedAddress, ok +} + +func (as *addressStore) notBannedDatabaseKey(key addressKey) *database.Key { + serializedKey := as.serializeAddressKey(key) + return notBannedAddressBucket.Key(serializedKey) +} + +func (as *addressStore) bannedDatabaseKey(key addressKey) *database.Key { + return bannedAddressBucket.Key(key.address[:]) +} + +func (as *addressStore) serializeAddressKey(key addressKey) []byte { + serializedSize := 16 + 2 // ipv6 + port + serializedKey := make([]byte, serializedSize) + + copy(serializedKey[:], key.address[:]) + binary.LittleEndian.PutUint16(serializedKey[16:], key.port) + + return serializedKey +} + +func (as *addressStore) deserializeAddressKey(serializedKey []byte) addressKey { + var ip ipv6 + copy(ip[:], serializedKey[:]) + + port := binary.LittleEndian.Uint16(serializedKey[16:]) + + return addressKey{ + port: port, + address: ip, + } +} + +func (as *addressStore) serializeNetAddress(netAddress *appmessage.NetAddress) []byte { + serializedSize := 16 + 2 + 8 + 8 // ipv6 + port + timestamp + services + serializedNetAddress := make([]byte, serializedSize) + + copy(serializedNetAddress[:], netAddress.IP[:]) + binary.LittleEndian.PutUint16(serializedNetAddress[16:], netAddress.Port) + binary.LittleEndian.PutUint64(serializedNetAddress[18:], uint64(netAddress.Timestamp.UnixMilliseconds())) + binary.LittleEndian.PutUint64(serializedNetAddress[26:], uint64(netAddress.Services)) + + return serializedNetAddress +} + +func (as *addressStore) deserializeNetAddress(serializedNetAddress []byte) *appmessage.NetAddress { + ip := make(net.IP, 16) + copy(ip[:], serializedNetAddress[:]) + + port := binary.LittleEndian.Uint16(serializedNetAddress[16:]) + timestamp := mstime.UnixMilliseconds(int64(binary.LittleEndian.Uint64(serializedNetAddress[18:]))) + services := appmessage.ServiceFlag(binary.LittleEndian.Uint64(serializedNetAddress[26:])) + + return &appmessage.NetAddress{ + IP: ip, + Port: port, + Timestamp: timestamp, + Services: services, + } +} diff --git a/infrastructure/network/addressmanager/store_test.go b/infrastructure/network/addressmanager/store_test.go new file mode 100644 index 000000000..5ea35ebc8 --- /dev/null +++ b/infrastructure/network/addressmanager/store_test.go @@ -0,0 +1,45 @@ +package addressmanager + +import ( + "github.com/kaspanet/kaspad/app/appmessage" + "github.com/kaspanet/kaspad/util/mstime" + "net" + "reflect" + "testing" +) + +func TestAddressKeySerialization(t *testing.T) { + addressManager, teardown := newAddressManagerForTest(t, "TestAddressKeySerialization") + defer teardown() + addressStore := addressManager.store + + testAddress := &appmessage.NetAddress{IP: net.ParseIP("2602:100:abcd::102"), Port: 12345} + testAddressKey := netAddressKey(testAddress) + + serializedTestAddressKey := addressStore.serializeAddressKey(testAddressKey) + deserializedTestAddressKey := addressStore.deserializeAddressKey(serializedTestAddressKey) + if !reflect.DeepEqual(testAddressKey, deserializedTestAddressKey) { + t.Fatalf("testAddressKey and deserializedTestAddressKey are not equal\n"+ + "testAddressKey:%+v\ndeserializedTestAddressKey:%+v", testAddressKey, deserializedTestAddressKey) + } +} + +func TestNetAddressSerialization(t *testing.T) { + addressManager, teardown := newAddressManagerForTest(t, "TestNetAddressSerialization") + defer teardown() + addressStore := addressManager.store + + testAddress := &appmessage.NetAddress{ + IP: net.ParseIP("2602:100:abcd::102"), + Port: 12345, + Timestamp: mstime.Now(), + Services: appmessage.ServiceFlag(6789), + } + + serializedTestNetAddress := addressStore.serializeNetAddress(testAddress) + deserializedTestNetAddress := addressStore.deserializeNetAddress(serializedTestNetAddress) + if !reflect.DeepEqual(testAddress, deserializedTestNetAddress) { + t.Fatalf("testAddress and deserializedTestNetAddress are not equal\n"+ + "testAddress:%+v\ndeserializedTestNetAddress:%+v", testAddress, deserializedTestNetAddress) + } +} diff --git a/infrastructure/network/addressmanager/test_utils.go b/infrastructure/network/addressmanager/test_utils.go index e86713997..e7301bde1 100644 --- a/infrastructure/network/addressmanager/test_utils.go +++ b/infrastructure/network/addressmanager/test_utils.go @@ -27,6 +27,5 @@ func AddAddressByIP(am *AddressManager, addressIP string, subnetworkID *external return errors.Errorf("invalid port %s: %s", portString, err) } netAddress := appmessage.NewNetAddressIPPort(ip, uint16(port), 0) - am.AddAddresses(netAddress) - return nil + return am.AddAddresses(netAddress) } diff --git a/infrastructure/network/connmanager/connmanager.go b/infrastructure/network/connmanager/connmanager.go index 7b3a91e81..4fa458754 100644 --- a/infrastructure/network/connmanager/connmanager.go +++ b/infrastructure/network/connmanager/connmanager.go @@ -137,8 +137,7 @@ func (c *ConnectionManager) Ban(netConnection *netadapter.NetConnection) error { return errors.Wrapf(ErrCannotBanPermanent, "Cannot ban %s because it's a permanent connection", netConnection.Address()) } - c.addressManager.Ban(netConnection.NetAddress()) - return nil + return c.addressManager.Ban(netConnection.NetAddress()) } // BanByIP bans the given IP and disconnects from all the connection with that IP. @@ -159,8 +158,7 @@ func (c *ConnectionManager) BanByIP(ip net.IP) error { } } - c.addressManager.Ban(appmessage.NewNetAddressIPPort(ip, 0, 0)) - return nil + return c.addressManager.Ban(appmessage.NewNetAddressIPPort(ip, 0, 0)) } // IsBanned returns whether the given netConnection is banned