Compare commits

...

16 Commits

Author SHA1 Message Date
oudeis
eaf9117225 Update to version 0.7.3 2020-10-13 06:40:22 +00:00
Kirill
b88e34fd84 [NOD-1313] Refactor AddressManager (#918)
* [NOD-1313] Refactor AddressManager.

* [NOD-1313]Remove old tests.Minor improvements,fixes.

* [NOD-1313] After merge fixes. Fix import cycle.

* [NOD-1313] Integration tests fixes.

* [NOD-1313] Allocate new slice for the returned key.

* [NOD-1313] AddressManager improvements and fixes.

* Move local and banned addresses to separate lists.
* Move AddressManager config to the separate file.
* Add LocalAddressManager.
* Remove redundant KnownAddress structure.
* Restore local addresses functionality.
* Call initListeners from the LocalAddressManager.
* AddressManager minor improvements and fixes.

* [NOD-1313] Minor fixes.

* [NOD-1313] Implement HandleGetPeerAddresses. Refactoring.

* [NOD-1313] After-merge fixes.

* [NOD-1313] Minor improvements.

* AddressManager: added BannedAddresses() method.
* AddressManager: HandleGetPeerAddresses() add banned addresses
  separately.
* AddressManager: remove addressEntry redundant struct.
* ConnectionManager: checkOutgoingConnections() minor improvements and
  fixes.
* Minor refactoring.
* Minor fixes.

* [NOD-1313] GetPeerAddresses RPC message update

* GetPeerAddresses RPC: add BannedAddresses in the separate field.
* Update protobuf.
2020-10-08 17:05:47 +03:00
Ori Newman
689098082f [NOD-1444] Implement getHeaders RPC command (#944)
* [NOD-1444] Implement getHeaders RPC command

* [NOD-1444] Fix tests and comments

* [NOD-1444] Fix error message

* [NOD-1444] Make GetHeaders propagate header serialization errors

* [NOD-1444] RLock the dag on GetHeaders

* [NOD-1444] Change the error field number to 1000
2020-10-06 12:40:32 +03:00
Svarog
a359e2248b [NOD-1447] checkEntryAmounts should check against totalSompiInAfter, not totalSompiInBefore (#945)
* [NOD-1447] checkEntryAmounts should check against totalSompiInAfter, not totalSompiInBefore

* [NOD-1447] Remove lastSompiIn, and use totalSompiInBefore instead
2020-10-05 13:10:09 +03:00
Svarog
513ffa7e0c [NOD-1420] Restructure main (#942)
* [NOD-1420] Moved setting limits to executor

* [NOD-1420] Moved all code dealing with windows service to separate package

* [NOD-1420] Move practically all main to restructured app package

* [NOD-1420] Check for running as interactive only after checking if we are doing any service operation

* [NOD-1420] Add comments

* [NOD-1420] Add a comment
2020-10-01 08:28:16 +03:00
oudeis
ef6c46a231 Update to version 0.7.2 2020-09-30 14:07:18 +00:00
Svarog
22237a4a8d [NOD-1439] Added Stop command (#940)
* [NOD-1439] Added Stop command

* [NOD-1439] Added comment explaining why we wait before closing the StopChan

* [NOD-1439] Warnf -> Warn

* [NOD-1439] Rename Stop command to Shut Down

* [NOD-1439] Clean up pauseBeforeShutDown

* [NOD-1439] Add ShutDownRequestMessage case for toRPCPayload

* [NOD-1439] Minor stylistic changes
2020-09-29 15:59:47 +03:00
Ori Newman
6ab8ada9ff [NOD-1406] remove mempool utxo diff (#938)
* [NOD-1406] Remove mempool UTXO diff

* [NOD-1406] Fix mempool tests

* [NOD-1406] Fetch mempool transactions before locking the dag in NewBlockTemplate

* [NOD-1406] Remove redundant comment

* [NOD-1406] Move mempool UTXO set to a different file

* [NOD-1406] Fix transactionRelatedUTXOEntries receiver's name

* [NOD-1406] Fix variable names and fix comments

* [NOD-1406] Rename inputsWithUTXOEntries->referencedUTXOEntries

* [NOD-1406] Remove debug logs
2020-09-27 16:40:07 +03:00
Svarog
9a756939d8 [NOD-1412] Remove ffldb, remove dataStore from database, store blocks directly in levelDB (#939)
* [NOD-1412] Remove ffldb, and make ldb implement all the database
interfaces

* [NOD-1412] Removed any reference to dataStore and updated block dbaccess to work directly with key/values
2020-09-27 15:40:15 +03:00
Kirill
aea3baf897 [NOD-1320] Flush UTXOs to disk (#902)
* [NOD-1320] Flush UTXOs to disk.

* [NOD-1320] Minor improvements and fixes.

* FullUTXOSet: change size type from int64 to uint64.
* Rename FullUTXOSet.size to FullUTXOSet.estimatedSize
* Fill NewFullUTXOSetFromContext with db context on virtual block
  creation.
* Typo fixes.

* [NOD-1320] Stylystic fixes.

* [NOD-1320] Update tests. Improvements and fixes.

* Update blockdag/dag tests: prepare DB for tests.
* Update blockdag/utxoset tests: prepare DB for tests.
* Update blockdag/test_utils utils.
* Update blockdag/common tests.
* FullUTXOSet: remove embedded utxoCollection type. Rename
  utxoCollection to utxoCache.
* Fix blockdag/block_utxo genesisPastUTXO func.
* Minor fixes and improvements.
2020-09-27 13:16:11 +03:00
Ori Newman
f8d0f7f67a [NOD-1405] Add getMempoolEntries RPC command (#937)
* [NOD-1405] Add getMempoolEntries RPC command

* [NOD-1405] Remove redundant fields from GetMempoolEntryResponseMessage
2020-09-23 15:51:02 +03:00
stasatdaglabs
fed34273a1 [NOD-1404] Remove most of the notification manager to fix a deadlock (#936)
* [NOD-1404] Remove most of the notification manager to fix a deadlock.

* [NOD-1404] Rename a couple of fields.

* [NOD-1404] Fix merge errors.

* [NOD-1404] Remove most of the notification manager to fix a deadlock (#935)

* [NOD-1404] Remove most of the notification manager to fix a deadlock.

* [NOD-1404] Rename a couple of fields.
2020-09-23 14:00:05 +03:00
Ori Newman
34a1b30006 [NOD-1397] Add to rpc client ErrRPC (#934) 2020-09-17 19:26:03 +03:00
stasatdaglabs
798abf2103 [NOD-1395] Increase the grpcclient's max message size to match the grpcserver's. (#932) 2020-09-17 16:53:33 +03:00
stasatdaglabs
75e539f4d2 [NOD-1357] Implement getMempoolEntry (#931)
* [NOD-1357] Implement getMempoolEntry.

* [NOD-1357] Fix a nil point reference.

* [NOD-1357] Add a comment above BuildTransactionVerboseData.
2020-09-16 16:53:36 +03:00
oudeis
946e65d1c6 Update to version 0.7.1 2020-09-16 11:59:51 +00:00
132 changed files with 4638 additions and 7253 deletions

View File

@@ -2,259 +2,194 @@ package app
import (
"fmt"
"sync/atomic"
"os"
"path/filepath"
"runtime"
"time"
"github.com/kaspanet/kaspad/infrastructure/network/addressmanager"
"github.com/kaspanet/kaspad/infrastructure/network/netadapter/id"
"github.com/kaspanet/kaspad/app/appmessage"
"github.com/kaspanet/kaspad/app/protocol"
"github.com/kaspanet/kaspad/app/rpc"
"github.com/kaspanet/kaspad/domain/blockdag"
"github.com/kaspanet/kaspad/domain/blockdag/indexers"
"github.com/kaspanet/kaspad/domain/mempool"
"github.com/kaspanet/kaspad/domain/mining"
"github.com/kaspanet/kaspad/domain/txscript"
"github.com/kaspanet/kaspad/infrastructure/config"
"github.com/kaspanet/kaspad/infrastructure/db/dbaccess"
"github.com/kaspanet/kaspad/infrastructure/network/connmanager"
"github.com/kaspanet/kaspad/infrastructure/network/dnsseed"
"github.com/kaspanet/kaspad/infrastructure/network/netadapter"
"github.com/kaspanet/kaspad/util"
"github.com/kaspanet/kaspad/domain/blockdag/indexers"
"github.com/kaspanet/kaspad/infrastructure/os/signal"
"github.com/kaspanet/kaspad/util/profiling"
"github.com/kaspanet/kaspad/version"
"github.com/kaspanet/kaspad/util/panics"
"github.com/kaspanet/kaspad/infrastructure/config"
"github.com/kaspanet/kaspad/infrastructure/os/execenv"
"github.com/kaspanet/kaspad/infrastructure/os/limits"
"github.com/kaspanet/kaspad/infrastructure/os/winservice"
)
// App is a wrapper for all the kaspad services
type App struct {
cfg *config.Config
addressManager *addressmanager.AddressManager
protocolManager *protocol.Manager
rpcManager *rpc.Manager
connectionManager *connmanager.ConnectionManager
netAdapter *netadapter.NetAdapter
started, shutdown int32
var desiredLimits = &limits.DesiredLimits{
FileLimitWant: 2048,
FileLimitMin: 1024,
}
// Start launches all the kaspad services.
func (a *App) Start() {
// Already started?
if atomic.AddInt32(&a.started, 1) != 1 {
return
}
log.Trace("Starting kaspad")
err := a.netAdapter.Start()
if err != nil {
panics.Exit(log, fmt.Sprintf("Error starting the net adapter: %+v", err))
}
a.maybeSeedFromDNS()
a.connectionManager.Start()
var serviceDescription = &winservice.ServiceDescription{
Name: "kaspadsvc",
DisplayName: "Kaspad Service",
Description: "Downloads and stays synchronized with the Kaspa blockDAG and " +
"provides DAG services to applications.",
}
// Stop gracefully shuts down all the kaspad services.
func (a *App) Stop() {
// Make sure this only happens once.
if atomic.AddInt32(&a.shutdown, 1) != 1 {
log.Infof("Kaspad is already in the process of shutting down")
return
}
log.Warnf("Kaspad shutting down")
a.connectionManager.Stop()
err := a.netAdapter.Stop()
if err != nil {
log.Errorf("Error stopping the net adapter: %+v", err)
}
err = a.addressManager.Stop()
if err != nil {
log.Errorf("Error stopping address manager: %s", err)
}
return
type kaspadApp struct {
cfg *config.Config
}
// New returns a new App instance configured to listen on addr for the
// kaspa network type specified by dagParams. Use start to begin accepting
// connections from peers.
func New(cfg *config.Config, databaseContext *dbaccess.DatabaseContext, interrupt <-chan struct{}) (*App, error) {
indexManager, acceptanceIndex := setupIndexes(cfg)
// StartApp starts the kaspad app, and blocks until it finishes running
func StartApp() error {
execenv.Initialize(desiredLimits)
sigCache := txscript.NewSigCache(cfg.SigCacheMaxSize)
// Create a new block DAG instance with the appropriate configuration.
dag, err := setupDAG(cfg, databaseContext, interrupt, sigCache, indexManager)
// Load configuration and parse command line. This function also
// initializes logging and configures it accordingly.
cfg, err := config.LoadConfig()
if err != nil {
return nil, err
fmt.Fprint(os.Stderr, err)
return err
}
defer panics.HandlePanic(log, "MAIN", nil)
txMempool := setupMempool(cfg, dag, sigCache)
app := &kaspadApp{cfg: cfg}
netAdapter, err := netadapter.NewNetAdapter(cfg)
if err != nil {
return nil, err
}
addressManager, err := addressmanager.New(cfg, databaseContext)
if err != nil {
return nil, err
}
connectionManager, err := connmanager.New(cfg, netAdapter, addressManager)
if err != nil {
return nil, err
}
protocolManager, err := protocol.NewManager(cfg, dag, netAdapter, addressManager, txMempool, connectionManager)
if err != nil {
return nil, err
}
rpcManager := setupRPC(cfg, txMempool, dag, sigCache, netAdapter, protocolManager, connectionManager, addressManager, acceptanceIndex)
return &App{
cfg: cfg,
protocolManager: protocolManager,
rpcManager: rpcManager,
connectionManager: connectionManager,
netAdapter: netAdapter,
addressManager: addressManager,
}, nil
}
func setupRPC(
cfg *config.Config,
txMempool *mempool.TxPool,
dag *blockdag.BlockDAG,
sigCache *txscript.SigCache,
netAdapter *netadapter.NetAdapter,
protocolManager *protocol.Manager,
connectionManager *connmanager.ConnectionManager,
addressManager *addressmanager.AddressManager,
acceptanceIndex *indexers.AcceptanceIndex) *rpc.Manager {
blockTemplateGenerator := mining.NewBlkTmplGenerator(&mining.Policy{BlockMaxMass: cfg.BlockMaxMass}, txMempool, dag, sigCache)
rpcManager := rpc.NewManager(cfg, netAdapter, dag, protocolManager, connectionManager, blockTemplateGenerator, txMempool, addressManager, acceptanceIndex)
protocolManager.SetOnBlockAddedToDAGHandler(rpcManager.NotifyBlockAddedToDAG)
protocolManager.SetOnTransactionAddedToMempoolHandler(rpcManager.NotifyTransactionAddedToMempool)
dag.Subscribe(func(notification *blockdag.Notification) {
err := handleBlockDAGNotifications(notification, acceptanceIndex, rpcManager)
// Call serviceMain on Windows to handle running as a service. When
// the return isService flag is true, exit now since we ran as a
// service. Otherwise, just fall through to normal operation.
if runtime.GOOS == "windows" {
isService, err := winservice.WinServiceMain(app.main, serviceDescription, cfg)
if err != nil {
panic(err)
return err
}
})
return rpcManager
}
func handleBlockDAGNotifications(notification *blockdag.Notification,
acceptanceIndex *indexers.AcceptanceIndex, rpcManager *rpc.Manager) error {
switch notification.Type {
case blockdag.NTChainChanged:
if acceptanceIndex == nil {
if isService {
return nil
}
chainChangedNotificationData := notification.Data.(*blockdag.ChainChangedNotificationData)
err := rpcManager.NotifyChainChanged(chainChangedNotificationData.RemovedChainBlockHashes,
chainChangedNotificationData.AddedChainBlockHashes)
if err != nil {
return err
}
case blockdag.NTFinalityConflict:
finalityConflictNotificationData := notification.Data.(*blockdag.FinalityConflictNotificationData)
err := rpcManager.NotifyFinalityConflict(finalityConflictNotificationData.ViolatingBlockHash.String())
if err != nil {
return err
}
case blockdag.NTFinalityConflictResolved:
finalityConflictResolvedNotificationData := notification.Data.(*blockdag.FinalityConflictResolvedNotificationData)
err := rpcManager.NotifyFinalityConflictResolved(finalityConflictResolvedNotificationData.FinalityBlockHash.String())
}
return app.main(nil)
}
func (app *kaspadApp) main(startedChan chan<- struct{}) error {
// Get a channel that will be closed when a shutdown signal has been
// triggered either from an OS signal such as SIGINT (Ctrl+C) or from
// another subsystem such as the RPC server.
interrupt := signal.InterruptListener()
defer log.Info("Shutdown complete")
// Show version at startup.
log.Infof("Version %s", version.Version())
// Enable http profiling server if requested.
if app.cfg.Profile != "" {
profiling.Start(app.cfg.Profile, log)
}
// Perform upgrades to kaspad as new versions require it.
if err := doUpgrades(); err != nil {
log.Error(err)
return err
}
// Return now if an interrupt signal was triggered.
if signal.InterruptRequested(interrupt) {
return nil
}
if app.cfg.ResetDatabase {
err := removeDatabase(app.cfg)
if err != nil {
log.Error(err)
return err
}
}
// Open the database
databaseContext, err := openDB(app.cfg)
if err != nil {
log.Error(err)
return err
}
defer func() {
log.Infof("Gracefully shutting down the database...")
err := databaseContext.Close()
if err != nil {
log.Errorf("Failed to close the database: %s", err)
}
}()
// Return now if an interrupt signal was triggered.
if signal.InterruptRequested(interrupt) {
return nil
}
// Drop indexes and exit if requested.
if app.cfg.DropAcceptanceIndex {
if err := indexers.DropAcceptanceIndex(databaseContext); err != nil {
log.Errorf("%s", err)
return err
}
return nil
}
// Create componentManager and start it.
componentManager, err := NewComponentManager(app.cfg, databaseContext, interrupt)
if err != nil {
log.Errorf("Unable to start kaspad: %+v", err)
return err
}
defer func() {
log.Infof("Gracefully shutting down kaspad...")
shutdownDone := make(chan struct{})
go func() {
componentManager.Stop()
shutdownDone <- struct{}{}
}()
const shutdownTimeout = 2 * time.Minute
select {
case <-shutdownDone:
case <-time.After(shutdownTimeout):
log.Criticalf("Graceful shutdown timed out %s. Terminating...", shutdownTimeout)
}
log.Infof("Kaspad shutdown complete")
}()
componentManager.Start()
if startedChan != nil {
startedChan <- struct{}{}
}
// Wait until the interrupt signal is received from an OS signal or
// shutdown is requested through one of the subsystems such as the RPC
// server.
<-interrupt
return nil
}
func (a *App) maybeSeedFromDNS() {
if !a.cfg.DisableDNSSeed {
dnsseed.SeedFromDNS(a.cfg.NetParams(), a.cfg.DNSSeed, appmessage.SFNodeNetwork, false, nil,
a.cfg.Lookup, func(addresses []*appmessage.NetAddress) {
// Kaspad uses a lookup of the dns seeder here. Since seeder returns
// IPs of nodes and not its own IP, we can not know real IP of
// source. So we'll take first returned address as source.
a.addressManager.AddAddresses(addresses, addresses[0], nil)
})
}
if a.cfg.GRPCSeed != "" {
dnsseed.SeedFromGRPC(a.cfg.NetParams(), a.cfg.GRPCSeed, appmessage.SFNodeNetwork, false, nil,
func(addresses []*appmessage.NetAddress) {
a.addressManager.AddAddresses(addresses, addresses[0], nil)
})
}
}
func setupDAG(cfg *config.Config, databaseContext *dbaccess.DatabaseContext, interrupt <-chan struct{},
sigCache *txscript.SigCache, indexManager blockdag.IndexManager) (*blockdag.BlockDAG, error) {
dag, err := blockdag.New(&blockdag.Config{
Interrupt: interrupt,
DatabaseContext: databaseContext,
DAGParams: cfg.NetParams(),
TimeSource: blockdag.NewTimeSource(),
SigCache: sigCache,
IndexManager: indexManager,
SubnetworkID: cfg.SubnetworkID,
})
return dag, err
// doUpgrades performs upgrades to kaspad as new versions require it.
// currently it's a placeholder we got from kaspad upstream, that does nothing
func doUpgrades() error {
return nil
}
func setupIndexes(cfg *config.Config) (blockdag.IndexManager, *indexers.AcceptanceIndex) {
// Create indexes if needed.
var indexes []indexers.Indexer
var acceptanceIndex *indexers.AcceptanceIndex
if cfg.AcceptanceIndex {
log.Info("acceptance index is enabled")
acceptanceIndex = indexers.NewAcceptanceIndex()
indexes = append(indexes, acceptanceIndex)
}
// Create an index manager if any of the optional indexes are enabled.
if len(indexes) < 0 {
return nil, nil
}
indexManager := indexers.NewManager(indexes)
return indexManager, acceptanceIndex
// dbPath returns the path to the block database given a database type.
func databasePath(cfg *config.Config) string {
return filepath.Join(cfg.DataDir, "db")
}
func setupMempool(cfg *config.Config, dag *blockdag.BlockDAG, sigCache *txscript.SigCache) *mempool.TxPool {
mempoolConfig := mempool.Config{
Policy: mempool.Policy{
AcceptNonStd: cfg.RelayNonStd,
MaxOrphanTxs: cfg.MaxOrphanTxs,
MaxOrphanTxSize: config.DefaultMaxOrphanTxSize,
MinRelayTxFee: cfg.MinRelayTxFee,
MaxTxVersion: 1,
},
CalcSequenceLockNoLock: func(tx *util.Tx, utxoSet blockdag.UTXOSet) (*blockdag.SequenceLock, error) {
return dag.CalcSequenceLockNoLock(tx, utxoSet)
},
SigCache: sigCache,
DAG: dag,
}
return mempool.New(&mempoolConfig)
func removeDatabase(cfg *config.Config) error {
dbPath := databasePath(cfg)
return os.RemoveAll(dbPath)
}
// P2PNodeID returns the network ID associated with this App
func (a *App) P2PNodeID() *id.ID {
return a.netAdapter.ID()
}
// AddressManager returns the AddressManager associated with this App
func (a *App) AddressManager() *addressmanager.AddressManager {
return a.addressManager
func openDB(cfg *config.Config) (*dbaccess.DatabaseContext, error) {
dbPath := databasePath(cfg)
log.Infof("Loading database from '%s'", dbPath)
return dbaccess.New(dbPath)
}

View File

@@ -97,6 +97,12 @@ const (
CmdNotifyFinalityConflictsResponseMessage
CmdFinalityConflictNotificationMessage
CmdFinalityConflictResolvedNotificationMessage
CmdGetMempoolEntriesRequestMessage
CmdGetMempoolEntriesResponseMessage
CmdShutDownRequestMessage
CmdShutDownResponseMessage
CmdGetHeadersRequestMessage
CmdGetHeadersResponseMessage
)
// ProtocolMessageCommandToString maps all MessageCommands to their string representation
@@ -170,6 +176,10 @@ var RPCMessageCommandToString = map[MessageCommand]string{
CmdNotifyFinalityConflictsResponseMessage: "NotifyFinalityConflictsResponse",
CmdFinalityConflictNotificationMessage: "FinalityConflictNotification",
CmdFinalityConflictResolvedNotificationMessage: "FinalityConflictResolvedNotification",
CmdGetMempoolEntriesRequestMessage: "GetMempoolEntriesRequestMessage",
CmdGetMempoolEntriesResponseMessage: "GetMempoolEntriesResponseMessage",
CmdGetHeadersRequestMessage: "GetHeadersRequest",
CmdGetHeadersResponseMessage: "GetHeadersResponse",
}
// Message is an interface that describes a kaspa message. A type that

View File

@@ -18,7 +18,7 @@ func TestRequstIBDBlocks(t *testing.T) {
t.Errorf("NewHashFromStr: %v", err)
}
hashStr = "3ba27aa200b1cecaad478d2b00432346c3f1f3986da1afd33e506"
hashStr = "000000000003ba27aa200b1cecaad478d2b00432346c3f1f3986da1afd33e506"
highHash, err := daghash.NewHashFromStr(hashStr)
if err != nil {
t.Errorf("NewHashFromStr: %v", err)

View File

@@ -23,7 +23,7 @@ import (
func TestTx(t *testing.T) {
pver := ProtocolVersion
txIDStr := "3ba27aa200b1cecaad478d2b00432346c3f1f3986da1afd33e506"
txIDStr := "000000000003ba27aa200b1cecaad478d2b00432346c3f1f3986da1afd33e506"
txID, err := daghash.NewTxIDFromStr(txIDStr)
if err != nil {
t.Errorf("NewTxIDFromStr: %v", err)

View File

@@ -0,0 +1,45 @@
package appmessage
// GetHeadersRequestMessage is an appmessage corresponding to
// its respective RPC message
type GetHeadersRequestMessage struct {
baseMessage
StartHash string
Limit uint64
IsAscending bool
}
// Command returns the protocol command string for the message
func (msg *GetHeadersRequestMessage) Command() MessageCommand {
return CmdGetHeadersRequestMessage
}
// NewGetHeadersRequestMessage returns a instance of the message
func NewGetHeadersRequestMessage(startHash string, limit uint64, isAscending bool) *GetHeadersRequestMessage {
return &GetHeadersRequestMessage{
StartHash: startHash,
Limit: limit,
IsAscending: isAscending,
}
}
// GetHeadersResponseMessage is an appmessage corresponding to
// its respective RPC message
type GetHeadersResponseMessage struct {
baseMessage
Headers []string
Error *RPCError
}
// Command returns the protocol command string for the message
func (msg *GetHeadersResponseMessage) Command() MessageCommand {
return CmdGetHeadersResponseMessage
}
// NewGetHeadersResponseMessage returns a instance of the message
func NewGetHeadersResponseMessage(headers []string) *GetHeadersResponseMessage {
return &GetHeadersResponseMessage{
Headers: headers,
}
}

View File

@@ -0,0 +1,38 @@
package appmessage
// GetMempoolEntriesRequestMessage is an appmessage corresponding to
// its respective RPC message
type GetMempoolEntriesRequestMessage struct {
baseMessage
}
// Command returns the protocol command string for the message
func (msg *GetMempoolEntriesRequestMessage) Command() MessageCommand {
return CmdGetMempoolEntriesRequestMessage
}
// NewGetMempoolEntriesRequestMessage returns a instance of the message
func NewGetMempoolEntriesRequestMessage() *GetMempoolEntriesRequestMessage {
return &GetMempoolEntriesRequestMessage{}
}
// GetMempoolEntriesResponseMessage is an appmessage corresponding to
// its respective RPC message
type GetMempoolEntriesResponseMessage struct {
baseMessage
Entries []*MempoolEntry
Error *RPCError
}
// Command returns the protocol command string for the message
func (msg *GetMempoolEntriesResponseMessage) Command() MessageCommand {
return CmdGetMempoolEntriesResponseMessage
}
// NewGetMempoolEntriesResponseMessage returns a instance of the message
func NewGetMempoolEntriesResponseMessage(entries []*MempoolEntry) *GetMempoolEntriesResponseMessage {
return &GetMempoolEntriesResponseMessage{
Entries: entries,
}
}

View File

@@ -21,15 +21,28 @@ func NewGetMempoolEntryRequestMessage(txID string) *GetMempoolEntryRequestMessag
// its respective RPC message
type GetMempoolEntryResponseMessage struct {
baseMessage
Entry *MempoolEntry
Error *RPCError
}
// MempoolEntry represents a transaction in the mempool.
type MempoolEntry struct {
Fee uint64
TransactionVerboseData *TransactionVerboseData
}
// Command returns the protocol command string for the message
func (msg *GetMempoolEntryResponseMessage) Command() MessageCommand {
return CmdGetMempoolEntryResponseMessage
}
// NewGetMempoolEntryResponseMessage returns a instance of the message
func NewGetMempoolEntryResponseMessage() *GetMempoolEntryResponseMessage {
return &GetMempoolEntryResponseMessage{}
func NewGetMempoolEntryResponseMessage(fee uint64, transactionVerboseData *TransactionVerboseData) *GetMempoolEntryResponseMessage {
return &GetMempoolEntryResponseMessage{
Entry: &MempoolEntry{
Fee: fee,
TransactionVerboseData: transactionVerboseData,
},
}
}

View File

@@ -20,7 +20,8 @@ func NewGetPeerAddressesRequestMessage() *GetPeerAddressesRequestMessage {
// its respective RPC message
type GetPeerAddressesResponseMessage struct {
baseMessage
Addresses []*GetPeerAddressesKnownAddressMessage
Addresses []*GetPeerAddressesKnownAddressMessage
BannedAddresses []*GetPeerAddressesKnownAddressMessage
Error *RPCError
}
@@ -31,9 +32,10 @@ func (msg *GetPeerAddressesResponseMessage) Command() MessageCommand {
}
// NewGetPeerAddressesResponseMessage returns a instance of the message
func NewGetPeerAddressesResponseMessage(addresses []*GetPeerAddressesKnownAddressMessage) *GetPeerAddressesResponseMessage {
func NewGetPeerAddressesResponseMessage(addresses []*GetPeerAddressesKnownAddressMessage, bannedAddresses []*GetPeerAddressesKnownAddressMessage) *GetPeerAddressesResponseMessage {
return &GetPeerAddressesResponseMessage{
Addresses: addresses,
Addresses: addresses,
BannedAddresses: bannedAddresses,
}
}

View File

@@ -0,0 +1,34 @@
package appmessage
// ShutDownRequestMessage is an appmessage corresponding to
// its respective RPC message
type ShutDownRequestMessage struct {
baseMessage
}
// Command returns the protocol command string for the message
func (msg *ShutDownRequestMessage) Command() MessageCommand {
return CmdShutDownRequestMessage
}
// NewShutDownRequestMessage returns a instance of the message
func NewShutDownRequestMessage() *ShutDownRequestMessage {
return &ShutDownRequestMessage{}
}
// ShutDownResponseMessage is an appmessage corresponding to
// its respective RPC message
type ShutDownResponseMessage struct {
baseMessage
Error *RPCError
}
// Command returns the protocol command string for the message
func (msg *ShutDownResponseMessage) Command() MessageCommand {
return CmdShutDownResponseMessage
}
// NewShutDownResponseMessage returns a instance of the message
func NewShutDownResponseMessage() *ShutDownResponseMessage {
return &ShutDownResponseMessage{}
}

259
app/component_manager.go Normal file
View File

@@ -0,0 +1,259 @@
package app
import (
"fmt"
"sync/atomic"
"github.com/kaspanet/kaspad/infrastructure/network/addressmanager"
"github.com/kaspanet/kaspad/infrastructure/network/netadapter/id"
"github.com/kaspanet/kaspad/app/appmessage"
"github.com/kaspanet/kaspad/app/protocol"
"github.com/kaspanet/kaspad/app/rpc"
"github.com/kaspanet/kaspad/domain/blockdag"
"github.com/kaspanet/kaspad/domain/blockdag/indexers"
"github.com/kaspanet/kaspad/domain/mempool"
"github.com/kaspanet/kaspad/domain/mining"
"github.com/kaspanet/kaspad/domain/txscript"
"github.com/kaspanet/kaspad/infrastructure/config"
"github.com/kaspanet/kaspad/infrastructure/db/dbaccess"
"github.com/kaspanet/kaspad/infrastructure/network/connmanager"
"github.com/kaspanet/kaspad/infrastructure/network/dnsseed"
"github.com/kaspanet/kaspad/infrastructure/network/netadapter"
"github.com/kaspanet/kaspad/util/panics"
)
// ComponentManager is a wrapper for all the kaspad services
type ComponentManager struct {
cfg *config.Config
addressManager *addressmanager.AddressManager
protocolManager *protocol.Manager
rpcManager *rpc.Manager
connectionManager *connmanager.ConnectionManager
netAdapter *netadapter.NetAdapter
started, shutdown int32
}
// Start launches all the kaspad services.
func (a *ComponentManager) Start() {
// Already started?
if atomic.AddInt32(&a.started, 1) != 1 {
return
}
log.Trace("Starting kaspad")
err := a.netAdapter.Start()
if err != nil {
panics.Exit(log, fmt.Sprintf("Error starting the net adapter: %+v", err))
}
a.maybeSeedFromDNS()
a.connectionManager.Start()
}
// Stop gracefully shuts down all the kaspad services.
func (a *ComponentManager) Stop() {
// Make sure this only happens once.
if atomic.AddInt32(&a.shutdown, 1) != 1 {
log.Infof("Kaspad is already in the process of shutting down")
return
}
log.Warnf("Kaspad shutting down")
a.connectionManager.Stop()
err := a.netAdapter.Stop()
if err != nil {
log.Errorf("Error stopping the net adapter: %+v", err)
}
return
}
// NewComponentManager returns a new ComponentManager instance.
// Use Start() to begin all services within this ComponentManager
func NewComponentManager(cfg *config.Config, databaseContext *dbaccess.DatabaseContext, interrupt chan<- struct{}) (*ComponentManager, error) {
indexManager, acceptanceIndex := setupIndexes(cfg)
sigCache := txscript.NewSigCache(cfg.SigCacheMaxSize)
// Create a new block DAG instance with the appropriate configuration.
dag, err := setupDAG(cfg, databaseContext, sigCache, indexManager)
if err != nil {
return nil, err
}
txMempool := setupMempool(cfg, dag, sigCache)
netAdapter, err := netadapter.NewNetAdapter(cfg)
if err != nil {
return nil, err
}
addressManager, err := addressmanager.New(addressmanager.NewConfig(cfg))
if err != nil {
return nil, err
}
connectionManager, err := connmanager.New(cfg, netAdapter, addressManager)
if err != nil {
return nil, err
}
protocolManager, err := protocol.NewManager(cfg, dag, netAdapter, addressManager, txMempool, connectionManager)
if err != nil {
return nil, err
}
rpcManager := setupRPC(cfg, txMempool, dag, sigCache, netAdapter, protocolManager, connectionManager, addressManager, acceptanceIndex, interrupt)
return &ComponentManager{
cfg: cfg,
protocolManager: protocolManager,
rpcManager: rpcManager,
connectionManager: connectionManager,
netAdapter: netAdapter,
addressManager: addressManager,
}, nil
}
func setupRPC(
cfg *config.Config,
txMempool *mempool.TxPool,
dag *blockdag.BlockDAG,
sigCache *txscript.SigCache,
netAdapter *netadapter.NetAdapter,
protocolManager *protocol.Manager,
connectionManager *connmanager.ConnectionManager,
addressManager *addressmanager.AddressManager,
acceptanceIndex *indexers.AcceptanceIndex,
shutDownChan chan<- struct{},
) *rpc.Manager {
blockTemplateGenerator := mining.NewBlkTmplGenerator(&mining.Policy{BlockMaxMass: cfg.BlockMaxMass}, txMempool, dag, sigCache)
rpcManager := rpc.NewManager(cfg, netAdapter, dag, protocolManager, connectionManager, blockTemplateGenerator, txMempool, addressManager, acceptanceIndex, shutDownChan)
protocolManager.SetOnBlockAddedToDAGHandler(rpcManager.NotifyBlockAddedToDAG)
protocolManager.SetOnTransactionAddedToMempoolHandler(rpcManager.NotifyTransactionAddedToMempool)
dag.Subscribe(func(notification *blockdag.Notification) {
err := handleBlockDAGNotifications(notification, acceptanceIndex, rpcManager)
if err != nil {
panic(err)
}
})
return rpcManager
}
func handleBlockDAGNotifications(notification *blockdag.Notification,
acceptanceIndex *indexers.AcceptanceIndex, rpcManager *rpc.Manager) error {
switch notification.Type {
case blockdag.NTChainChanged:
if acceptanceIndex == nil {
return nil
}
chainChangedNotificationData := notification.Data.(*blockdag.ChainChangedNotificationData)
err := rpcManager.NotifyChainChanged(chainChangedNotificationData.RemovedChainBlockHashes,
chainChangedNotificationData.AddedChainBlockHashes)
if err != nil {
return err
}
case blockdag.NTFinalityConflict:
finalityConflictNotificationData := notification.Data.(*blockdag.FinalityConflictNotificationData)
err := rpcManager.NotifyFinalityConflict(finalityConflictNotificationData.ViolatingBlockHash.String())
if err != nil {
return err
}
case blockdag.NTFinalityConflictResolved:
finalityConflictResolvedNotificationData := notification.Data.(*blockdag.FinalityConflictResolvedNotificationData)
err := rpcManager.NotifyFinalityConflictResolved(finalityConflictResolvedNotificationData.FinalityBlockHash.String())
if err != nil {
return err
}
}
return nil
}
func (a *ComponentManager) maybeSeedFromDNS() {
if !a.cfg.DisableDNSSeed {
dnsseed.SeedFromDNS(a.cfg.NetParams(), a.cfg.DNSSeed, appmessage.SFNodeNetwork, false, nil,
a.cfg.Lookup, func(addresses []*appmessage.NetAddress) {
// Kaspad uses a lookup of the dns seeder here. Since seeder returns
// IPs of nodes and not its own IP, we can not know real IP of
// source. So we'll take first returned address as source.
a.addressManager.AddAddresses(addresses...)
})
}
if a.cfg.GRPCSeed != "" {
dnsseed.SeedFromGRPC(a.cfg.NetParams(), a.cfg.GRPCSeed, appmessage.SFNodeNetwork, false, nil,
func(addresses []*appmessage.NetAddress) {
a.addressManager.AddAddresses(addresses...)
})
}
}
func setupDAG(cfg *config.Config, databaseContext *dbaccess.DatabaseContext,
sigCache *txscript.SigCache, indexManager blockdag.IndexManager) (*blockdag.BlockDAG, error) {
dag, err := blockdag.New(&blockdag.Config{
DatabaseContext: databaseContext,
DAGParams: cfg.NetParams(),
TimeSource: blockdag.NewTimeSource(),
SigCache: sigCache,
IndexManager: indexManager,
SubnetworkID: cfg.SubnetworkID,
MaxUTXOCacheSize: cfg.MaxUTXOCacheSize,
})
return dag, err
}
func setupIndexes(cfg *config.Config) (blockdag.IndexManager, *indexers.AcceptanceIndex) {
// Create indexes if needed.
var indexes []indexers.Indexer
var acceptanceIndex *indexers.AcceptanceIndex
if cfg.AcceptanceIndex {
log.Info("acceptance index is enabled")
acceptanceIndex = indexers.NewAcceptanceIndex()
indexes = append(indexes, acceptanceIndex)
}
// Create an index manager if any of the optional indexes are enabled.
if len(indexes) < 0 {
return nil, nil
}
indexManager := indexers.NewManager(indexes)
return indexManager, acceptanceIndex
}
func setupMempool(cfg *config.Config, dag *blockdag.BlockDAG, sigCache *txscript.SigCache) *mempool.TxPool {
mempoolConfig := mempool.Config{
Policy: mempool.Policy{
AcceptNonStd: cfg.RelayNonStd,
MaxOrphanTxs: cfg.MaxOrphanTxs,
MaxOrphanTxSize: config.DefaultMaxOrphanTxSize,
MinRelayTxFee: cfg.MinRelayTxFee,
MaxTxVersion: 1,
},
CalcTxSequenceLockFromReferencedUTXOEntries: dag.CalcTxSequenceLockFromReferencedUTXOEntries,
SigCache: sigCache,
DAG: dag,
}
return mempool.New(&mempoolConfig)
}
// P2PNodeID returns the network ID associated with this ComponentManager
func (a *ComponentManager) P2PNodeID() *id.ID {
return a.netAdapter.ID()
}
// AddressManager returns the AddressManager associated with this ComponentManager
func (a *ComponentManager) AddressManager() *addressmanager.AddressManager {
return a.addressManager
}

View File

@@ -19,7 +19,10 @@ func (f *FlowContext) OnNewBlock(block *util.Block) error {
return err
}
if f.onBlockAddedToDAGHandler != nil {
f.onBlockAddedToDAGHandler(block)
err := f.onBlockAddedToDAGHandler(block)
if err != nil {
return err
}
}
return f.broadcastTransactionsAfterBlockAdded(block, transactionsAcceptedToMempool)

View File

@@ -20,7 +20,7 @@ import (
// OnBlockAddedToDAGHandler is a handler function that's triggered
// when a block is added to the DAG
type OnBlockAddedToDAGHandler func(block *util.Block)
type OnBlockAddedToDAGHandler func(block *util.Block) error
// OnTransactionAddedToMempoolHandler is a handler function that's triggered
// when a transaction is added to the mempool

View File

@@ -20,10 +20,6 @@ type ReceiveAddressesContext interface {
func ReceiveAddresses(context ReceiveAddressesContext, incomingRoute *router.Route, outgoingRoute *router.Route,
peer *peerpkg.Peer) error {
if !context.AddressManager().NeedMoreAddresses() {
return nil
}
subnetworkID := peer.SubnetworkID()
msgGetAddresses := appmessage.NewMsgRequestAddresses(false, subnetworkID)
err := outgoingRoute.Enqueue(msgGetAddresses)
@@ -51,7 +47,6 @@ func ReceiveAddresses(context ReceiveAddressesContext, incomingRoute *router.Rou
context.Config().SubnetworkID, msgAddresses.Command(), msgAddresses.SubnetworkID)
}
sourceAddress := peer.Connection().NetAddress()
context.AddressManager().AddAddresses(msgAddresses.AddrList, sourceAddress, msgAddresses.SubnetworkID)
context.AddressManager().AddAddresses(msgAddresses.AddrList...)
return nil
}

View File

@@ -1,10 +1,11 @@
package addressexchange
import (
"math/rand"
"github.com/kaspanet/kaspad/app/appmessage"
"github.com/kaspanet/kaspad/infrastructure/network/addressmanager"
"github.com/kaspanet/kaspad/infrastructure/network/netadapter/router"
"math/rand"
)
// SendAddressesContext is the interface for the context needed for the SendAddresses flow.
@@ -20,8 +21,7 @@ func SendAddresses(context SendAddressesContext, incomingRoute *router.Route, ou
}
msgGetAddresses := message.(*appmessage.MsgRequestAddresses)
addresses := context.AddressManager().AddressCache(msgGetAddresses.IncludeAllSubnetworks,
msgGetAddresses.SubnetworkID)
addresses := context.AddressManager().Addresses()
msgAddresses := appmessage.NewMsgAddresses(msgGetAddresses.IncludeAllSubnetworks, msgGetAddresses.SubnetworkID)
err = msgAddresses.AddAddresses(shuffleAddresses(addresses)...)
if err != nil {

View File

@@ -85,9 +85,7 @@ func HandleHandshake(context HandleHandshakeContext, netConnection *netadapter.N
}
if peerAddress != nil {
subnetworkID := peer.SubnetworkID()
context.AddressManager().AddAddress(peerAddress, peerAddress, subnetworkID)
context.AddressManager().Good(peerAddress, subnetworkID)
context.AddressManager().AddAddresses(peerAddress)
}
context.StartIBDIfRequired()

View File

@@ -50,7 +50,7 @@ func (flow *sendVersionFlow) start() error {
subnetworkID := flow.Config().SubnetworkID
// Version message.
localAddress := flow.AddressManager().GetBestLocalAddress(flow.peer.Connection().NetAddress())
localAddress := flow.AddressManager().BestLocalAddress(flow.peer.Connection().NetAddress())
msg := appmessage.NewMsgVersion(localAddress, flow.NetAdapter().ID(),
flow.Config().ActiveNetParams.Name, selectedTipHash, subnetworkID)
msg.AddUserAgent(userAgentName, userAgentVersion, flow.Config().UserAgentComments...)

View File

@@ -31,7 +31,8 @@ func NewManager(
blockTemplateGenerator *mining.BlkTmplGenerator,
mempool *mempool.TxPool,
addressManager *addressmanager.AddressManager,
acceptanceIndex *indexers.AcceptanceIndex) *Manager {
acceptanceIndex *indexers.AcceptanceIndex,
shutDownChan chan<- struct{}) *Manager {
manager := Manager{
context: rpccontext.NewContext(
@@ -44,6 +45,7 @@ func NewManager(
mempool,
addressManager,
acceptanceIndex,
shutDownChan,
),
}
netAdapter.SetRPCRouterInitializer(manager.routerInitializer)
@@ -52,11 +54,11 @@ func NewManager(
}
// NotifyBlockAddedToDAG notifies the manager that a block has been added to the DAG
func (m *Manager) NotifyBlockAddedToDAG(block *util.Block) {
func (m *Manager) NotifyBlockAddedToDAG(block *util.Block) error {
m.context.BlockTemplateState.NotifyBlockAdded(block)
notification := appmessage.NewBlockAddedNotificationMessage(block.MsgBlock())
m.context.NotificationManager.NotifyBlockAdded(notification)
return m.context.NotificationManager.NotifyBlockAdded(notification)
}
// NotifyChainChanged notifies the manager that the DAG's selected parent chain has changed
@@ -70,22 +72,19 @@ func (m *Manager) NotifyChainChanged(removedChainBlockHashes []*daghash.Hash, ad
removedChainBlockHashStrings[i] = removedChainBlockHash.String()
}
notification := appmessage.NewChainChangedNotificationMessage(removedChainBlockHashStrings, addedChainBlocks)
m.context.NotificationManager.NotifyChainChanged(notification)
return nil
return m.context.NotificationManager.NotifyChainChanged(notification)
}
// NotifyFinalityConflict notifies the manager that there's a finality conflict in the DAG
func (m *Manager) NotifyFinalityConflict(violatingBlockHash string) error {
notification := appmessage.NewFinalityConflictNotificationMessage(violatingBlockHash)
m.context.NotificationManager.NotifyFinalityConflict(notification)
return nil
return m.context.NotificationManager.NotifyFinalityConflict(notification)
}
// NotifyFinalityConflictResolved notifies the manager that a finality conflict in the DAG has been resolved
func (m *Manager) NotifyFinalityConflictResolved(finalityBlockHash string) error {
notification := appmessage.NewFinalityConflictResolvedNotificationMessage(finalityBlockHash)
m.context.NotificationManager.NotifyFinalityConflictResolved(notification)
return nil
return m.context.NotificationManager.NotifyFinalityConflictResolved(notification)
}
// NotifyTransactionAddedToMempool notifies the manager that a transaction has been added to the mempool

View File

@@ -31,6 +31,9 @@ var handlers = map[appmessage.MessageCommand]handler{
appmessage.CmdGetBlockDAGInfoRequestMessage: rpchandlers.HandleGetBlockDAGInfo,
appmessage.CmdResolveFinalityConflictRequestMessage: rpchandlers.HandleResolveFinalityConflict,
appmessage.CmdNotifyFinalityConflictsRequestMessage: rpchandlers.HandleNotifyFinalityConflicts,
appmessage.CmdGetMempoolEntriesRequestMessage: rpchandlers.HandleGetMempoolEntries,
appmessage.CmdShutDownRequestMessage: rpchandlers.HandleGetMempoolEntries,
appmessage.CmdGetHeadersRequestMessage: rpchandlers.HandleGetHeaders,
}
func (m *Manager) routerInitializer(router *router.Router, netConnection *netadapter.NetConnection) {
@@ -42,16 +45,12 @@ func (m *Manager) routerInitializer(router *router.Router, netConnection *netada
if err != nil {
panic(err)
}
spawn("routerInitializer-handleIncomingMessages", func() {
err := m.handleIncomingMessages(router, incomingRoute)
m.handleError(err, netConnection)
})
m.context.NotificationManager.AddListener(router)
notificationListener := m.context.NotificationManager.AddListener(router)
spawn("routerInitializer-handleOutgoingNotifications", func() {
spawn("routerInitializer-handleIncomingMessages", func() {
defer m.context.NotificationManager.RemoveListener(router)
err := m.handleOutgoingNotifications(notificationListener)
err := m.handleIncomingMessages(router, incomingRoute)
m.handleError(err, netConnection)
})
}
@@ -78,15 +77,6 @@ func (m *Manager) handleIncomingMessages(router *router.Router, incomingRoute *r
}
}
func (m *Manager) handleOutgoingNotifications(notificationListener *rpccontext.NotificationListener) error {
for {
err := notificationListener.ProcessNextNotification()
if err != nil {
return err
}
}
}
func (m *Manager) handleError(err error, netConnection *netadapter.NetConnection) {
if errors.Is(err, router.ErrTimeout) {
log.Warnf("Got timeout from %s. Disconnecting...", netConnection)

View File

@@ -23,14 +23,14 @@ type Context struct {
Mempool *mempool.TxPool
AddressManager *addressmanager.AddressManager
AcceptanceIndex *indexers.AcceptanceIndex
ShutDownChan chan<- struct{}
BlockTemplateState *BlockTemplateState
NotificationManager *NotificationManager
}
// NewContext creates a new RPC context
func NewContext(
cfg *config.Config,
func NewContext(cfg *config.Config,
netAdapter *netadapter.NetAdapter,
dag *blockdag.BlockDAG,
protocolManager *protocol.Manager,
@@ -38,7 +38,9 @@ func NewContext(
blockTemplateGenerator *mining.BlkTmplGenerator,
mempool *mempool.TxPool,
addressManager *addressmanager.AddressManager,
acceptanceIndex *indexers.AcceptanceIndex) *Context {
acceptanceIndex *indexers.AcceptanceIndex,
shutDownChan chan<- struct{}) *Context {
context := &Context{
Config: cfg,
NetAdapter: netAdapter,
@@ -49,6 +51,7 @@ func NewContext(
Mempool: mempool,
AddressManager: addressManager,
AcceptanceIndex: acceptanceIndex,
ShutDownChan: shutDownChan,
}
context.BlockTemplateState = NewBlockTemplateState(context)
context.NotificationManager = NewNotificationManager()

View File

@@ -2,7 +2,7 @@ package rpccontext
import (
"github.com/kaspanet/kaspad/app/appmessage"
"github.com/kaspanet/kaspad/infrastructure/network/netadapter/router"
routerpkg "github.com/kaspanet/kaspad/infrastructure/network/netadapter/router"
"github.com/pkg/errors"
"sync"
)
@@ -10,68 +10,43 @@ import (
// NotificationManager manages notifications for the RPC
type NotificationManager struct {
sync.RWMutex
listeners map[*router.Router]*NotificationListener
listeners map[*routerpkg.Router]*NotificationListener
}
// OnBlockAddedListener is a listener function for when a block is added to the DAG
type OnBlockAddedListener func(notification *appmessage.BlockAddedNotificationMessage) error
// OnChainChangedListener is a listener function for when the DAG's selected parent chain changes
type OnChainChangedListener func(notification *appmessage.ChainChangedNotificationMessage) error
// OnFinalityConflictListener is a listener function for when there's a finality conflict in the DAG
type OnFinalityConflictListener func(notification *appmessage.FinalityConflictNotificationMessage) error
// OnFinalityConflictResolvedListener is a listener function for when a finality conflict in the DAG has been resolved
type OnFinalityConflictResolvedListener func(notification *appmessage.FinalityConflictResolvedNotificationMessage) error
// NotificationListener represents a registered RPC notification listener
type NotificationListener struct {
onBlockAddedListener OnBlockAddedListener
onBlockAddedNotificationChan chan *appmessage.BlockAddedNotificationMessage
onChainChangedListener OnChainChangedListener
onChainChangedNotificationChan chan *appmessage.ChainChangedNotificationMessage
onFinalityConflictListener OnFinalityConflictListener
onFinalityConflictNotificationChan chan *appmessage.FinalityConflictNotificationMessage
onFinalityConflictResolvedListener OnFinalityConflictResolvedListener
onFinalityConflictResolvedNotificationChan chan *appmessage.FinalityConflictResolvedNotificationMessage
closeChan chan struct{}
propagateBlockAddedNotifications bool
propagateChainChangedNotifications bool
propagateFinalityConflictNotifications bool
propagateFinalityConflictResolvedNotifications bool
}
// NewNotificationManager creates a new NotificationManager
func NewNotificationManager() *NotificationManager {
return &NotificationManager{
listeners: make(map[*router.Router]*NotificationListener),
listeners: make(map[*routerpkg.Router]*NotificationListener),
}
}
// AddListener registers a listener with the given router
func (nm *NotificationManager) AddListener(router *router.Router) *NotificationListener {
func (nm *NotificationManager) AddListener(router *routerpkg.Router) {
nm.Lock()
defer nm.Unlock()
listener := newNotificationListener()
nm.listeners[router] = listener
return listener
}
// RemoveListener unregisters the given router
func (nm *NotificationManager) RemoveListener(router *router.Router) {
func (nm *NotificationManager) RemoveListener(router *routerpkg.Router) {
nm.Lock()
defer nm.Unlock()
listener, ok := nm.listeners[router]
if !ok {
return
}
listener.close()
delete(nm.listeners, router)
}
// Listener retrieves the listener registered with the given router
func (nm *NotificationManager) Listener(router *router.Router) (*NotificationListener, error) {
func (nm *NotificationManager) Listener(router *routerpkg.Router) (*NotificationListener, error) {
nm.RLock()
defer nm.RUnlock()
@@ -83,115 +58,98 @@ func (nm *NotificationManager) Listener(router *router.Router) (*NotificationLis
}
// NotifyBlockAdded notifies the notification manager that a block has been added to the DAG
func (nm *NotificationManager) NotifyBlockAdded(notification *appmessage.BlockAddedNotificationMessage) {
func (nm *NotificationManager) NotifyBlockAdded(notification *appmessage.BlockAddedNotificationMessage) error {
nm.RLock()
defer nm.RUnlock()
for _, listener := range nm.listeners {
if listener.onBlockAddedListener != nil {
select {
case listener.onBlockAddedNotificationChan <- notification:
case <-listener.closeChan:
continue
for router, listener := range nm.listeners {
if listener.propagateBlockAddedNotifications {
err := router.OutgoingRoute().Enqueue(notification)
if err != nil {
return err
}
}
}
return nil
}
// NotifyChainChanged notifies the notification manager that the DAG's selected parent chain has changed
func (nm *NotificationManager) NotifyChainChanged(message *appmessage.ChainChangedNotificationMessage) {
func (nm *NotificationManager) NotifyChainChanged(notification *appmessage.ChainChangedNotificationMessage) error {
nm.RLock()
defer nm.RUnlock()
for _, listener := range nm.listeners {
if listener.onChainChangedListener != nil {
select {
case listener.onChainChangedNotificationChan <- message:
case <-listener.closeChan:
continue
for router, listener := range nm.listeners {
if listener.propagateChainChangedNotifications {
err := router.OutgoingRoute().Enqueue(notification)
if err != nil {
return err
}
}
}
return nil
}
// NotifyFinalityConflict notifies the notification manager that there's a finality conflict in the DAG
func (nm *NotificationManager) NotifyFinalityConflict(message *appmessage.FinalityConflictNotificationMessage) {
func (nm *NotificationManager) NotifyFinalityConflict(notification *appmessage.FinalityConflictNotificationMessage) error {
nm.RLock()
defer nm.RUnlock()
for _, listener := range nm.listeners {
if listener.onFinalityConflictListener != nil {
select {
case listener.onFinalityConflictNotificationChan <- message:
case <-listener.closeChan:
continue
for router, listener := range nm.listeners {
if listener.propagateFinalityConflictNotifications {
err := router.OutgoingRoute().Enqueue(notification)
if err != nil {
return err
}
}
}
return nil
}
// NotifyFinalityConflictResolved notifies the notification manager that a finality conflict in the DAG has been resolved
func (nm *NotificationManager) NotifyFinalityConflictResolved(message *appmessage.FinalityConflictResolvedNotificationMessage) {
func (nm *NotificationManager) NotifyFinalityConflictResolved(notification *appmessage.FinalityConflictResolvedNotificationMessage) error {
nm.RLock()
defer nm.RUnlock()
for _, listener := range nm.listeners {
if listener.onFinalityConflictResolvedListener != nil {
select {
case listener.onFinalityConflictResolvedNotificationChan <- message:
case <-listener.closeChan:
continue
for router, listener := range nm.listeners {
if listener.propagateFinalityConflictResolvedNotifications {
err := router.OutgoingRoute().Enqueue(notification)
if err != nil {
return err
}
}
}
return nil
}
func newNotificationListener() *NotificationListener {
return &NotificationListener{
onBlockAddedNotificationChan: make(chan *appmessage.BlockAddedNotificationMessage),
onChainChangedNotificationChan: make(chan *appmessage.ChainChangedNotificationMessage),
onFinalityConflictNotificationChan: make(chan *appmessage.FinalityConflictNotificationMessage),
onFinalityConflictResolvedNotificationChan: make(chan *appmessage.FinalityConflictResolvedNotificationMessage),
closeChan: make(chan struct{}, 1),
propagateBlockAddedNotifications: false,
propagateChainChangedNotifications: false,
propagateFinalityConflictNotifications: false,
propagateFinalityConflictResolvedNotifications: false,
}
}
// SetOnBlockAddedListener sets the onBlockAddedListener handler for this listener
func (nl *NotificationListener) SetOnBlockAddedListener(onBlockAddedListener OnBlockAddedListener) {
nl.onBlockAddedListener = onBlockAddedListener
// PropagateBlockAddedNotifications instructs the listener to send block added notifications
// to the remote listener
func (nl *NotificationListener) PropagateBlockAddedNotifications() {
nl.propagateBlockAddedNotifications = true
}
// SetOnChainChangedListener sets the onChainChangedListener handler for this listener
func (nl *NotificationListener) SetOnChainChangedListener(onChainChangedListener OnChainChangedListener) {
nl.onChainChangedListener = onChainChangedListener
// PropagateChainChangedNotifications instructs the listener to send chain changed notifications
// to the remote listener
func (nl *NotificationListener) PropagateChainChangedNotifications() {
nl.propagateChainChangedNotifications = true
}
// SetOnFinalityConflictListener sets the onFinalityConflictListener handler for this listener
func (nl *NotificationListener) SetOnFinalityConflictListener(onFinalityConflictListener OnFinalityConflictListener) {
nl.onFinalityConflictListener = onFinalityConflictListener
// PropagateFinalityConflictNotifications instructs the listener to send finality conflict notifications
// to the remote listener
func (nl *NotificationListener) PropagateFinalityConflictNotifications() {
nl.propagateFinalityConflictNotifications = true
}
// SetOnFinalityConflictResolvedListener sets the onFinalityConflictResolvedListener handler for this listener
func (nl *NotificationListener) SetOnFinalityConflictResolvedListener(onFinalityConflictResolvedListener OnFinalityConflictResolvedListener) {
nl.onFinalityConflictResolvedListener = onFinalityConflictResolvedListener
}
// ProcessNextNotification waits until a notification arrives and processes it
func (nl *NotificationListener) ProcessNextNotification() error {
select {
case block := <-nl.onBlockAddedNotificationChan:
return nl.onBlockAddedListener(block)
case notification := <-nl.onChainChangedNotificationChan:
return nl.onChainChangedListener(notification)
case notification := <-nl.onFinalityConflictNotificationChan:
return nl.onFinalityConflictListener(notification)
case notification := <-nl.onFinalityConflictResolvedNotificationChan:
return nl.onFinalityConflictResolvedListener(notification)
case <-nl.closeChan:
return nil
}
}
func (nl *NotificationListener) close() {
nl.closeChan <- struct{}{}
// PropagateFinalityConflictResolvedNotifications instructs the listener to send finality conflict resolved notifications
// to the remote listener
func (nl *NotificationListener) PropagateFinalityConflictResolvedNotifications() {
nl.propagateFinalityConflictResolvedNotifications = true
}

View File

@@ -88,7 +88,7 @@ func (ctx *Context) BuildBlockVerboseData(block *util.Block, includeTransactionV
transactions := block.Transactions()
transactionVerboseData := make([]*appmessage.TransactionVerboseData, len(transactions))
for i, tx := range transactions {
data, err := ctx.buildTransactionVerboseData(tx.MsgTx(), tx.ID().String(),
data, err := ctx.BuildTransactionVerboseData(tx.MsgTx(), tx.ID().String(),
&blockHeader, hash.String(), nil, false)
if err != nil {
return nil, err
@@ -120,7 +120,9 @@ func (ctx *Context) GetDifficultyRatio(bits uint32, params *dagconfig.Params) fl
return diff
}
func (ctx *Context) buildTransactionVerboseData(mtx *appmessage.MsgTx,
// BuildTransactionVerboseData builds a TransactionVerboseData from
// the given parameters
func (ctx *Context) BuildTransactionVerboseData(mtx *appmessage.MsgTx,
txID string, blockHeader *appmessage.BlockHeader, blockHash string,
acceptingBlock *daghash.Hash, isInMempool bool) (*appmessage.TransactionVerboseData, error) {

View File

@@ -12,7 +12,7 @@ import (
const (
// maxBlocksInGetBlocksResponse is the max amount of blocks that are
// allowed in a GetBlocksResult.
maxBlocksInGetBlocksResponse = 1000
maxBlocksInGetBlocksResponse = 100
)
// HandleGetBlocks handles the respectively named RPC command

View File

@@ -0,0 +1,52 @@
package rpchandlers
import (
"bytes"
"encoding/hex"
"github.com/kaspanet/kaspad/app/appmessage"
"github.com/kaspanet/kaspad/app/rpc/rpccontext"
"github.com/kaspanet/kaspad/infrastructure/network/netadapter/router"
"github.com/kaspanet/kaspad/util/daghash"
)
// HandleGetHeaders handles the respectively named RPC command
func HandleGetHeaders(context *rpccontext.Context, _ *router.Router, request appmessage.Message) (appmessage.Message, error) {
getHeadersRequest := request.(*appmessage.GetHeadersRequestMessage)
dag := context.DAG
var startHash *daghash.Hash
if getHeadersRequest.StartHash != "" {
var err error
startHash, err = daghash.NewHashFromStr(getHeadersRequest.StartHash)
if err != nil {
errorMessage := &appmessage.GetHeadersResponseMessage{}
errorMessage.Error = appmessage.RPCErrorf("Start hash could not be parsed: %s", err)
return errorMessage, nil
}
}
const getHeadersDefaultLimit uint64 = 2000
limit := getHeadersDefaultLimit
if getHeadersRequest.Limit != 0 {
limit = getHeadersRequest.Limit
}
headers, err := dag.GetHeaders(startHash, limit, getHeadersRequest.IsAscending)
if err != nil {
errorMessage := &appmessage.GetHeadersResponseMessage{}
errorMessage.Error = appmessage.RPCErrorf("Error getting the headers: %s", err)
return errorMessage, nil
}
headersHex := make([]string, len(headers))
var buf bytes.Buffer
for i, header := range headers {
err := header.Serialize(&buf)
if err != nil {
return nil, err
}
headersHex[i] = hex.EncodeToString(buf.Bytes())
buf.Reset()
}
return appmessage.NewGetHeadersResponseMessage(headersHex), nil
}

View File

@@ -0,0 +1,25 @@
package rpchandlers
import (
"github.com/kaspanet/kaspad/app/appmessage"
"github.com/kaspanet/kaspad/app/rpc/rpccontext"
"github.com/kaspanet/kaspad/infrastructure/network/netadapter/router"
)
// HandleGetMempoolEntries handles the respectively named RPC command
func HandleGetMempoolEntries(context *rpccontext.Context, _ *router.Router, _ appmessage.Message) (appmessage.Message, error) {
txDescs := context.Mempool.TxDescs()
entries := make([]*appmessage.MempoolEntry, len(txDescs))
for i, txDesc := range txDescs {
transactionVerboseData, err := context.BuildTransactionVerboseData(txDesc.Tx.MsgTx(), txDesc.Tx.ID().String(),
nil, "", nil, true)
if err != nil {
return nil, err
}
entries[i] = &appmessage.MempoolEntry{
Fee: txDesc.Fee,
TransactionVerboseData: transactionVerboseData,
}
}
return appmessage.NewGetMempoolEntriesResponseMessage(entries), nil
}

View File

@@ -17,13 +17,19 @@ func HandleGetMempoolEntry(context *rpccontext.Context, _ *router.Router, reques
return errorMessage, nil
}
_, ok := context.Mempool.FetchTxDesc(txID)
txDesc, ok := context.Mempool.FetchTxDesc(txID)
if !ok {
errorMessage := &appmessage.GetMempoolEntryResponseMessage{}
errorMessage.Error = appmessage.RPCErrorf("transaction is not in the pool")
return errorMessage, nil
}
response := appmessage.NewGetMempoolEntryResponseMessage()
transactionVerboseData, err := context.BuildTransactionVerboseData(txDesc.Tx.MsgTx(), txID.String(),
nil, "", nil, true)
if err != nil {
return nil, err
}
response := appmessage.NewGetMempoolEntryResponseMessage(txDesc.Fee, transactionVerboseData)
return response, nil
}

View File

@@ -1,6 +1,9 @@
package rpchandlers
import (
"net"
"strconv"
"github.com/kaspanet/kaspad/app/appmessage"
"github.com/kaspanet/kaspad/app/rpc/rpccontext"
"github.com/kaspanet/kaspad/infrastructure/network/netadapter/router"
@@ -8,14 +11,20 @@ import (
// HandleGetPeerAddresses handles the respectively named RPC command
func HandleGetPeerAddresses(context *rpccontext.Context, _ *router.Router, _ appmessage.Message) (appmessage.Message, error) {
peersState, err := context.AddressManager.PeersStateForSerialization()
if err != nil {
return nil, err
netAddresses := context.AddressManager.Addresses()
addressMessages := make([]*appmessage.GetPeerAddressesKnownAddressMessage, len(netAddresses))
for i, netAddress := range netAddresses {
addressWithPort := net.JoinHostPort(netAddress.IP.String(), strconv.FormatUint(uint64(netAddress.Port), 10))
addressMessages[i] = &appmessage.GetPeerAddressesKnownAddressMessage{Addr: addressWithPort}
}
addresses := make([]*appmessage.GetPeerAddressesKnownAddressMessage, len(peersState.Addresses))
for i, address := range peersState.Addresses {
addresses[i] = &appmessage.GetPeerAddressesKnownAddressMessage{Addr: string(address.Address)}
bannedAddresses := context.AddressManager.BannedAddresses()
bannedAddressMessages := make([]*appmessage.GetPeerAddressesKnownAddressMessage, len(bannedAddresses))
for i, netAddress := range bannedAddresses {
addressWithPort := net.JoinHostPort(netAddress.IP.String(), strconv.FormatUint(uint64(netAddress.Port), 10))
bannedAddressMessages[i] = &appmessage.GetPeerAddressesKnownAddressMessage{Addr: addressWithPort}
}
response := appmessage.NewGetPeerAddressesResponseMessage(addresses)
response := appmessage.NewGetPeerAddressesResponseMessage(addressMessages, bannedAddressMessages)
return response, nil
}

View File

@@ -2,6 +2,8 @@ package rpchandlers
import (
"github.com/kaspanet/kaspad/infrastructure/logger"
"github.com/kaspanet/kaspad/util/panics"
)
var log, _ = logger.Get(logger.SubsystemTags.RPCS)
var spawn = panics.GoroutineWrapperFunc(log)

View File

@@ -12,9 +12,7 @@ func HandleNotifyBlockAdded(context *rpccontext.Context, router *router.Router,
if err != nil {
return nil, err
}
listener.SetOnBlockAddedListener(func(notification *appmessage.BlockAddedNotificationMessage) error {
return router.OutgoingRoute().Enqueue(notification)
})
listener.PropagateBlockAddedNotifications()
response := appmessage.NewNotifyBlockAddedResponseMessage()
return response, nil

View File

@@ -18,9 +18,7 @@ func HandleNotifyChainChanged(context *rpccontext.Context, router *router.Router
if err != nil {
return nil, err
}
listener.SetOnChainChangedListener(func(message *appmessage.ChainChangedNotificationMessage) error {
return router.OutgoingRoute().Enqueue(message)
})
listener.PropagateChainChangedNotifications()
response := appmessage.NewNotifyChainChangedResponseMessage()
return response, nil

View File

@@ -12,12 +12,8 @@ func HandleNotifyFinalityConflicts(context *rpccontext.Context, router *router.R
if err != nil {
return nil, err
}
listener.SetOnFinalityConflictListener(func(notification *appmessage.FinalityConflictNotificationMessage) error {
return router.OutgoingRoute().Enqueue(notification)
})
listener.SetOnFinalityConflictResolvedListener(func(notification *appmessage.FinalityConflictResolvedNotificationMessage) error {
return router.OutgoingRoute().Enqueue(notification)
})
listener.PropagateFinalityConflictNotifications()
listener.PropagateFinalityConflictResolvedNotifications()
response := appmessage.NewNotifyFinalityConflictsResponseMessage()
return response, nil

View File

@@ -0,0 +1,25 @@
package rpchandlers
import (
"time"
"github.com/kaspanet/kaspad/app/appmessage"
"github.com/kaspanet/kaspad/app/rpc/rpccontext"
"github.com/kaspanet/kaspad/infrastructure/network/netadapter/router"
)
const pauseBeforeShutDown = time.Second
// HandleShutDown handles the respectively named RPC command
func HandleShutDown(context *rpccontext.Context, _ *router.Router, _ appmessage.Message) (appmessage.Message, error) {
log.Warn("ShutDown RPC called.")
// Wait a second before shutting down, to allow time to return the response to the caller
spawn("HandleShutDown-pauseAndShutDown", func() {
<-time.After(pauseBeforeShutDown)
close(context.ShutDownChan)
})
response := appmessage.NewShutDownResponseMessage()
return response, nil
}

View File

@@ -5,10 +5,11 @@
package main
import (
"github.com/kaspanet/kaspad/infrastructure/logger"
"os"
"runtime"
"github.com/kaspanet/kaspad/infrastructure/logger"
"github.com/kaspanet/kaspad/infrastructure/os/limits"
"github.com/kaspanet/kaspad/util/panics"
)
@@ -77,7 +78,7 @@ func realMain() error {
func main() {
// Use all processor cores and up some limits.
runtime.GOMAXPROCS(runtime.NumCPU())
if err := limits.SetLimits(); err != nil {
if err := limits.SetLimits(nil); err != nil {
os.Exit(1)
}

View File

@@ -107,7 +107,7 @@ func genesisPastUTXO(virtual *virtualBlock) UTXOSet {
// set by creating a diff UTXO set with the virtual UTXO
// set, and adding all of its entries in toRemove
diff := NewUTXODiff()
for outpoint, entry := range virtual.utxoSet.utxoCollection {
for outpoint, entry := range virtual.utxoSet.utxoCache {
diff.toRemove[outpoint] = entry
}
genesisPastUTXO := UTXOSet(NewDiffUTXOSet(virtual.utxoSet, diff))

View File

@@ -78,7 +78,7 @@ func loadUTXOSet(filename string) (UTXOSet, error) {
if err != nil {
return nil, err
}
utxoSet.utxoCollection[appmessage.Outpoint{TxID: txID, Index: index}] = entry
utxoSet.utxoCache[appmessage.Outpoint{TxID: txID, Index: index}] = entry
}
return utxoSet, nil

View File

@@ -9,13 +9,6 @@ import (
// Config is a descriptor which specifies the blockDAG instance configuration.
type Config struct {
// Interrupt specifies a channel the caller can close to signal that
// long running operations, such as catching up indexes or performing
// database migrations, should be interrupted.
//
// This field can be nil if the caller does not desire the behavior.
Interrupt <-chan struct{}
// DAGParams identifies which DAG parameters the DAG is associated
// with.
//
@@ -51,4 +44,8 @@ type Config struct {
// DatabaseContext is the context in which all database queries related to
// this DAG are going to run.
DatabaseContext *dbaccess.DatabaseContext
// MaxUTXOCacheSize is the Max size of loaded UTXO into ram from the disk in bytes
// to support UTXO lazy-load
MaxUTXOCacheSize uint64
}

View File

@@ -94,7 +94,8 @@ type BlockDAG struct {
recentBlockProcessingTimestamps []mstime.Time
startTime mstime.Time
tips blockSet
maxUTXOCacheSize uint64
tips blockSet
// validTips is a set of blocks with the status "valid", which have no valid descendants.
// Note that some validTips might not be actual tips.
@@ -122,6 +123,7 @@ func New(config *Config) (*BlockDAG, error) {
blockCount: 0,
subnetworkID: config.SubnetworkID,
startTime: mstime.Now(),
maxUTXOCacheSize: config.MaxUTXOCacheSize,
}
dag.virtual = newVirtualBlock(dag, nil)
@@ -410,17 +412,30 @@ func (dag *BlockDAG) isAnyInPastOf(nodes blockSet, other *blockNode) (bool, erro
return false, nil
}
// GetTopHeaders returns the top appmessage.MaxBlockHeadersPerMsg block headers ordered by blue score.
func (dag *BlockDAG) GetTopHeaders(highHash *daghash.Hash, maxHeaders uint64) ([]*appmessage.BlockHeader, error) {
// GetHeaders returns DAG headers ordered by blue score, starts from the given hash with the given direction.
func (dag *BlockDAG) GetHeaders(startHash *daghash.Hash, maxHeaders uint64,
isAscending bool) ([]*appmessage.BlockHeader, error) {
dag.RLock()
defer dag.RUnlock()
if isAscending {
return dag.getHeadersAscending(startHash, maxHeaders)
}
return dag.getHeadersDescending(startHash, maxHeaders)
}
func (dag *BlockDAG) getHeadersDescending(highHash *daghash.Hash, maxHeaders uint64) ([]*appmessage.BlockHeader, error) {
highNode := dag.virtual.blockNode
if highHash != nil {
var ok bool
highNode, ok = dag.index.LookupNode(highHash)
if !ok {
return nil, errors.Errorf("Couldn't find the high hash %s in the dag", highHash)
return nil, errors.Errorf("Couldn't find the start hash %s in the dag", highHash)
}
}
headers := make([]*appmessage.BlockHeader, 0, highNode.blueScore)
headers := make([]*appmessage.BlockHeader, 0, maxHeaders)
queue := newDownHeap()
queue.pushSet(highNode.parents)
@@ -436,6 +451,31 @@ func (dag *BlockDAG) GetTopHeaders(highHash *daghash.Hash, maxHeaders uint64) ([
return headers, nil
}
func (dag *BlockDAG) getHeadersAscending(lowHash *daghash.Hash, maxHeaders uint64) ([]*appmessage.BlockHeader, error) {
lowNode := dag.genesis
if lowHash != nil {
var ok bool
lowNode, ok = dag.index.LookupNode(lowHash)
if !ok {
return nil, errors.Errorf("Couldn't find the start hash %s in the dag", lowHash)
}
}
headers := make([]*appmessage.BlockHeader, 0, maxHeaders)
queue := newUpHeap()
queue.pushSet(lowNode.children)
visited := newBlockSet()
for i := uint32(0); queue.Len() > 0 && uint64(len(headers)) < maxHeaders; i++ {
current := queue.pop()
if !visited.contains(current) {
visited.add(current)
headers = append(headers, current.Header())
queue.pushSet(current.children)
}
}
return headers, nil
}
// ForEachHash runs the given fn on every hash that's currently known to
// the DAG.
//

View File

@@ -17,6 +17,7 @@ import (
"github.com/kaspanet/kaspad/app/appmessage"
"github.com/kaspanet/kaspad/domain/dagconfig"
"github.com/kaspanet/kaspad/domain/txscript"
"github.com/kaspanet/kaspad/infrastructure/config"
"github.com/kaspanet/kaspad/infrastructure/db/dbaccess"
"github.com/kaspanet/kaspad/util"
"github.com/kaspanet/kaspad/util/daghash"
@@ -213,7 +214,7 @@ func TestIsKnownBlock(t *testing.T) {
{hash: "732c891529619d43b5aeb3df42ba25dea483a8c0aded1cf585751ebabea28f29", want: true},
// Random hashes should not be available.
{hash: "123", want: false},
{hash: "1234567812345678123456781234567812345678123456781234567812345678", want: false},
}
for i, test := range tests {
@@ -278,7 +279,10 @@ func TestCalcSequenceLock(t *testing.T) {
// age of 4 blocks.
msgTx := appmessage.NewNativeMsgTx(appmessage.TxVersion, nil, []*appmessage.TxOut{{ScriptPubKey: nil, Value: 10}})
targetTx := util.NewTx(msgTx)
utxoSet := NewFullUTXOSet()
fullUTXOCacheSize := config.DefaultConfig().MaxUTXOCacheSize
db, teardown := prepareDatabaseForTest(t, "TestCalcSequenceLock")
defer teardown()
utxoSet := NewFullUTXOSetFromContext(db, fullUTXOCacheSize)
blueScore := uint64(numBlocksToGenerate) - 4
if isAccepted, err := utxoSet.AddTx(targetTx.MsgTx(), blueScore); err != nil {
t.Fatalf("AddTx unexpectedly failed. Error: %s", err)
@@ -1272,7 +1276,7 @@ func TestUTXOCommitment(t *testing.T) {
// Build a Multiset for block D
multiset := secp256k1.NewMultiset()
for outpoint, entry := range blockDPastDiffUTXOSet.base.utxoCollection {
for outpoint, entry := range blockDPastDiffUTXOSet.base.utxoCache {
var err error
multiset, err = addUTXOToMultiset(multiset, entry, &outpoint)
if err != nil {

View File

@@ -211,12 +211,6 @@ func (dag *BlockDAG) initDAGState() error {
return err
}
log.Debugf("Loading UTXO set...")
fullUTXOCollection, err := dag.initUTXOSet()
if err != nil {
return err
}
log.Debugf("Loading reachability data...")
err = dag.reachabilityTree.init(dag.databaseContext)
if err != nil {
@@ -229,12 +223,6 @@ func (dag *BlockDAG) initDAGState() error {
return err
}
log.Debugf("Applying the loaded utxoCollection to the virtual block...")
dag.virtual.utxoSet, err = newFullUTXOSetFromUTXOCollection(fullUTXOCollection)
if err != nil {
return errors.Wrap(err, "Error loading UTXOSet")
}
log.Debugf("Applying the stored tips to the virtual block...")
err = dag.initTipsAndVirtualParents(dagState)
if err != nil {

View File

@@ -10,7 +10,7 @@ import (
// re-selecting virtual parents in such a way that given finalityBlock will be in virtual's selectedParentChain
func (dag *BlockDAG) ResolveFinalityConflict(finalityBlockHash *daghash.Hash) error {
dag.dagLock.Lock()
defer dag.dagLock.RUnlock()
defer dag.dagLock.Unlock()
finalityBlock, ok := dag.index.LookupNode(finalityBlockHash)
if !ok {

View File

@@ -24,12 +24,12 @@ type txValidateItem struct {
// inputs. It provides several channels for communication and a processing
// function that is intended to be in run multiple goroutines.
type txValidator struct {
validateChan chan *txValidateItem
quitChan chan struct{}
resultChan chan error
utxoSet UTXOSet
flags txscript.ScriptFlags
sigCache *txscript.SigCache
validateChan chan *txValidateItem
quitChan chan struct{}
resultChan chan error
referencedUTXOEntries []*UTXOEntry
flags txscript.ScriptFlags
sigCache *txscript.SigCache
}
// sendResult sends the result of a script pair validation on the internal
@@ -51,19 +51,8 @@ out:
for {
select {
case txVI := <-v.validateChan:
// Ensure the referenced input utxo is available.
txIn := txVI.txIn
entry, ok := v.utxoSet.Get(txIn.PreviousOutpoint)
if !ok {
str := fmt.Sprintf("unable to find unspent "+
"output %s referenced from "+
"transaction %s input %d",
txIn.PreviousOutpoint, txVI.tx.ID(),
txVI.txInIndex)
err := ruleError(ErrMissingTxOut, str)
v.sendResult(err)
break out
}
entry := v.referencedUTXOEntries[txVI.txInIndex]
// Create a new script engine for the script pair.
sigScript := txIn.SignatureScript
@@ -164,20 +153,20 @@ func (v *txValidator) Validate(items []*txValidateItem) error {
// newTxValidator returns a new instance of txValidator to be used for
// validating transaction scripts asynchronously.
func newTxValidator(utxoSet UTXOSet, flags txscript.ScriptFlags, sigCache *txscript.SigCache) *txValidator {
func newTxValidator(referencedUTXOEntries []*UTXOEntry, flags txscript.ScriptFlags, sigCache *txscript.SigCache) *txValidator {
return &txValidator{
validateChan: make(chan *txValidateItem),
quitChan: make(chan struct{}),
resultChan: make(chan error),
utxoSet: utxoSet,
sigCache: sigCache,
flags: flags,
validateChan: make(chan *txValidateItem),
quitChan: make(chan struct{}),
resultChan: make(chan error),
referencedUTXOEntries: referencedUTXOEntries,
sigCache: sigCache,
flags: flags,
}
}
// ValidateTransactionScripts validates the scripts for the passed transaction
// using multiple goroutines.
func ValidateTransactionScripts(tx *util.Tx, utxoSet UTXOSet, flags txscript.ScriptFlags, sigCache *txscript.SigCache) error {
func ValidateTransactionScripts(tx *util.Tx, referencedUTXOEntries []*UTXOEntry, flags txscript.ScriptFlags, sigCache *txscript.SigCache) error {
// Collect all of the transaction inputs and required information for
// validation.
txIns := tx.MsgTx().TxIn
@@ -192,6 +181,6 @@ func ValidateTransactionScripts(tx *util.Tx, utxoSet UTXOSet, flags txscript.Scr
}
// Validate all of the inputs.
validator := newTxValidator(utxoSet, flags, sigCache)
validator := newTxValidator(referencedUTXOEntries, flags, sigCache)
return validator.Validate(txValItems)
}

View File

@@ -43,16 +43,26 @@ func (dag *BlockDAG) CalcSequenceLockNoLock(tx *util.Tx, utxoSet UTXOSet) (*Sequ
//
// This function MUST be called with the DAG state lock held (for writes).
func (dag *BlockDAG) calcTxSequenceLock(node *blockNode, tx *util.Tx, utxoSet UTXOSet) (*SequenceLock, error) {
inputsWithUTXOEntries, err := dag.getReferencedUTXOEntries(tx, utxoSet)
referencedUTXOEntries, err := dag.getReferencedUTXOEntries(tx, utxoSet)
if err != nil {
return nil, err
}
return dag.calcTxSequenceLockFromInputsWithUTXOEntries(node, tx, inputsWithUTXOEntries)
return dag.calcTxSequenceLockFromReferencedUTXOEntries(node, tx, referencedUTXOEntries)
}
func (dag *BlockDAG) calcTxSequenceLockFromInputsWithUTXOEntries(
node *blockNode, tx *util.Tx, inputsWithUTXOEntries []*txInputAndUTXOEntry) (*SequenceLock, error) {
// CalcTxSequenceLockFromReferencedUTXOEntries computes the relative lock-times for the passed
// transaction, with the given referenced UTXO entries. See CalcSequenceLock for further details.
func (dag *BlockDAG) CalcTxSequenceLockFromReferencedUTXOEntries(
tx *util.Tx, referencedUTXOEntries []*UTXOEntry) (*SequenceLock, error) {
dag.dagLock.RLock()
defer dag.dagLock.RUnlock()
return dag.calcTxSequenceLockFromReferencedUTXOEntries(dag.selectedTip(), tx, referencedUTXOEntries)
}
func (dag *BlockDAG) calcTxSequenceLockFromReferencedUTXOEntries(
node *blockNode, tx *util.Tx, referencedUTXOEntries []*UTXOEntry) (*SequenceLock, error) {
// A value of -1 for each relative lock type represents a relative time
// lock value that will allow a transaction to be included in a block
@@ -66,9 +76,8 @@ func (dag *BlockDAG) calcTxSequenceLockFromInputsWithUTXOEntries(
return sequenceLock, nil
}
for _, txInAndReferencedUTXOEntry := range inputsWithUTXOEntries {
txIn := txInAndReferencedUTXOEntry.txIn
utxoEntry := txInAndReferencedUTXOEntry.utxoEntry
for i, txIn := range tx.MsgTx().TxIn {
utxoEntry := referencedUTXOEntries[i]
// If the input blue score is set to the mempool blue score, then we
// assume the transaction makes it into the next block when

View File

@@ -14,14 +14,14 @@ import (
"sync"
"testing"
"github.com/kaspanet/kaspad/infrastructure/db/database/ffldb/ldb"
"github.com/kaspanet/kaspad/infrastructure/config"
"github.com/kaspanet/kaspad/infrastructure/db/database/ldb"
"github.com/kaspanet/kaspad/infrastructure/db/dbaccess"
"github.com/kaspanet/kaspad/util"
"github.com/kaspanet/kaspad/util/subnetworkid"
"github.com/pkg/errors"
"github.com/syndtr/goleveldb/leveldb/opt"
"github.com/kaspanet/kaspad/util/subnetworkid"
"github.com/kaspanet/kaspad/app/appmessage"
"github.com/kaspanet/kaspad/domain/txscript"
"github.com/kaspanet/kaspad/util/daghash"
@@ -43,7 +43,7 @@ func FileExists(name string) bool {
// The openDB parameter instructs DAGSetup whether or not to also open the
// database. Setting it to false is useful in tests that handle database
// opening/closing by themselves.
func DAGSetup(dbName string, openDb bool, config Config) (*BlockDAG, func(), error) {
func DAGSetup(dbName string, openDb bool, dagConfig Config) (*BlockDAG, func(), error) {
var teardown func()
// To make sure that the teardown function is not called before any goroutines finished to run -
@@ -81,7 +81,7 @@ func DAGSetup(dbName string, openDb bool, config Config) (*BlockDAG, func(), err
return nil, nil, errors.Errorf("error creating db: %s", err)
}
config.DatabaseContext = databaseContext
dagConfig.DatabaseContext = databaseContext
// Setup a teardown function for cleaning up. This function is
// returned to the caller to be invoked when it is done testing.
@@ -99,11 +99,12 @@ func DAGSetup(dbName string, openDb bool, config Config) (*BlockDAG, func(), err
}
}
config.TimeSource = NewTimeSource()
config.SigCache = txscript.NewSigCache(1000)
dagConfig.TimeSource = NewTimeSource()
dagConfig.SigCache = txscript.NewSigCache(1000)
dagConfig.MaxUTXOCacheSize = config.DefaultConfig().MaxUTXOCacheSize
// Create the DAG instance.
dag, err := New(&config)
dag, err := New(&dagConfig)
if err != nil {
teardown()
err := errors.Wrapf(err, "failed to create dag instance")

View File

@@ -1,14 +1,17 @@
package blockdag
import (
"bytes"
"fmt"
"math"
"sort"
"strings"
"unsafe"
"github.com/kaspanet/kaspad/infrastructure/db/dbaccess"
"github.com/pkg/errors"
"github.com/kaspanet/go-secp256k1"
"github.com/kaspanet/kaspad/app/appmessage"
)
@@ -484,29 +487,27 @@ type UTXOSet interface {
// FullUTXOSet represents a full list of transaction outputs and their values
type FullUTXOSet struct {
utxoCollection
utxoCache utxoCollection
dbContext dbaccess.Context
estimatedSize uint64
maxUTXOCacheSize uint64
outpointBuff *bytes.Buffer
}
// NewFullUTXOSet creates a new utxoSet with full list of transaction outputs and their values
func NewFullUTXOSet() *FullUTXOSet {
return &FullUTXOSet{
utxoCollection: utxoCollection{},
utxoCache: utxoCollection{},
}
}
// newFullUTXOSetFromUTXOCollection converts a utxoCollection to a FullUTXOSet
func newFullUTXOSetFromUTXOCollection(collection utxoCollection) (*FullUTXOSet, error) {
var err error
multiset := secp256k1.NewMultiset()
for outpoint, utxoEntry := range collection {
multiset, err = addUTXOToMultiset(multiset, utxoEntry, &outpoint)
if err != nil {
return nil, err
}
}
// NewFullUTXOSetFromContext creates a new utxoSet and map the data context with caching
func NewFullUTXOSetFromContext(context dbaccess.Context, cacheSize uint64) *FullUTXOSet {
return &FullUTXOSet{
utxoCollection: collection,
}, nil
dbContext: context,
maxUTXOCacheSize: cacheSize,
utxoCache: make(utxoCollection),
}
}
// diffFrom returns the difference between this utxoSet and another
@@ -564,15 +565,93 @@ func (fus *FullUTXOSet) containsInputs(tx *appmessage.MsgTx) bool {
return true
}
// contains returns a boolean value indicating whether a UTXO entry is in the set
func (fus *FullUTXOSet) contains(outpoint appmessage.Outpoint) bool {
_, ok := fus.Get(outpoint)
return ok
}
// clone returns a clone of this utxoSet
func (fus *FullUTXOSet) clone() UTXOSet {
return &FullUTXOSet{utxoCollection: fus.utxoCollection.clone()}
return &FullUTXOSet{
utxoCache: fus.utxoCache.clone(),
dbContext: fus.dbContext,
estimatedSize: fus.estimatedSize,
maxUTXOCacheSize: fus.maxUTXOCacheSize,
}
}
// get returns the UTXOEntry associated with the given Outpoint, and a boolean indicating if such entry was found
func (fus *FullUTXOSet) get(outpoint appmessage.Outpoint) (*UTXOEntry, bool) {
return fus.Get(outpoint)
}
// getSizeOfUTXOEntryAndOutpoint returns estimated size of UTXOEntry & Outpoint in bytes
func getSizeOfUTXOEntryAndOutpoint(entry *UTXOEntry) uint64 {
const staticSize = uint64(unsafe.Sizeof(UTXOEntry{}) + unsafe.Sizeof(appmessage.Outpoint{}))
return staticSize + uint64(len(entry.scriptPubKey))
}
// checkAndCleanCachedData checks the FullUTXOSet estimated size and clean it if it reaches the limit
func (fus *FullUTXOSet) checkAndCleanCachedData() {
if fus.estimatedSize > fus.maxUTXOCacheSize {
fus.utxoCache = make(utxoCollection)
fus.estimatedSize = 0
}
}
// add adds a new UTXO entry to this FullUTXOSet
func (fus *FullUTXOSet) add(outpoint appmessage.Outpoint, entry *UTXOEntry) {
fus.utxoCache[outpoint] = entry
fus.estimatedSize += getSizeOfUTXOEntryAndOutpoint(entry)
fus.checkAndCleanCachedData()
}
// remove removes a UTXO entry from this collection if it exists
func (fus *FullUTXOSet) remove(outpoint appmessage.Outpoint) {
entry, ok := fus.utxoCache.get(outpoint)
if ok {
delete(fus.utxoCache, outpoint)
fus.estimatedSize -= getSizeOfUTXOEntryAndOutpoint(entry)
}
}
// Get returns the UTXOEntry associated with the given Outpoint, and a boolean indicating if such entry was found
// If the UTXOEntry doesn't not exist in the memory then check in the database
func (fus *FullUTXOSet) Get(outpoint appmessage.Outpoint) (*UTXOEntry, bool) {
utxoEntry, ok := fus.utxoCollection[outpoint]
return utxoEntry, ok
utxoEntry, ok := fus.utxoCache[outpoint]
if ok {
return utxoEntry, ok
}
if fus.outpointBuff == nil {
fus.outpointBuff = bytes.NewBuffer(make([]byte, outpointSerializeSize))
}
fus.outpointBuff.Reset()
err := serializeOutpoint(fus.outpointBuff, &outpoint)
if err != nil {
return nil, false
}
key := fus.outpointBuff.Bytes()
value, err := dbaccess.GetFromUTXOSet(fus.dbContext, key)
if err != nil {
return nil, false
}
entry, err := deserializeUTXOEntry(bytes.NewReader(value))
if err != nil {
return nil, false
}
fus.add(outpoint, entry)
return entry, true
}
func (fus *FullUTXOSet) String() string {
return fus.utxoCache.String()
}
// DiffUTXOSet represents a utxoSet with a base fullUTXOSet and a UTXODiff

View File

@@ -1,15 +1,47 @@
package blockdag
import (
"io/ioutil"
"os"
"path/filepath"
"reflect"
"testing"
"github.com/kaspanet/kaspad/infrastructure/config"
"github.com/kaspanet/kaspad/infrastructure/db/dbaccess"
"github.com/kaspanet/kaspad/util/subnetworkid"
"github.com/kaspanet/kaspad/app/appmessage"
"github.com/kaspanet/kaspad/util/daghash"
)
func prepareDatabaseForTest(t *testing.T, testName string) (*dbaccess.DatabaseContext, func()) {
var err error
tmpDir, err := ioutil.TempDir("", "utxoset_test")
if err != nil {
t.Fatalf("error creating temp dir: %s", err)
return nil, nil
}
dbPath := filepath.Join(tmpDir, testName)
_ = os.RemoveAll(dbPath)
databaseContext, err := dbaccess.New(dbPath)
if err != nil {
t.Fatalf("error creating db: %s", err)
return nil, nil
}
// Setup a teardown function for cleaning up. This function is
// returned to the caller to be invoked when it is done testing.
teardown := func() {
databaseContext.Close()
os.RemoveAll(dbPath)
}
return databaseContext, teardown
}
// TestUTXOCollection makes sure that utxoCollection cloning and string representations work as expected.
func TestUTXOCollection(t *testing.T) {
txID0, _ := daghash.NewTxIDFromStr("0000000000000000000000000000000000000000000000000000000000000000")
@@ -619,7 +651,7 @@ func (d *UTXODiff) equal(other *UTXODiff) bool {
}
func (fus *FullUTXOSet) equal(other *FullUTXOSet) bool {
return reflect.DeepEqual(fus.utxoCollection, other.utxoCollection)
return reflect.DeepEqual(fus.utxoCache, other.utxoCache)
}
func (dus *DiffUTXOSet) equal(other *DiffUTXOSet) bool {
@@ -642,7 +674,10 @@ func TestFullUTXOSet(t *testing.T) {
}
// Test fullUTXOSet creation
emptySet := NewFullUTXOSet()
fullUTXOCacheSize := config.DefaultConfig().MaxUTXOCacheSize
db, teardown := prepareDatabaseForTest(t, "TestDiffUTXOSet")
defer teardown()
emptySet := NewFullUTXOSetFromContext(db, fullUTXOCacheSize)
if len(emptySet.collection()) != 0 {
t.Errorf("new set is not empty")
}
@@ -668,7 +703,8 @@ func TestFullUTXOSet(t *testing.T) {
} else if isAccepted {
t.Errorf("addTx unexpectedly succeeded")
}
emptySet = &FullUTXOSet{utxoCollection: utxoCollection{outpoint0: utxoEntry0}}
emptySet = NewFullUTXOSetFromContext(db, fullUTXOCacheSize)
emptySet.add(outpoint0, utxoEntry0)
if isAccepted, err := emptySet.AddTx(transaction0, 0); err != nil {
t.Errorf("addTx unexpectedly failed. Error: %s", err)
} else if !isAccepted {
@@ -676,7 +712,7 @@ func TestFullUTXOSet(t *testing.T) {
}
// Test fullUTXOSet collection
if !reflect.DeepEqual(emptySet.collection(), emptySet.utxoCollection) {
if !reflect.DeepEqual(emptySet.collection(), emptySet.utxoCache) {
t.Errorf("collection does not equal the set's utxoCollection")
}
@@ -704,9 +740,12 @@ func TestDiffUTXOSet(t *testing.T) {
toAdd: utxoCollection{outpoint0: utxoEntry0},
toRemove: utxoCollection{outpoint1: utxoEntry1},
}
fullUTXOCacheSize := config.DefaultConfig().MaxUTXOCacheSize
db, teardown := prepareDatabaseForTest(t, "TestDiffUTXOSet")
defer teardown()
// Test diffUTXOSet creation
emptySet := NewDiffUTXOSet(NewFullUTXOSet(), NewUTXODiff())
emptySet := NewDiffUTXOSet(NewFullUTXOSetFromContext(db, fullUTXOCacheSize), NewUTXODiff())
if collection, err := emptySet.collection(); err != nil {
t.Errorf("Error getting emptySet collection: %s", err)
} else if len(collection) != 0 {
@@ -726,7 +765,7 @@ func TestDiffUTXOSet(t *testing.T) {
if !reflect.DeepEqual(withDiffUTXOSet.base, emptySet.base) || !reflect.DeepEqual(withDiffUTXOSet.UTXODiff, withDiff) {
t.Errorf("WithDiff is of unexpected composition")
}
_, err = NewDiffUTXOSet(NewFullUTXOSet(), diff).WithDiff(diff)
_, err = NewDiffUTXOSet(NewFullUTXOSetFromContext(db, fullUTXOCacheSize), diff).WithDiff(diff)
if err == nil {
t.Errorf("WithDiff unexpectedly succeeded")
}
@@ -748,14 +787,14 @@ func TestDiffUTXOSet(t *testing.T) {
{
name: "empty base, empty diff",
diffSet: &DiffUTXOSet{
base: NewFullUTXOSet(),
base: NewFullUTXOSetFromContext(db, fullUTXOCacheSize),
UTXODiff: &UTXODiff{
toAdd: utxoCollection{},
toRemove: utxoCollection{},
},
},
expectedMeldSet: &DiffUTXOSet{
base: NewFullUTXOSet(),
base: NewFullUTXOSetFromContext(db, fullUTXOCacheSize),
UTXODiff: &UTXODiff{
toAdd: utxoCollection{},
toRemove: utxoCollection{},
@@ -767,14 +806,18 @@ func TestDiffUTXOSet(t *testing.T) {
{
name: "empty base, one member in diff toAdd",
diffSet: &DiffUTXOSet{
base: NewFullUTXOSet(),
base: NewFullUTXOSetFromContext(db, fullUTXOCacheSize),
UTXODiff: &UTXODiff{
toAdd: utxoCollection{outpoint0: utxoEntry0},
toRemove: utxoCollection{},
},
},
expectedMeldSet: &DiffUTXOSet{
base: &FullUTXOSet{utxoCollection: utxoCollection{outpoint0: utxoEntry0}},
base: func() *FullUTXOSet {
futxo := NewFullUTXOSetFromContext(db, fullUTXOCacheSize)
futxo.add(outpoint0, utxoEntry0)
return futxo
}(),
UTXODiff: &UTXODiff{
toAdd: utxoCollection{},
toRemove: utxoCollection{},
@@ -786,7 +829,7 @@ func TestDiffUTXOSet(t *testing.T) {
{
name: "empty base, one member in diff toRemove",
diffSet: &DiffUTXOSet{
base: NewFullUTXOSet(),
base: NewFullUTXOSetFromContext(db, fullUTXOCacheSize),
UTXODiff: &UTXODiff{
toAdd: utxoCollection{},
toRemove: utxoCollection{outpoint0: utxoEntry0},
@@ -800,19 +843,23 @@ func TestDiffUTXOSet(t *testing.T) {
{
name: "one member in base toAdd, one member in diff toAdd",
diffSet: &DiffUTXOSet{
base: &FullUTXOSet{utxoCollection: utxoCollection{outpoint0: utxoEntry0}},
base: func() *FullUTXOSet {
futxo := NewFullUTXOSetFromContext(db, fullUTXOCacheSize)
futxo.add(outpoint0, utxoEntry0)
return futxo
}(),
UTXODiff: &UTXODiff{
toAdd: utxoCollection{outpoint1: utxoEntry1},
toRemove: utxoCollection{},
},
},
expectedMeldSet: &DiffUTXOSet{
base: &FullUTXOSet{
utxoCollection: utxoCollection{
outpoint0: utxoEntry0,
outpoint1: utxoEntry1,
},
},
base: func() *FullUTXOSet {
futxo := NewFullUTXOSetFromContext(db, fullUTXOCacheSize)
futxo.add(outpoint0, utxoEntry0)
futxo.add(outpoint1, utxoEntry1)
return futxo
}(),
UTXODiff: &UTXODiff{
toAdd: utxoCollection{},
toRemove: utxoCollection{},
@@ -827,16 +874,18 @@ func TestDiffUTXOSet(t *testing.T) {
{
name: "one member in base toAdd, same one member in diff toRemove",
diffSet: &DiffUTXOSet{
base: &FullUTXOSet{utxoCollection: utxoCollection{outpoint0: utxoEntry0}},
base: func() *FullUTXOSet {
futxo := NewFullUTXOSetFromContext(db, fullUTXOCacheSize)
futxo.add(outpoint0, utxoEntry0)
return futxo
}(),
UTXODiff: &UTXODiff{
toAdd: utxoCollection{},
toRemove: utxoCollection{outpoint0: utxoEntry0},
},
},
expectedMeldSet: &DiffUTXOSet{
base: &FullUTXOSet{
utxoCollection: utxoCollection{},
},
base: NewFullUTXOSetFromContext(db, fullUTXOCacheSize),
UTXODiff: &UTXODiff{
toAdd: utxoCollection{},
toRemove: utxoCollection{},
@@ -949,6 +998,9 @@ func TestDiffUTXOSet_addTx(t *testing.T) {
txOut0 := &appmessage.TxOut{ScriptPubKey: []byte{0}, Value: 10}
utxoEntry0 := NewUTXOEntry(txOut0, true, 0)
coinbaseTX := appmessage.NewSubnetworkMsgTx(1, []*appmessage.TxIn{}, []*appmessage.TxOut{txOut0}, subnetworkid.SubnetworkIDCoinbase, 0, nil)
fullUTXOCacheSize := config.DefaultConfig().MaxUTXOCacheSize
db, teardown := prepareDatabaseForTest(t, "TestDiffUTXOSet")
defer teardown()
// transaction1 spends coinbaseTX
id1 := coinbaseTX.TxID()
@@ -982,11 +1034,11 @@ func TestDiffUTXOSet_addTx(t *testing.T) {
}{
{
name: "add coinbase transaction to empty set",
startSet: NewDiffUTXOSet(NewFullUTXOSet(), NewUTXODiff()),
startSet: NewDiffUTXOSet(NewFullUTXOSetFromContext(db, fullUTXOCacheSize), NewUTXODiff()),
startHeight: 0,
toAdd: []*appmessage.MsgTx{coinbaseTX},
expectedSet: &DiffUTXOSet{
base: &FullUTXOSet{utxoCollection: utxoCollection{}},
base: NewFullUTXOSetFromContext(db, fullUTXOCacheSize),
UTXODiff: &UTXODiff{
toAdd: utxoCollection{outpoint1: utxoEntry0},
toRemove: utxoCollection{},
@@ -995,11 +1047,11 @@ func TestDiffUTXOSet_addTx(t *testing.T) {
},
{
name: "add regular transaction to empty set",
startSet: NewDiffUTXOSet(NewFullUTXOSet(), NewUTXODiff()),
startSet: NewDiffUTXOSet(NewFullUTXOSetFromContext(db, fullUTXOCacheSize), NewUTXODiff()),
startHeight: 0,
toAdd: []*appmessage.MsgTx{transaction1},
expectedSet: &DiffUTXOSet{
base: &FullUTXOSet{utxoCollection: utxoCollection{}},
base: NewFullUTXOSetFromContext(db, fullUTXOCacheSize),
UTXODiff: &UTXODiff{
toAdd: utxoCollection{},
toRemove: utxoCollection{},
@@ -1009,7 +1061,11 @@ func TestDiffUTXOSet_addTx(t *testing.T) {
{
name: "add transaction to set with its input in base",
startSet: &DiffUTXOSet{
base: &FullUTXOSet{utxoCollection: utxoCollection{outpoint1: utxoEntry0}},
base: func() *FullUTXOSet {
futxo := NewFullUTXOSetFromContext(db, fullUTXOCacheSize)
futxo.add(outpoint1, utxoEntry0)
return futxo
}(),
UTXODiff: &UTXODiff{
toAdd: utxoCollection{},
toRemove: utxoCollection{},
@@ -1018,7 +1074,11 @@ func TestDiffUTXOSet_addTx(t *testing.T) {
startHeight: 1,
toAdd: []*appmessage.MsgTx{transaction1},
expectedSet: &DiffUTXOSet{
base: &FullUTXOSet{utxoCollection: utxoCollection{outpoint1: utxoEntry0}},
base: func() *FullUTXOSet {
futxo := NewFullUTXOSetFromContext(db, fullUTXOCacheSize)
futxo.add(outpoint1, utxoEntry0)
return futxo
}(),
UTXODiff: &UTXODiff{
toAdd: utxoCollection{outpoint2: utxoEntry1},
toRemove: utxoCollection{outpoint1: utxoEntry0},
@@ -1028,7 +1088,7 @@ func TestDiffUTXOSet_addTx(t *testing.T) {
{
name: "add transaction to set with its input in diff toAdd",
startSet: &DiffUTXOSet{
base: NewFullUTXOSet(),
base: NewFullUTXOSetFromContext(db, fullUTXOCacheSize),
UTXODiff: &UTXODiff{
toAdd: utxoCollection{outpoint1: utxoEntry0},
toRemove: utxoCollection{},
@@ -1037,7 +1097,7 @@ func TestDiffUTXOSet_addTx(t *testing.T) {
startHeight: 1,
toAdd: []*appmessage.MsgTx{transaction1},
expectedSet: &DiffUTXOSet{
base: NewFullUTXOSet(),
base: NewFullUTXOSetFromContext(db, fullUTXOCacheSize),
UTXODiff: &UTXODiff{
toAdd: utxoCollection{outpoint2: utxoEntry1},
toRemove: utxoCollection{},
@@ -1047,7 +1107,7 @@ func TestDiffUTXOSet_addTx(t *testing.T) {
{
name: "add transaction to set with its input in diff toAdd and its output in diff toRemove",
startSet: &DiffUTXOSet{
base: NewFullUTXOSet(),
base: NewFullUTXOSetFromContext(db, fullUTXOCacheSize),
UTXODiff: &UTXODiff{
toAdd: utxoCollection{outpoint1: utxoEntry0},
toRemove: utxoCollection{outpoint2: utxoEntry1},
@@ -1056,7 +1116,7 @@ func TestDiffUTXOSet_addTx(t *testing.T) {
startHeight: 1,
toAdd: []*appmessage.MsgTx{transaction1},
expectedSet: &DiffUTXOSet{
base: NewFullUTXOSet(),
base: NewFullUTXOSetFromContext(db, fullUTXOCacheSize),
UTXODiff: &UTXODiff{
toAdd: utxoCollection{},
toRemove: utxoCollection{},
@@ -1066,7 +1126,11 @@ func TestDiffUTXOSet_addTx(t *testing.T) {
{
name: "add two transactions, one spending the other, to set with the first input in base",
startSet: &DiffUTXOSet{
base: &FullUTXOSet{utxoCollection: utxoCollection{outpoint1: utxoEntry0}},
base: func() *FullUTXOSet {
futxo := NewFullUTXOSetFromContext(db, fullUTXOCacheSize)
futxo.add(outpoint1, utxoEntry0)
return futxo
}(),
UTXODiff: &UTXODiff{
toAdd: utxoCollection{},
toRemove: utxoCollection{},
@@ -1075,7 +1139,11 @@ func TestDiffUTXOSet_addTx(t *testing.T) {
startHeight: 1,
toAdd: []*appmessage.MsgTx{transaction1, transaction2},
expectedSet: &DiffUTXOSet{
base: &FullUTXOSet{utxoCollection: utxoCollection{outpoint1: utxoEntry0}},
base: func() *FullUTXOSet {
futxo := NewFullUTXOSetFromContext(db, fullUTXOCacheSize)
futxo.add(outpoint1, utxoEntry0)
return futxo
}(),
UTXODiff: &UTXODiff{
toAdd: utxoCollection{outpoint3: utxoEntry2},
toRemove: utxoCollection{outpoint1: utxoEntry0},
@@ -1108,7 +1176,7 @@ testLoop:
// collection returns a collection of all UTXOs in this set
func (fus *FullUTXOSet) collection() utxoCollection {
return fus.utxoCollection.clone()
return fus.utxoCache.clone()
}
// collection returns a collection of all UTXOs in this set

View File

@@ -353,21 +353,18 @@ func (dag *BlockDAG) checkProofOfWork(header *appmessage.BlockHeader, flags Beha
// ValidateTxMass makes sure that the given transaction's mass does not exceed
// the maximum allowed limit. Currently, it is equivalent to the block mass limit.
// See CalcTxMass for further details.
func ValidateTxMass(tx *util.Tx, utxoSet UTXOSet) error {
txMass, err := CalcTxMassFromUTXOSet(tx, utxoSet)
if err != nil {
return err
}
func ValidateTxMass(tx *util.Tx, referencedUTXOEntries []*UTXOEntry) (txMass uint64, err error) {
txMass = calcTxMassFromReferencedUTXOEntries(tx, referencedUTXOEntries)
if txMass > appmessage.MaxMassAcceptedByBlock {
str := fmt.Sprintf("tx %s has mass %d, which is above the "+
"allowed limit of %d", tx.ID(), txMass, appmessage.MaxMassAcceptedByBlock)
return ruleError(ErrTxMassTooHigh, str)
return 0, ruleError(ErrTxMassTooHigh, str)
}
return nil
return txMass, nil
}
func calcTxMassFromInputsWithUTXOEntries(
tx *util.Tx, inputsWithUTXOEntries []*txInputAndUTXOEntry) uint64 {
func calcTxMassFromReferencedUTXOEntries(
tx *util.Tx, referencedUTXOEntries []*UTXOEntry) uint64 {
if tx.IsCoinBase() {
return calcCoinbaseTxMass(tx)
@@ -375,9 +372,7 @@ func calcTxMassFromInputsWithUTXOEntries(
previousScriptPubKeys := make([][]byte, 0, len(tx.MsgTx().TxIn))
for _, inputWithUTXOEntry := range inputsWithUTXOEntries {
utxoEntry := inputWithUTXOEntry.utxoEntry
for _, utxoEntry := range referencedUTXOEntries {
previousScriptPubKeys = append(previousScriptPubKeys, utxoEntry.ScriptPubKey())
}
return CalcTxMass(tx, previousScriptPubKeys)
@@ -928,7 +923,7 @@ func checkTxIsNotDuplicate(tx *util.Tx, utxoSet UTXOSet) error {
// NOTE: The transaction MUST have already been sanity checked with the
// CheckTransactionSanity function prior to calling this function.
func CheckTransactionInputsAndCalulateFee(
tx *util.Tx, txBlueScore uint64, utxoSet UTXOSet, dagParams *dagconfig.Params, fastAdd bool) (
tx *util.Tx, txBlueScore uint64, referencedUTXOEntries []*UTXOEntry, dagParams *dagconfig.Params, fastAdd bool) (
txFeeInSompi uint64, err error) {
// Coinbase transactions have no standard inputs to validate.
@@ -938,10 +933,7 @@ func CheckTransactionInputsAndCalulateFee(
var totalSompiIn uint64
for txInIndex, txIn := range tx.MsgTx().TxIn {
entry, err := findReferencedOutput(tx, utxoSet, txIn, txInIndex)
if err != nil {
return 0, err
}
entry := referencedUTXOEntries[txInIndex]
if !fastAdd {
if err = validateCoinbaseMaturity(dagParams, entry, txBlueScore, txIn); err != nil {
@@ -968,11 +960,10 @@ func checkEntryAmounts(entry *UTXOEntry, totalSompiInBefore uint64) (totalSompiI
// The total of all outputs must not be more than the max
// allowed per transaction. Also, we could potentially overflow
// the accumulator so check for overflow.
lastSompiIn := totalSompiInBefore
originTxSompi := entry.Amount()
totalSompiInAfter = totalSompiInBefore + originTxSompi
if totalSompiInBefore < lastSompiIn ||
totalSompiInBefore > util.MaxSompi {
if totalSompiInAfter < totalSompiInBefore ||
totalSompiInAfter > util.MaxSompi {
str := fmt.Sprintf("total value of all transaction "+
"inputs is %d which is higher than max "+
"allowed value of %d", totalSompiInBefore,
@@ -982,20 +973,6 @@ func checkEntryAmounts(entry *UTXOEntry, totalSompiInBefore uint64) (totalSompiI
return totalSompiInAfter, nil
}
func findReferencedOutput(
tx *util.Tx, utxoSet UTXOSet, txIn *appmessage.TxIn, txInIndex int) (*UTXOEntry, error) {
entry, ok := utxoSet.Get(txIn.PreviousOutpoint)
if !ok {
str := fmt.Sprintf("output %s referenced from "+
"transaction %s input %d either does not exist or "+
"has already been spent", txIn.PreviousOutpoint,
tx.ID(), txInIndex)
return nil, ruleError(ErrMissingTxOut, str)
}
return entry, nil
}
func validateCoinbaseMaturity(dagParams *dagconfig.Params, entry *UTXOEntry, txBlueScore uint64, txIn *appmessage.TxIn) error {
// Ensure the transaction is not spending coins which have not
// yet reached the required coinbase maturity.
@@ -1039,11 +1016,6 @@ func (dag *BlockDAG) checkConnectBlockToPastUTXO(
return nil
}
type txInputAndUTXOEntry struct {
txIn *appmessage.TxIn
utxoEntry *UTXOEntry
}
func (dag *BlockDAG) checkConnectTransactionToPastUTXO(
node *blockNode, tx *util.Tx, pastUTXO UTXOSet, accumulatedMassBefore uint64, selectedParentMedianTime mstime.Time) (
txFee uint64, accumulatedMassAfter uint64, err error) {
@@ -1053,22 +1025,22 @@ func (dag *BlockDAG) checkConnectTransactionToPastUTXO(
return 0, 0, err
}
inputsWithUTXOEntries, err := dag.getReferencedUTXOEntries(tx, pastUTXO)
referencedUTXOEntries, err := dag.getReferencedUTXOEntries(tx, pastUTXO)
if err != nil {
return 0, 0, err
}
accumulatedMassAfter, err = dag.checkTxMass(tx, inputsWithUTXOEntries, accumulatedMassBefore)
accumulatedMassAfter, err = dag.checkTxMass(tx, referencedUTXOEntries, accumulatedMassBefore)
if err != nil {
return 0, 0, err
}
err = dag.checkTxCoinbaseMaturity(node, inputsWithUTXOEntries)
err = dag.checkTxCoinbaseMaturity(node, tx, referencedUTXOEntries)
if err != nil {
return 0, 0, nil
}
totalSompiIn, err := dag.checkTxInputAmounts(inputsWithUTXOEntries)
totalSompiIn, err := dag.checkTxInputAmounts(referencedUTXOEntries)
if err != nil {
return 0, 0, nil
}
@@ -1080,12 +1052,12 @@ func (dag *BlockDAG) checkConnectTransactionToPastUTXO(
txFee = totalSompiIn - totalSompiOut
err = dag.checkTxSequenceLock(node, tx, inputsWithUTXOEntries, selectedParentMedianTime)
err = dag.checkTxSequenceLock(node, tx, referencedUTXOEntries, selectedParentMedianTime)
if err != nil {
return 0, 0, nil
}
err = ValidateTransactionScripts(tx, pastUTXO, txscript.ScriptNoFlags, dag.sigCache)
err = ValidateTransactionScripts(tx, referencedUTXOEntries, txscript.ScriptNoFlags, dag.sigCache)
if err != nil {
return 0, 0, err
}
@@ -1094,12 +1066,12 @@ func (dag *BlockDAG) checkConnectTransactionToPastUTXO(
}
func (dag *BlockDAG) checkTxSequenceLock(node *blockNode, tx *util.Tx,
inputsWithUTXOEntries []*txInputAndUTXOEntry, medianTime mstime.Time) error {
referencedUTXOEntries []*UTXOEntry, medianTime mstime.Time) error {
// A transaction can only be included within a block
// once the sequence locks of *all* its inputs are
// active.
sequenceLock, err := dag.calcTxSequenceLockFromInputsWithUTXOEntries(node, tx, inputsWithUTXOEntries)
sequenceLock, err := dag.calcTxSequenceLockFromReferencedUTXOEntries(node, tx, referencedUTXOEntries)
if err != nil {
return err
}
@@ -1133,12 +1105,11 @@ func checkTxOutputAmounts(tx *util.Tx, totalSompiIn uint64) (uint64, error) {
}
func (dag *BlockDAG) checkTxInputAmounts(
inputsWithUTXOEntries []*txInputAndUTXOEntry) (totalSompiIn uint64, err error) {
inputUTXOEntries []*UTXOEntry) (totalSompiIn uint64, err error) {
totalSompiIn = 0
for _, txInAndReferencedUTXOEntry := range inputsWithUTXOEntries {
utxoEntry := txInAndReferencedUTXOEntry.utxoEntry
for _, utxoEntry := range inputUTXOEntries {
// Ensure the transaction amounts are in range. Each of the
// output values of the input transactions must not be negative
@@ -1156,11 +1127,10 @@ func (dag *BlockDAG) checkTxInputAmounts(
}
func (dag *BlockDAG) checkTxCoinbaseMaturity(
node *blockNode, inputsWithUTXOEntries []*txInputAndUTXOEntry) error {
node *blockNode, tx *util.Tx, referencedUTXOEntries []*UTXOEntry) error {
txBlueScore := node.blueScore
for _, txInAndReferencedUTXOEntry := range inputsWithUTXOEntries {
txIn := txInAndReferencedUTXOEntry.txIn
utxoEntry := txInAndReferencedUTXOEntry.utxoEntry
for i, txIn := range tx.MsgTx().TxIn {
utxoEntry := referencedUTXOEntries[i]
if utxoEntry.IsCoinbase() {
originBlueScore := utxoEntry.BlockBlueScore()
@@ -1181,10 +1151,10 @@ func (dag *BlockDAG) checkTxCoinbaseMaturity(
return nil
}
func (dag *BlockDAG) checkTxMass(tx *util.Tx, inputsWithUTXOEntries []*txInputAndUTXOEntry,
func (dag *BlockDAG) checkTxMass(tx *util.Tx, referencedUTXOEntries []*UTXOEntry,
accumulatedMassBefore uint64) (accumulatedMassAfter uint64, err error) {
txMass := calcTxMassFromInputsWithUTXOEntries(tx, inputsWithUTXOEntries)
txMass := calcTxMassFromReferencedUTXOEntries(tx, referencedUTXOEntries)
accumulatedMassAfter = accumulatedMassBefore + txMass
@@ -1199,11 +1169,10 @@ func (dag *BlockDAG) checkTxMass(tx *util.Tx, inputsWithUTXOEntries []*txInputAn
return accumulatedMassAfter, nil
}
func (dag *BlockDAG) getReferencedUTXOEntries(tx *util.Tx, utxoSet UTXOSet) (
[]*txInputAndUTXOEntry, error) {
func (dag *BlockDAG) getReferencedUTXOEntries(tx *util.Tx, utxoSet UTXOSet) ([]*UTXOEntry, error) {
txIns := tx.MsgTx().TxIn
inputsWithUTXOEntries := make([]*txInputAndUTXOEntry, 0, len(txIns))
referencedUTXOEntries := make([]*UTXOEntry, 0, len(txIns))
for txInIndex, txIn := range txIns {
utxoEntry, ok := utxoSet.Get(txIn.PreviousOutpoint)
@@ -1215,13 +1184,10 @@ func (dag *BlockDAG) getReferencedUTXOEntries(tx *util.Tx, utxoSet UTXOSet) (
return nil, ruleError(ErrMissingTxOut, str)
}
inputsWithUTXOEntries = append(inputsWithUTXOEntries, &txInputAndUTXOEntry{
txIn: txIn,
utxoEntry: utxoEntry,
})
referencedUTXOEntries = append(referencedUTXOEntries, utxoEntry)
}
return inputsWithUTXOEntries, nil
return referencedUTXOEntries, nil
}
func (dag *BlockDAG) checkTotalFee(totalFees uint64, txFee uint64) (uint64, error) {

View File

@@ -32,7 +32,7 @@ func newVirtualBlock(dag *BlockDAG, parents blockSet) *virtualBlock {
// The mutex is intentionally not held since this is a constructor.
var virtual virtualBlock
virtual.dag = dag
virtual.utxoSet = NewFullUTXOSet()
virtual.utxoSet = NewFullUTXOSetFromContext(dag.databaseContext, dag.maxUTXOCacheSize)
virtual.selectedParentChainSet = newBlockSet()
virtual.selectedParentChainSlice = nil
virtual.blockNode, _ = dag.newBlockNode(nil, parents)

View File

@@ -48,16 +48,13 @@ type Config struct {
// to policy.
Policy Policy
// CalcSequenceLockNoLock defines the function to use in order to generate
// the current sequence lock for the given transaction using the passed
// utxo set.
CalcSequenceLockNoLock func(*util.Tx, blockdag.UTXOSet) (*blockdag.SequenceLock, error)
// SigCache defines a signature cache to use.
SigCache *txscript.SigCache
// DAG is the BlockDAG we want to use (mainly for UTXO checks)
DAG *blockdag.BlockDAG
CalcTxSequenceLockFromReferencedUTXOEntries func(tx *util.Tx, referencedUTXOEntries []*blockdag.UTXOEntry) (*blockdag.SequenceLock, error)
}
// Policy houses the policy (configuration parameters) which is used to
@@ -92,7 +89,7 @@ type Policy struct {
type TxDesc struct {
mining.TxDesc
// depCount is not 0 for dependent transaction. Dependent transaction is
// depCount is not 0 for a chained transaction. A chained transaction is
// one that is accepted to pool, but cannot be mined in next block because it
// depends on outputs of accepted, but still not mined transaction
depCount int
@@ -113,22 +110,24 @@ type TxPool struct {
// The following variables must only be used atomically.
lastUpdated int64 // last time pool was updated
mtx sync.RWMutex
cfg Config
pool map[daghash.TxID]*TxDesc
depends map[daghash.TxID]*TxDesc
dependsByPrev map[appmessage.Outpoint]map[daghash.TxID]*TxDesc
mtx sync.RWMutex
cfg Config
pool map[daghash.TxID]*TxDesc
chainedTransactions map[daghash.TxID]*TxDesc
chainedTransactionByPreviousOutpoint map[appmessage.Outpoint]*TxDesc
orphans map[daghash.TxID]*orphanTx
orphansByPrev map[appmessage.Outpoint]map[daghash.TxID]*util.Tx
outpoints map[appmessage.Outpoint]*util.Tx
mempoolUTXOSet *mempoolUTXOSet
// nextExpireScan is the time after which the orphan pool will be
// scanned in order to evict orphans. This is NOT a hard deadline as
// the scan will only run when an orphan is added to the pool as opposed
// to on an unconditional timer.
nextExpireScan mstime.Time
mpUTXOSet blockdag.UTXOSet
}
// Ensure the TxPool type implements the mining.TxSource interface.
@@ -341,7 +340,7 @@ func (mp *TxPool) IsTransactionInPool(hash *daghash.TxID) bool {
//
// This function MUST be called with the mempool lock held (for reads).
func (mp *TxPool) isInDependPool(hash *daghash.TxID) bool {
if _, exists := mp.depends[*hash]; exists {
if _, exists := mp.chainedTransactions[*hash]; exists {
return true
}
@@ -405,221 +404,129 @@ func (mp *TxPool) HaveTransaction(txID *daghash.TxID) bool {
return haveTx
}
// removeTransactions is the internal function which implements the public
// RemoveTransactions. See the comment for RemoveTransactions for more details.
//
// This method, in contrast to removeTransaction (singular), creates one utxoDiff
// and calls removeTransactionWithDiff on it for every transaction. This is an
// optimization to save us a good amount of allocations (specifically in
// UTXODiff.WithDiff) every time we accept a block.
// removeBlockTransactionsFromPool removes the transactions that are found in the block
// from the mempool, and move their chained mempool transactions (if any) to the main pool.
//
// This function MUST be called with the mempool lock held (for writes).
func (mp *TxPool) removeTransactions(txs []*util.Tx) error {
diff := blockdag.NewUTXODiff()
for _, tx := range txs {
func (mp *TxPool) removeBlockTransactionsFromPool(block *util.Block) error {
for _, tx := range block.Transactions()[util.CoinbaseTransactionIndex+1:] {
txID := tx.ID()
if _, exists := mp.fetchTxDesc(txID); !exists {
continue
}
err := mp.removeTransactionWithDiff(tx, diff, false)
err := mp.cleanTransactionFromSets(tx)
if err != nil {
return err
}
mp.updateBlockTransactionChainedTransactions(tx)
}
var err error
mp.mpUTXOSet, err = mp.mpUTXOSet.WithDiff(diff)
if err != nil {
return err
}
atomic.StoreInt64(&mp.lastUpdated, mstime.Now().UnixMilliseconds())
return nil
}
// removeTransaction is the internal function which implements the public
// RemoveTransaction. See the comment for RemoveTransaction for more details.
// removeTransactionAndItsChainedTransactions removes a transaction and all of its chained transaction from the mempool.
//
// This function MUST be called with the mempool lock held (for writes).
func (mp *TxPool) removeTransaction(tx *util.Tx, removeDependants bool, restoreInputs bool) error {
func (mp *TxPool) removeTransactionAndItsChainedTransactions(tx *util.Tx) error {
txID := tx.ID()
if removeDependants {
// Remove any transactions which rely on this one.
for i := uint32(0); i < uint32(len(tx.MsgTx().TxOut)); i++ {
prevOut := appmessage.Outpoint{TxID: *txID, Index: i}
if txRedeemer, exists := mp.outpoints[prevOut]; exists {
err := mp.removeTransaction(txRedeemer, true, false)
if err != nil {
return err
}
}
}
}
if _, exists := mp.fetchTxDesc(txID); !exists {
return nil
}
diff := blockdag.NewUTXODiff()
err := mp.removeTransactionWithDiff(tx, diff, restoreInputs)
if err != nil {
return err
}
mp.mpUTXOSet, err = mp.mpUTXOSet.WithDiff(diff)
if err != nil {
return err
}
atomic.StoreInt64(&mp.lastUpdated, mstime.Now().UnixMilliseconds())
return nil
}
// removeTransactionWithDiff removes the transaction tx from the mempool while
// updating the UTXODiff diff with appropriate changes. diff is later meant to
// be withDiff'd against the mempool UTXOSet to update it.
//
// This method assumes that tx exists in the mempool.
func (mp *TxPool) removeTransactionWithDiff(tx *util.Tx, diff *blockdag.UTXODiff, restoreInputs bool) error {
txID := tx.ID()
err := mp.removeTransactionUTXOEntriesFromDiff(tx, diff)
if err != nil {
return errors.Errorf("could not remove UTXOEntry from diff: %s", err)
}
err = mp.markTransactionOutputsUnspent(tx, diff, restoreInputs)
if err != nil {
return errors.Errorf("could not mark transaction output as unspent: %s", err)
}
txDesc, _ := mp.fetchTxDesc(txID)
if txDesc.depCount == 0 {
delete(mp.pool, *txID)
} else {
delete(mp.depends, *txID)
}
mp.processRemovedTransactionDependencies(tx)
return nil
}
// removeTransactionUTXOEntriesFromDiff removes tx's UTXOEntries from the diff
func (mp *TxPool) removeTransactionUTXOEntriesFromDiff(tx *util.Tx, diff *blockdag.UTXODiff) error {
for idx := range tx.MsgTx().TxOut {
outpoint := *appmessage.NewOutpoint(tx.ID(), uint32(idx))
entry, exists := mp.mpUTXOSet.Get(outpoint)
if exists {
err := diff.RemoveEntry(outpoint, entry)
// Remove any transactions which rely on this one.
for i := uint32(0); i < uint32(len(tx.MsgTx().TxOut)); i++ {
prevOut := appmessage.Outpoint{TxID: *txID, Index: i}
if txRedeemer, exists := mp.mempoolUTXOSet.poolTransactionBySpendingOutpoint(prevOut); exists {
err := mp.removeTransactionAndItsChainedTransactions(txRedeemer)
if err != nil {
return err
}
}
}
return nil
}
// markTransactionOutputsUnspent updates the mempool so that tx's TXOs are unspent
// Iff restoreInputs is true then the inputs are restored back into the supplied diff
func (mp *TxPool) markTransactionOutputsUnspent(tx *util.Tx, diff *blockdag.UTXODiff, restoreInputs bool) error {
for _, txIn := range tx.MsgTx().TxIn {
if restoreInputs {
if prevTxDesc, exists := mp.pool[txIn.PreviousOutpoint.TxID]; exists {
prevOut := prevTxDesc.Tx.MsgTx().TxOut[txIn.PreviousOutpoint.Index]
entry := blockdag.NewUTXOEntry(prevOut, false, blockdag.UnacceptedBlueScore)
err := diff.AddEntry(txIn.PreviousOutpoint, entry)
if err != nil {
return err
}
}
if prevTxDesc, exists := mp.depends[txIn.PreviousOutpoint.TxID]; exists {
prevOut := prevTxDesc.Tx.MsgTx().TxOut[txIn.PreviousOutpoint.Index]
entry := blockdag.NewUTXOEntry(prevOut, false, blockdag.UnacceptedBlueScore)
err := diff.AddEntry(txIn.PreviousOutpoint, entry)
if err != nil {
return err
}
}
}
delete(mp.outpoints, txIn.PreviousOutpoint)
if _, exists := mp.chainedTransactions[*tx.ID()]; exists {
mp.removeChainTransaction(tx)
}
err := mp.cleanTransactionFromSets(tx)
if err != nil {
return err
}
atomic.StoreInt64(&mp.lastUpdated, mstime.Now().UnixMilliseconds())
return nil
}
// processRemovedTransactionDependencies processes the dependencies of a
// transaction tx that was just now removed from the mempool
func (mp *TxPool) processRemovedTransactionDependencies(tx *util.Tx) {
// cleanTransactionFromSets removes the transaction from all mempool related transaction sets.
// It assumes that any chained transaction is already cleaned from the mempool.
//
// This function MUST be called with the mempool lock held (for writes).
func (mp *TxPool) cleanTransactionFromSets(tx *util.Tx) error {
err := mp.mempoolUTXOSet.removeTx(tx)
if err != nil {
return err
}
txID := *tx.ID()
delete(mp.pool, txID)
delete(mp.chainedTransactions, txID)
return nil
}
// updateBlockTransactionChainedTransactions processes the dependencies of a
// transaction that was included in a block and was just now removed from the mempool.
//
// This function MUST be called with the mempool lock held (for writes).
func (mp *TxPool) updateBlockTransactionChainedTransactions(tx *util.Tx) {
prevOut := appmessage.Outpoint{TxID: *tx.ID()}
for txOutIdx := range tx.MsgTx().TxOut {
// Skip to the next available output if there are none.
prevOut.Index = uint32(txOutIdx)
depends, exists := mp.dependsByPrev[prevOut]
txDesc, exists := mp.chainedTransactionByPreviousOutpoint[prevOut]
if !exists {
continue
}
// Move independent transactions into main pool
for _, txD := range depends {
txD.depCount--
if txD.depCount == 0 {
// Transaction may be already removed by recursive calls, if removeRedeemers is true.
// So avoid moving it into main pool
if _, ok := mp.depends[*txD.Tx.ID()]; ok {
delete(mp.depends, *txD.Tx.ID())
mp.pool[*txD.Tx.ID()] = txD
}
txDesc.depCount--
// If the transaction is not chained anymore, move it into the main pool
if txDesc.depCount == 0 {
// Transaction may be already removed by recursive calls, if removeRedeemers is true.
// So avoid moving it into main pool
if _, ok := mp.chainedTransactions[*txDesc.Tx.ID()]; ok {
delete(mp.chainedTransactions, *txDesc.Tx.ID())
mp.pool[*txDesc.Tx.ID()] = txDesc
}
}
delete(mp.dependsByPrev, prevOut)
delete(mp.chainedTransactionByPreviousOutpoint, prevOut)
}
}
// RemoveTransaction removes the passed transaction from the mempool. When the
// removeDependants flag is set, any transactions that depend on the removed
// transaction (that is to say, redeem outputs from it) will also be removed
// recursively from the mempool, as they would otherwise become orphans.
// removeChainTransaction removes a chain transaction and all of its relation as a result of double spend.
//
// This function is safe for concurrent access.
func (mp *TxPool) RemoveTransaction(tx *util.Tx, removeDependants bool, restoreInputs bool) error {
// Protect concurrent access.
mp.mtx.Lock()
defer mp.mtx.Unlock()
return mp.removeTransaction(tx, removeDependants, restoreInputs)
// This function MUST be called with the mempool lock held (for writes).
func (mp *TxPool) removeChainTransaction(tx *util.Tx) {
delete(mp.chainedTransactions, *tx.ID())
for _, txIn := range tx.MsgTx().TxIn {
delete(mp.chainedTransactionByPreviousOutpoint, txIn.PreviousOutpoint)
}
}
// RemoveTransactions removes the passed transactions from the mempool.
//
// This function is safe for concurrent access.
func (mp *TxPool) RemoveTransactions(txs []*util.Tx) error {
// Protect concurrent access.
mp.mtx.Lock()
defer mp.mtx.Unlock()
return mp.removeTransactions(txs)
}
// RemoveDoubleSpends removes all transactions which spend outputs spent by the
// removeDoubleSpends removes all transactions which spend outputs spent by the
// passed transaction from the memory pool. Removing those transactions then
// leads to removing all transactions which rely on them, recursively. This is
// necessary when a block is connected to the DAG because the block may
// contain transactions which were previously unknown to the memory pool.
//
// This function is safe for concurrent access.
func (mp *TxPool) RemoveDoubleSpends(tx *util.Tx) error {
// Protect concurrent access.
mp.mtx.Lock()
defer mp.mtx.Unlock()
return mp.removeDoubleSpends(tx)
}
// This function MUST be called with the mempool lock held (for writes).
func (mp *TxPool) removeDoubleSpends(tx *util.Tx) error {
for _, txIn := range tx.MsgTx().TxIn {
if txRedeemer, ok := mp.outpoints[txIn.PreviousOutpoint]; ok {
if txRedeemer, ok := mp.mempoolUTXOSet.poolTransactionBySpendingOutpoint(txIn.PreviousOutpoint); ok {
if !txRedeemer.ID().IsEqual(tx.ID()) {
err := mp.removeTransaction(txRedeemer, true, false)
err := mp.removeTransactionAndItsChainedTransactions(txRedeemer)
if err != nil {
return err
}
@@ -634,13 +541,9 @@ func (mp *TxPool) removeDoubleSpends(tx *util.Tx) error {
// helper for maybeAcceptTransaction.
//
// This function MUST be called with the mempool lock held (for writes).
func (mp *TxPool) addTransaction(tx *util.Tx, fee uint64, parentsInPool []*appmessage.Outpoint) (*TxDesc, error) {
func (mp *TxPool) addTransaction(tx *util.Tx, mass uint64, fee uint64, parentsInPool []*appmessage.Outpoint) (*TxDesc, error) {
// Add the transaction to the pool and mark the referenced outpoints
// as spent by the pool.
mass, err := blockdag.CalcTxMassFromUTXOSet(tx, mp.mpUTXOSet)
if err != nil {
return nil, err
}
txD := &TxDesc{
TxDesc: mining.TxDesc{
Tx: tx,
@@ -654,23 +557,17 @@ func (mp *TxPool) addTransaction(tx *util.Tx, fee uint64, parentsInPool []*appme
if len(parentsInPool) == 0 {
mp.pool[*tx.ID()] = txD
} else {
mp.depends[*tx.ID()] = txD
mp.chainedTransactions[*tx.ID()] = txD
for _, previousOutpoint := range parentsInPool {
if _, exists := mp.dependsByPrev[*previousOutpoint]; !exists {
mp.dependsByPrev[*previousOutpoint] = make(map[daghash.TxID]*TxDesc)
}
mp.dependsByPrev[*previousOutpoint][*tx.ID()] = txD
mp.chainedTransactionByPreviousOutpoint[*previousOutpoint] = txD
}
}
for _, txIn := range tx.MsgTx().TxIn {
mp.outpoints[txIn.PreviousOutpoint] = tx
}
if isAccepted, err := mp.mpUTXOSet.AddTx(tx.MsgTx(), blockdag.UnacceptedBlueScore); err != nil {
err := mp.mempoolUTXOSet.addTx(tx)
if err != nil {
return nil, err
} else if !isAccepted {
return nil, errors.Errorf("unexpectedly failed to add tx %s to the mempool utxo set", tx.ID())
}
atomic.StoreInt64(&mp.lastUpdated, mstime.Now().UnixMilliseconds())
return txD, nil
@@ -684,7 +581,7 @@ func (mp *TxPool) addTransaction(tx *util.Tx, fee uint64, parentsInPool []*appme
// This function MUST be called with the mempool lock held (for reads).
func (mp *TxPool) checkPoolDoubleSpend(tx *util.Tx) error {
for _, txIn := range tx.MsgTx().TxIn {
if txR, exists := mp.outpoints[txIn.PreviousOutpoint]; exists {
if txR, exists := mp.mempoolUTXOSet.poolTransactionBySpendingOutpoint(txIn.PreviousOutpoint); exists {
str := fmt.Sprintf("output %s already spent by "+
"transaction %s in the memory pool",
txIn.PreviousOutpoint, txR.ID())
@@ -695,22 +592,11 @@ func (mp *TxPool) checkPoolDoubleSpend(tx *util.Tx) error {
return nil
}
// CheckSpend checks whether the passed outpoint is already spent by a
// transaction in the mempool. If that's the case the spending transaction will
// be returned, if not nil will be returned.
func (mp *TxPool) CheckSpend(op appmessage.Outpoint) *util.Tx {
mp.mtx.RLock()
defer mp.mtx.RUnlock()
txR := mp.outpoints[op]
return txR
}
// This function MUST be called with the mempool lock held (for reads).
func (mp *TxPool) fetchTxDesc(txID *daghash.TxID) (*TxDesc, bool) {
txDesc, exists := mp.pool[*txID]
if !exists {
txDesc, exists = mp.depends[*txID]
txDesc, exists = mp.chainedTransactions[*txID]
}
return txDesc, exists
}
@@ -885,7 +771,7 @@ func (mp *TxPool) maybeAcceptTransaction(tx *util.Tx, rejectDupOrphans bool) ([]
prevOut := appmessage.Outpoint{TxID: *txID}
for txOutIdx := range tx.MsgTx().TxOut {
prevOut.Index = uint32(txOutIdx)
_, ok := mp.mpUTXOSet.Get(prevOut)
_, _, ok := mp.mempoolUTXOSet.utxoEntryByOutpoint(prevOut)
if ok {
return nil, nil, txRuleError(RejectDuplicate,
"transaction already exists")
@@ -896,21 +782,7 @@ func (mp *TxPool) maybeAcceptTransaction(tx *util.Tx, rejectDupOrphans bool) ([]
// don't exist or are already spent. Adding orphans to the orphan pool
// is not handled by this function, and the caller should use
// maybeAddOrphan if this behavior is desired.
var missingParents []*daghash.TxID
var parentsInPool []*appmessage.Outpoint
for _, txIn := range tx.MsgTx().TxIn {
if _, ok := mp.mpUTXOSet.Get(txIn.PreviousOutpoint); !ok {
// Must make a copy of the hash here since the iterator
// is replaced and taking its address directly would
// result in all of the entries pointing to the same
// memory location and thus all be the final hash.
txIDCopy := txIn.PreviousOutpoint.TxID
missingParents = append(missingParents, &txIDCopy)
}
if mp.isTransactionInPool(&txIn.PreviousOutpoint.TxID) {
parentsInPool = append(parentsInPool, &txIn.PreviousOutpoint)
}
}
spentUTXOEntries, parentsInPool, missingParents := mp.mempoolUTXOSet.transactionRelatedUTXOEntries(tx)
if len(missingParents) > 0 {
return missingParents, nil, nil
}
@@ -918,7 +790,7 @@ func (mp *TxPool) maybeAcceptTransaction(tx *util.Tx, rejectDupOrphans bool) ([]
// Don't allow the transaction into the mempool unless its sequence
// lock is active, meaning that it'll be allowed into the next block
// with respect to its defined relative lock times.
sequenceLock, err := mp.cfg.CalcSequenceLockNoLock(tx, mp.mpUTXOSet)
sequenceLock, err := mp.cfg.CalcTxSequenceLockFromReferencedUTXOEntries(tx, spentUTXOEntries)
if err != nil {
var dagRuleErr blockdag.RuleError
if ok := errors.As(err, &dagRuleErr); ok {
@@ -934,7 +806,7 @@ func (mp *TxPool) maybeAcceptTransaction(tx *util.Tx, rejectDupOrphans bool) ([]
// Don't allow transactions that exceed the maximum allowed
// transaction mass.
err = blockdag.ValidateTxMass(tx, mp.mpUTXOSet)
mass, err := blockdag.ValidateTxMass(tx, spentUTXOEntries)
if err != nil {
var ruleError blockdag.RuleError
if ok := errors.As(err, &ruleError); ok {
@@ -948,7 +820,7 @@ func (mp *TxPool) maybeAcceptTransaction(tx *util.Tx, rejectDupOrphans bool) ([]
// Also returns the fees associated with the transaction which will be
// used later.
txFee, err := blockdag.CheckTransactionInputsAndCalulateFee(tx, nextBlockBlueScore,
mp.mpUTXOSet, mp.cfg.DAG.Params, false)
spentUTXOEntries, mp.cfg.DAG.Params, false)
if err != nil {
var dagRuleErr blockdag.RuleError
if ok := errors.As(err, &dagRuleErr); ok {
@@ -960,7 +832,7 @@ func (mp *TxPool) maybeAcceptTransaction(tx *util.Tx, rejectDupOrphans bool) ([]
// Don't allow transactions with non-standard inputs if the network
// parameters forbid their acceptance.
if !mp.cfg.Policy.AcceptNonStd {
err := checkInputsStandard(tx, mp.mpUTXOSet)
err := checkInputsStandard(tx, spentUTXOEntries)
if err != nil {
// Attempt to extract a reject code from the error so
// it can be retained. When not possible, fall back to
@@ -1008,7 +880,7 @@ func (mp *TxPool) maybeAcceptTransaction(tx *util.Tx, rejectDupOrphans bool) ([]
// Verify crypto signatures for each input and reject the transaction if
// any don't verify.
err = blockdag.ValidateTransactionScripts(tx, mp.mpUTXOSet,
err = blockdag.ValidateTransactionScripts(tx, spentUTXOEntries,
txscript.StandardVerifyFlags, mp.cfg.SigCache)
if err != nil {
var dagRuleErr blockdag.RuleError
@@ -1019,7 +891,7 @@ func (mp *TxPool) maybeAcceptTransaction(tx *util.Tx, rejectDupOrphans bool) ([]
}
// Add to transaction pool.
txD, err := mp.addTransaction(tx, txFee, parentsInPool)
txDesc, err := mp.addTransaction(tx, mass, txFee, parentsInPool)
if err != nil {
return nil, nil, err
}
@@ -1027,7 +899,7 @@ func (mp *TxPool) maybeAcceptTransaction(tx *util.Tx, rejectDupOrphans bool) ([]
log.Debugf("Accepted transaction %s (pool size: %d)", txID,
len(mp.pool))
return nil, txD, nil
return nil, txDesc, nil
}
// processOrphans is the internal function which implements the public
@@ -1124,8 +996,6 @@ func (mp *TxPool) processOrphans(acceptedTx *util.Tx) []*TxDesc {
//
// This function is safe for concurrent access.
func (mp *TxPool) ProcessOrphans(acceptedTx *util.Tx) []*TxDesc {
mp.cfg.DAG.RLock()
defer mp.cfg.DAG.RUnlock()
mp.mtx.Lock()
defer mp.mtx.Unlock()
acceptedTxns := mp.processOrphans(acceptedTx)
@@ -1148,8 +1018,6 @@ func (mp *TxPool) ProcessTransaction(tx *util.Tx, allowOrphan bool) ([]*TxDesc,
log.Tracef("Processing transaction %s", tx.ID())
// Protect concurrent access.
mp.cfg.DAG.RLock()
defer mp.cfg.DAG.RUnlock()
mp.mtx.Lock()
defer mp.mtx.Unlock()
@@ -1210,14 +1078,14 @@ func (mp *TxPool) Count() int {
return count
}
// DepCount returns the number of dependent transactions in the main pool. It does not
// ChainedCount returns the number of chained transactions in the mempool. It does not
// include the orphan pool.
//
// This function is safe for concurrent access.
func (mp *TxPool) DepCount() int {
func (mp *TxPool) ChainedCount() int {
mp.mtx.RLock()
defer mp.mtx.RUnlock()
return len(mp.depends)
return len(mp.chainedTransactions)
}
// TxIDs returns a slice of IDs for all of the transactions in the memory
@@ -1287,13 +1155,9 @@ func (mp *TxPool) LastUpdated() mstime.Time {
// transaction that is already in the DAG
func (mp *TxPool) HandleNewBlock(block *util.Block) ([]*util.Tx, error) {
// Protect concurrent access.
mp.cfg.DAG.RLock()
defer mp.cfg.DAG.RUnlock()
mp.mtx.Lock()
defer mp.mtx.Unlock()
oldUTXOSet := mp.mpUTXOSet
// Remove all of the transactions (except the coinbase) in the
// connected block from the transaction pool. Secondly, remove any
// transactions which are now double spends as a result of these
@@ -1301,9 +1165,8 @@ func (mp *TxPool) HandleNewBlock(block *util.Block) ([]*util.Tx, error) {
// no longer an orphan. Transactions which depend on a confirmed
// transaction are NOT removed recursively because they are still
// valid.
err := mp.removeTransactions(block.Transactions()[util.CoinbaseTransactionIndex+1:])
err := mp.removeBlockTransactionsFromPool(block)
if err != nil {
mp.mpUTXOSet = oldUTXOSet
return nil, err
}
acceptedTxs := make([]*util.Tx, 0)
@@ -1324,17 +1187,14 @@ func (mp *TxPool) HandleNewBlock(block *util.Block) ([]*util.Tx, error) {
// New returns a new memory pool for validating and storing standalone
// transactions until they are mined into a block.
func New(cfg *Config) *TxPool {
virtualUTXO := cfg.DAG.UTXOSet()
mpUTXO := blockdag.NewDiffUTXOSet(virtualUTXO, blockdag.NewUTXODiff())
return &TxPool{
cfg: *cfg,
pool: make(map[daghash.TxID]*TxDesc),
depends: make(map[daghash.TxID]*TxDesc),
dependsByPrev: make(map[appmessage.Outpoint]map[daghash.TxID]*TxDesc),
orphans: make(map[daghash.TxID]*orphanTx),
orphansByPrev: make(map[appmessage.Outpoint]map[daghash.TxID]*util.Tx),
nextExpireScan: mstime.Now().Add(orphanExpireScanInterval),
outpoints: make(map[appmessage.Outpoint]*util.Tx),
mpUTXOSet: mpUTXO,
cfg: *cfg,
pool: make(map[daghash.TxID]*TxDesc),
chainedTransactions: make(map[daghash.TxID]*TxDesc),
chainedTransactionByPreviousOutpoint: make(map[appmessage.Outpoint]*TxDesc),
orphans: make(map[daghash.TxID]*orphanTx),
orphansByPrev: make(map[appmessage.Outpoint]map[daghash.TxID]*util.Tx),
nextExpireScan: mstime.Now().Add(orphanExpireScanInterval),
mempoolUTXOSet: newMempoolUTXOSet(cfg.DAG),
}
}

View File

@@ -69,8 +69,7 @@ func (s *fakeDAG) SetMedianTimePast(mtp mstime.Time) {
s.medianTimePast = mtp
}
func calcSequenceLock(tx *util.Tx,
utxoSet blockdag.UTXOSet) (*blockdag.SequenceLock, error) {
func calcTxSequenceLockFromReferencedUTXOEntries(tx *util.Tx, referencedUTXOEntries []*blockdag.UTXOEntry) (*blockdag.SequenceLock, error) {
return &blockdag.SequenceLock{
Milliseconds: -1,
@@ -339,8 +338,8 @@ func newPoolHarness(t *testing.T, dagParams *dagconfig.Params, numOutputs uint32
MinRelayTxFee: 1000, // 1 sompi per byte
MaxTxVersion: 1,
},
CalcSequenceLockNoLock: calcSequenceLock,
SigCache: nil,
CalcTxSequenceLockFromReferencedUTXOEntries: calcTxSequenceLockFromReferencedUTXOEntries,
SigCache: nil,
}),
}
@@ -628,10 +627,7 @@ func TestProcessTransaction(t *testing.T) {
t.Fatalf("Script: error creating wrappedP2shNonSigScript: %v", err)
}
dummyPrevOutTxID, err := daghash.NewTxIDFromStr("01")
if err != nil {
t.Fatalf("NewShaHashFromStr: unexpected error: %v", err)
}
dummyPrevOutTxID := &daghash.TxID{}
dummyPrevOut := appmessage.Outpoint{TxID: *dummyPrevOutTxID, Index: 1}
dummySigScript := bytes.Repeat([]byte{0x00}, 65)
@@ -646,10 +642,8 @@ func TestProcessTransaction(t *testing.T) {
t.Fatalf("PayToAddrScript: unexpected error: %v", err)
}
p2shTx := util.NewTx(appmessage.NewNativeMsgTx(1, nil, []*appmessage.TxOut{{Value: 5000000000, ScriptPubKey: p2shScriptPubKey}}))
if isAccepted, err := harness.txPool.mpUTXOSet.AddTx(p2shTx.MsgTx(), currentBlueScore+1); err != nil {
if err := harness.txPool.mempoolUTXOSet.addTx(p2shTx); err != nil {
t.Fatalf("AddTx unexpectedly failed. Error: %s", err)
} else if !isAccepted {
t.Fatalf("AddTx unexpectedly didn't add tx %s", p2shTx.ID())
}
txIns := []*appmessage.TxIn{{
@@ -691,8 +685,7 @@ func TestProcessTransaction(t *testing.T) {
}
// Checks that transactions get rejected from mempool if sequence lock is not active
harness.txPool.cfg.CalcSequenceLockNoLock = func(tx *util.Tx,
view blockdag.UTXOSet) (*blockdag.SequenceLock, error) {
harness.txPool.cfg.CalcTxSequenceLockFromReferencedUTXOEntries = func(tx *util.Tx, referencedUTXOEntries []*blockdag.UTXOEntry) (*blockdag.SequenceLock, error) {
return &blockdag.SequenceLock{
Milliseconds: math.MaxInt64,
@@ -714,7 +707,7 @@ func TestProcessTransaction(t *testing.T) {
if err.Error() != expectedErrStr {
t.Errorf("Unexpected error message. Expected \"%s\" but got \"%s\"", expectedErrStr, err.Error())
}
harness.txPool.cfg.CalcSequenceLockNoLock = calcSequenceLock
harness.txPool.cfg.CalcTxSequenceLockFromReferencedUTXOEntries = calcTxSequenceLockFromReferencedUTXOEntries
// Transaction should be rejected from mempool because it has low fee, and its priority is above mining.MinHighPriority
tx, err = harness.createTx(spendableOuts[4], 0, 1000)
@@ -796,7 +789,7 @@ func TestDoubleSpends(t *testing.T) {
// Then we assume tx3 is already in the DAG, so we need to remove
// transactions that spends the same outpoints from the mempool
harness.txPool.RemoveDoubleSpends(tx3)
harness.txPool.removeDoubleSpends(tx3)
// Ensures that only the transaction that double spends the same
// funds as tx3 is removed, and the other one remains unaffected
testPoolMembership(tc, tx1, false, false, false)
@@ -1132,10 +1125,10 @@ func TestRemoveTransaction(t *testing.T) {
testPoolMembership(tc, chainedTxns[3], false, true, true)
testPoolMembership(tc, chainedTxns[4], false, true, true)
// Checks that when removeRedeemers is true, all of the transaction that are dependent on it get removed
err = harness.txPool.RemoveTransaction(chainedTxns[1], true, true)
// Checks that all of the transaction that are dependent on it get removed
err = harness.txPool.removeTransactionAndItsChainedTransactions(chainedTxns[1])
if err != nil {
t.Fatalf("RemoveTransaction: %v", err)
t.Fatalf("removeTransactionAndItsChainedTransactions: %v", err)
}
testPoolMembership(tc, chainedTxns[1], false, false, false)
testPoolMembership(tc, chainedTxns[2], false, false, false)
@@ -1429,9 +1422,9 @@ func TestMultiInputOrphanDoubleSpend(t *testing.T) {
testPoolMembership(tc, doubleSpendTx, false, false, false)
}
// TestCheckSpend tests that CheckSpend returns the expected spends found in
// TestPoolTransactionBySpendingOutpoint tests that poolTransactionBySpendingOutpoint returns the expected spends found in
// the mempool.
func TestCheckSpend(t *testing.T) {
func TestPoolTransactionBySpendingOutpoint(t *testing.T) {
tc, outputs, teardownFunc, err := newPoolHarness(t, &dagconfig.SimnetParams, 1, "TestCheckSpend")
if err != nil {
t.Fatalf("unable to create test pool: %v", err)
@@ -1442,8 +1435,8 @@ func TestCheckSpend(t *testing.T) {
// The mempool is empty, so none of the spendable outputs should have a
// spend there.
for _, op := range outputs {
spend := harness.txPool.CheckSpend(op.outpoint)
if spend != nil {
spend, ok := harness.txPool.mempoolUTXOSet.poolTransactionBySpendingOutpoint(op.outpoint)
if ok {
t.Fatalf("Unexpeced spend found in pool: %v", spend)
}
}
@@ -1466,7 +1459,7 @@ func TestCheckSpend(t *testing.T) {
// The first tx in the chain should be the spend of the spendable
// output.
op := outputs[0].outpoint
spend := harness.txPool.CheckSpend(op)
spend, _ := harness.txPool.mempoolUTXOSet.poolTransactionBySpendingOutpoint(op)
if spend != chainedTxns[0] {
t.Fatalf("expected %v to be spent by %v, instead "+
"got %v", op, chainedTxns[0], spend)
@@ -1479,7 +1472,7 @@ func TestCheckSpend(t *testing.T) {
Index: 0,
}
expSpend := chainedTxns[i+1]
spend = harness.txPool.CheckSpend(op)
spend, _ = harness.txPool.mempoolUTXOSet.poolTransactionBySpendingOutpoint(op)
if spend != expSpend {
t.Fatalf("expected %v to be spent by %v, instead "+
"got %v", op, expSpend, spend)
@@ -1491,7 +1484,7 @@ func TestCheckSpend(t *testing.T) {
TxID: *chainedTxns[txChainLength-1].ID(),
Index: 0,
}
spend = harness.txPool.CheckSpend(op)
spend, _ = harness.txPool.mempoolUTXOSet.poolTransactionBySpendingOutpoint(op)
if spend != nil {
t.Fatalf("Unexpeced spend found in pool: %v", spend)
}
@@ -1518,16 +1511,21 @@ func TestCount(t *testing.T) {
if err != nil {
t.Errorf("ProcessTransaction: unexpected error: %v", err)
}
if harness.txPool.Count()+harness.txPool.DepCount() != i+1 {
if harness.txPool.Count()+harness.txPool.ChainedCount() != i+1 {
t.Errorf("TestCount: txPool expected to have %v transactions but got %v", i+1, harness.txPool.Count())
}
}
err = harness.txPool.RemoveTransaction(chainedTxns[0], false, false)
// Mimic a situation where the first transaction is found in a block
fakeBlock := appmessage.NewMsgBlock(&appmessage.BlockHeader{})
fakeCoinbase := &appmessage.MsgTx{}
fakeBlock.AddTransaction(fakeCoinbase)
fakeBlock.AddTransaction(chainedTxns[0].MsgTx())
err = harness.txPool.removeBlockTransactionsFromPool(util.NewBlock(fakeBlock))
if err != nil {
t.Fatalf("harness.CreateTxChain: unexpected error: %v", err)
}
if harness.txPool.Count()+harness.txPool.DepCount() != 2 {
if harness.txPool.Count()+harness.txPool.ChainedCount() != 2 {
t.Errorf("TestCount: txPool expected to have 2 transactions but got %v", harness.txPool.Count())
}
}
@@ -1636,82 +1634,15 @@ func TestHandleNewBlock(t *testing.T) {
if err != nil {
t.Fatalf("unable to create transaction 1: %v", err)
}
dummyBlock.Transactions = append(dummyBlock.Transactions, blockTx1.MsgTx(), blockTx2.MsgTx())
// Create block and add its transactions to UTXO set
block := util.NewBlock(&dummyBlock)
for i, tx := range block.Transactions() {
if isAccepted, err := harness.txPool.mpUTXOSet.AddTx(tx.MsgTx(), 1); err != nil {
t.Fatalf("Failed to add transaction (%v,%v) to UTXO set: %v", i, tx.ID(), err)
} else if !isAccepted {
t.Fatalf("AddTx unexpectedly didn't add tx %s", tx.ID())
}
}
block := blockdag.PrepareAndProcessBlockForTest(t, harness.txPool.cfg.DAG, harness.txPool.cfg.DAG.TipHashes(), []*appmessage.MsgTx{blockTx1.MsgTx(), blockTx2.MsgTx()})
// Handle new block by pool
_, err = harness.txPool.HandleNewBlock(block)
_, err = harness.txPool.HandleNewBlock(util.NewBlock(block))
// ensure that orphan transaction moved to main pool
testPoolMembership(tc, orphanTx, false, true, false)
}
// dummyBlock defines a block on the block DAG. It is used to test block operations.
var dummyBlock = appmessage.MsgBlock{
Header: appmessage.BlockHeader{
Version: 1,
ParentHashes: []*daghash.Hash{
{
0x82, 0xdc, 0xbd, 0xe6, 0x88, 0x37, 0x74, 0x5b,
0x78, 0x6b, 0x03, 0x1d, 0xa3, 0x48, 0x3c, 0x45,
0x3f, 0xc3, 0x2e, 0xd4, 0x53, 0x5b, 0x6f, 0x26,
0x26, 0xb0, 0x48, 0x4f, 0x09, 0x00, 0x00, 0x00,
}, // Mainnet genesis
{
0xc1, 0x5b, 0x71, 0xfe, 0x20, 0x70, 0x0f, 0xd0,
0x08, 0x49, 0x88, 0x1b, 0x32, 0xb5, 0xbd, 0x13,
0x17, 0xbe, 0x75, 0xe7, 0x29, 0x46, 0xdd, 0x03,
0x01, 0x92, 0x90, 0xf1, 0xca, 0x8a, 0x88, 0x11,
}}, // Simnet genesis
HashMerkleRoot: &daghash.Hash{
0x66, 0x57, 0xa9, 0x25, 0x2a, 0xac, 0xd5, 0xc0,
0xb2, 0x94, 0x09, 0x96, 0xec, 0xff, 0x95, 0x22,
0x28, 0xc3, 0x06, 0x7c, 0xc3, 0x8d, 0x48, 0x85,
0xef, 0xb5, 0xa4, 0xac, 0x42, 0x47, 0xe9, 0xf3,
}, // f3e94742aca4b5ef85488dc37c06c3282295ffec960994b2c0d5ac2a25a95766
Timestamp: mstime.UnixMilliseconds(1529483563000), // 2018-06-20 08:32:43 +0000 UTC
Bits: 0x1e00ffff, // 503382015
Nonce: 0x000ae53f, // 714047
},
Transactions: []*appmessage.MsgTx{
{
Version: 1,
TxIn: []*appmessage.TxIn{},
TxOut: []*appmessage.TxOut{
{
Value: 0x12a05f200, // 5000000000
ScriptPubKey: []byte{
0xa9, 0x14, 0xda, 0x17, 0x45, 0xe9, 0xb5, 0x49,
0xbd, 0x0b, 0xfa, 0x1a, 0x56, 0x99, 0x71, 0xc7,
0x7e, 0xba, 0x30, 0xcd, 0x5a, 0x4b, 0x87,
},
},
},
LockTime: 0,
SubnetworkID: *subnetworkid.SubnetworkIDCoinbase,
Payload: []byte{
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00,
},
PayloadHash: &daghash.Hash{
0x14, 0x06, 0xe0, 0x58, 0x81, 0xe2, 0x99, 0x36,
0x77, 0x66, 0xd3, 0x13, 0xe2, 0x6c, 0x05, 0x56,
0x4e, 0xc9, 0x1b, 0xf7, 0x21, 0xd3, 0x17, 0x26,
0xbd, 0x6e, 0x46, 0xe6, 0x06, 0x89, 0x53, 0x9a,
},
},
},
}
func TestTransactionGas(t *testing.T) {
params := dagconfig.SimnetParams
params.BlockCoinbaseMaturity = 0

View File

@@ -0,0 +1,115 @@
package mempool
import (
"github.com/kaspanet/kaspad/app/appmessage"
"github.com/kaspanet/kaspad/domain/blockdag"
"github.com/kaspanet/kaspad/util"
"github.com/kaspanet/kaspad/util/daghash"
"github.com/pkg/errors"
)
func newMempoolUTXOSet(dag *blockdag.BlockDAG) *mempoolUTXOSet {
return &mempoolUTXOSet{
transactionByPreviousOutpoint: make(map[appmessage.Outpoint]*util.Tx),
poolUnspentOutputs: make(map[appmessage.Outpoint]*blockdag.UTXOEntry),
dag: dag,
}
}
type mempoolUTXOSet struct {
transactionByPreviousOutpoint map[appmessage.Outpoint]*util.Tx
poolUnspentOutputs map[appmessage.Outpoint]*blockdag.UTXOEntry
dag *blockdag.BlockDAG
}
func (mpus *mempoolUTXOSet) utxoEntryByOutpoint(outpoint appmessage.Outpoint) (entry *blockdag.UTXOEntry, isInPool bool, exists bool) {
entry, exists = mpus.dag.GetUTXOEntry(outpoint)
if !exists {
entry, exists := mpus.poolUnspentOutputs[outpoint]
if !exists {
return nil, false, false
}
return entry, true, true
}
return entry, false, true
}
// addTx adds a transaction to the mempool UTXO set. It assumes that it doesn't double spend another transaction
// in the mempool, and that its outputs doesn't exist in the mempool UTXO set, and returns error otherwise.
func (mpus *mempoolUTXOSet) addTx(tx *util.Tx) error {
msgTx := tx.MsgTx()
for _, txIn := range msgTx.TxIn {
if existingTx, exists := mpus.transactionByPreviousOutpoint[txIn.PreviousOutpoint]; exists {
return errors.Errorf("outpoint %s is already used by %s", txIn.PreviousOutpoint, existingTx.ID())
}
mpus.transactionByPreviousOutpoint[txIn.PreviousOutpoint] = tx
}
for i, txOut := range msgTx.TxOut {
outpoint := appmessage.NewOutpoint(tx.ID(), uint32(i))
if _, exists := mpus.poolUnspentOutputs[*outpoint]; exists {
return errors.Errorf("outpoint %s already exists", outpoint)
}
mpus.poolUnspentOutputs[*outpoint] = blockdag.NewUTXOEntry(txOut, false, blockdag.UnacceptedBlueScore)
}
return nil
}
// removeTx removes a transaction to the mempool UTXO set.
// Note: it doesn't re-add its previous outputs to the mempool UTXO set.
func (mpus *mempoolUTXOSet) removeTx(tx *util.Tx) error {
msgTx := tx.MsgTx()
for _, txIn := range msgTx.TxIn {
if _, exists := mpus.transactionByPreviousOutpoint[txIn.PreviousOutpoint]; !exists {
return errors.Errorf("outpoint %s doesn't exist", txIn.PreviousOutpoint)
}
delete(mpus.transactionByPreviousOutpoint, txIn.PreviousOutpoint)
}
for i := range msgTx.TxOut {
outpoint := appmessage.NewOutpoint(tx.ID(), uint32(i))
if _, exists := mpus.poolUnspentOutputs[*outpoint]; !exists {
return errors.Errorf("outpoint %s doesn't exist", outpoint)
}
delete(mpus.poolUnspentOutputs, *outpoint)
}
return nil
}
func (mpus *mempoolUTXOSet) poolTransactionBySpendingOutpoint(outpoint appmessage.Outpoint) (*util.Tx, bool) {
tx, exists := mpus.transactionByPreviousOutpoint[outpoint]
return tx, exists
}
func (mpus *mempoolUTXOSet) transactionRelatedUTXOEntries(tx *util.Tx) (spentUTXOEntries []*blockdag.UTXOEntry, parentsInPool []*appmessage.Outpoint, missingParents []*daghash.TxID) {
msgTx := tx.MsgTx()
spentUTXOEntries = make([]*blockdag.UTXOEntry, len(msgTx.TxIn))
missingParents = make([]*daghash.TxID, 0)
parentsInPool = make([]*appmessage.Outpoint, 0)
isOrphan := false
for i, txIn := range msgTx.TxIn {
entry, isInPool, exists := mpus.utxoEntryByOutpoint(txIn.PreviousOutpoint)
if !exists {
isOrphan = true
missingParents = append(missingParents, &txIn.PreviousOutpoint.TxID)
}
if isOrphan {
continue
}
if isInPool {
parentsInPool = append(parentsInPool, &txIn.PreviousOutpoint)
}
spentUTXOEntries[i] = entry
}
if isOrphan {
return nil, nil, missingParents
}
return spentUTXOEntries, parentsInPool, nil
}

View File

@@ -80,7 +80,7 @@ func calcMinRequiredTxRelayFee(serializedSize int64, minRelayTxFee util.Amount)
// context of this function is one whose referenced public key script is of a
// standard form and, for pay-to-script-hash, does not have more than
// maxStandardP2SHSigOps signature operations.
func checkInputsStandard(tx *util.Tx, utxoSet blockdag.UTXOSet) error {
func checkInputsStandard(tx *util.Tx, referencedUTXOEntries []*blockdag.UTXOEntry) error {
// NOTE: The reference implementation also does a coinbase check here,
// but coinbases have already been rejected prior to calling this
// function so no need to recheck.
@@ -89,7 +89,7 @@ func checkInputsStandard(tx *util.Tx, utxoSet blockdag.UTXOSet) error {
// It is safe to elide existence and index checks here since
// they have already been checked prior to calling this
// function.
entry, _ := utxoSet.Get(txIn.PreviousOutpoint)
entry := referencedUTXOEntries[i]
originScriptPubKey := entry.ScriptPubKey()
switch txscript.GetScriptClass(originScriptPubKey) {
case txscript.ScriptHashTy:

View File

@@ -168,10 +168,7 @@ func TestDust(t *testing.T) {
// TestCheckTransactionStandard tests the checkTransactionStandard API.
func TestCheckTransactionStandard(t *testing.T) {
// Create some dummy, but otherwise standard, data for transactions.
prevOutTxID, err := daghash.NewTxIDFromStr("01")
if err != nil {
t.Fatalf("NewShaHashFromStr: unexpected error: %v", err)
}
prevOutTxID := &daghash.TxID{}
dummyPrevOut := appmessage.Outpoint{TxID: *prevOutTxID, Index: 1}
dummySigScript := bytes.Repeat([]byte{0x00}, 65)
dummyTxIn := appmessage.TxIn{

View File

@@ -79,9 +79,6 @@ type BlockTemplate struct {
// coinbase, the first entry (offset 0) will contain the negative of the
// sum of the fees of all other transactions.
Fees []uint64
// Height is the height at which the block template connects to the DAG
Height uint64
}
// BlkTmplGenerator provides a type that can be used to generate block templates
@@ -176,10 +173,17 @@ func NewBlkTmplGenerator(policy *Policy,
// | <= policy.BlockMinSize) | |
// ----------------------------------- --
func (g *BlkTmplGenerator) NewBlockTemplate(payToAddress util.Address, extraNonce uint64) (*BlockTemplate, error) {
mempoolTransactions := g.txSource.MiningDescs()
// The lock is called only after MiningDescs() to avoid a potential deadlock:
// MiningDescs() requires the TxPool's read lock, and TxPool.ProcessTransaction
// requires the dag's read lock, so if NewBlockTemplate will call the lock before, it
// might cause a dead lock.
g.dag.Lock()
defer g.dag.Unlock()
txsForBlockTemplate, err := g.selectTxs(payToAddress, extraNonce)
txsForBlockTemplate, err := g.selectTxs(mempoolTransactions, payToAddress, extraNonce)
if err != nil {
return nil, errors.Errorf("failed to select transactions: %s", err)
}

View File

@@ -65,9 +65,8 @@ type txsForBlockTemplate struct {
// Once the sum of probabilities of marked transactions is greater than
// rebalanceThreshold percent of the sum of probabilities of all transactions,
// rebalance.
func (g *BlkTmplGenerator) selectTxs(payToAddress util.Address, extraNonce uint64) (*txsForBlockTemplate, error) {
// Fetch the source transactions.
sourceTxs := g.txSource.MiningDescs()
func (g *BlkTmplGenerator) selectTxs(mempoolTransactions []*TxDesc, payToAddress util.Address,
extraNonce uint64) (*txsForBlockTemplate, error) {
// Create a new txsForBlockTemplate struct, onto which all selectedTxs
// will be appended.
@@ -78,7 +77,7 @@ func (g *BlkTmplGenerator) selectTxs(payToAddress util.Address, extraNonce uint6
// Collect candidateTxs while excluding txs that will certainly not
// be selected.
candidateTxs := g.collectCandidatesTxs(sourceTxs)
candidateTxs := g.collectCandidatesTxs(mempoolTransactions)
log.Debugf("Considering %d transactions for inclusion to new block",
len(candidateTxs))

View File

@@ -51,10 +51,11 @@ const (
defaultMinRelayTxFee = 1e-5 // 1 sompi per byte
defaultMaxOrphanTransactions = 100
//DefaultMaxOrphanTxSize is the default maximum size for an orphan transaction
DefaultMaxOrphanTxSize = 100000
defaultSigCacheMaxSize = 100000
sampleConfigFilename = "sample-kaspad.conf"
defaultAcceptanceIndex = false
DefaultMaxOrphanTxSize = 100000
defaultSigCacheMaxSize = 100000
sampleConfigFilename = "sample-kaspad.conf"
defaultAcceptanceIndex = false
defaultMaxUTXOCacheSize = 5000000000
)
var (
@@ -106,7 +107,6 @@ type Flags struct {
ProxyPass string `long:"proxypass" default-mask:"-" description:"Password for proxy server"`
DbType string `long:"dbtype" description:"Database backend to use for the Block DAG"`
Profile string `long:"profile" description:"Enable HTTP profiling on given port -- NOTE port must be between 1024 and 65536"`
CPUProfile string `long:"cpuprofile" description:"Write CPU profile to the specified file"`
DebugLevel string `short:"d" long:"debuglevel" description:"Logging level for all subsystems {trace, debug, info, warn, error, critical} -- You may also specify <subsystem>=<level>,<subsystem2>=<level>,... to set the log level for individual subsystems -- Use show to list available subsystems"`
Upnp bool `long:"upnp" description:"Use UPnP to map our listening port outside of NAT"`
MinRelayTxFee float64 `long:"minrelaytxfee" description:"The minimum transaction fee in KAS/kB to be considered a non-zero fee."`
@@ -121,7 +121,9 @@ type Flags struct {
RelayNonStd bool `long:"relaynonstd" description:"Relay non-standard transactions regardless of the default settings for the active network."`
RejectNonStd bool `long:"rejectnonstd" description:"Reject non-standard transactions regardless of the default settings for the active network."`
ResetDatabase bool `long:"reset-db" description:"Reset database before starting node. It's needed when switching between subnetworks."`
MaxUTXOCacheSize uint64 `long:"maxutxocachesize" description:"Max size of loaded UTXO into ram from the disk in bytes"`
NetworkFlags
ServiceOptions *ServiceOptions
}
// Config defines the configuration options for kaspad.
@@ -137,9 +139,9 @@ type Config struct {
SubnetworkID *subnetworkid.SubnetworkID // nil in full nodes
}
// serviceOptions defines the configuration options for the daemon as a service on
// ServiceOptions defines the configuration options for the daemon as a service on
// Windows.
type serviceOptions struct {
type ServiceOptions struct {
ServiceCommand string `short:"s" long:"service" description:"Service command {install, remove, start, stop}"`
}
@@ -158,10 +160,10 @@ func cleanAndExpandPath(path string) string {
}
// newConfigParser returns a new command line flags parser.
func newConfigParser(cfgFlags *Flags, so *serviceOptions, options flags.Options) *flags.Parser {
func newConfigParser(cfgFlags *Flags, options flags.Options) *flags.Parser {
parser := flags.NewParser(cfgFlags, options)
if runtime.GOOS == "windows" {
parser.AddGroup("Service Options", "Service Options", so)
parser.AddGroup("Service Options", "Service Options", cfgFlags.ServiceOptions)
}
return parser
}
@@ -186,6 +188,8 @@ func defaultFlags() *Flags {
SigCacheMaxSize: defaultSigCacheMaxSize,
MinRelayTxFee: defaultMinRelayTxFee,
AcceptanceIndex: defaultAcceptanceIndex,
MaxUTXOCacheSize: defaultMaxUTXOCacheSize,
ServiceOptions: &ServiceOptions{},
}
}
@@ -208,24 +212,20 @@ func DefaultConfig() *Config {
// The above results in kaspad functioning properly without any config settings
// while still allowing the user to override settings with config files and
// command line options. Command line options always take precedence.
func LoadConfig() (cfg *Config, remainingArgs []string, err error) {
func LoadConfig() (*Config, error) {
cfgFlags := defaultFlags()
// Service options which are only added on Windows.
serviceOpts := serviceOptions{}
// Pre-parse the command line options to see if an alternative config
// file or the version flag was specified. Any errors aside from the
// help message error can be ignored here since they will be caught by
// the final parse below.
preCfg := cfgFlags
preParser := newConfigParser(preCfg, &serviceOpts, flags.HelpFlag)
_, err = preParser.Parse()
preParser := newConfigParser(preCfg, flags.HelpFlag)
_, err := preParser.Parse()
if err != nil {
var flagsErr *flags.Error
if ok := errors.As(err, &flagsErr); ok && flagsErr.Type == flags.ErrHelp {
fmt.Fprintln(os.Stderr, err)
return nil, nil, err
return nil, err
}
}
@@ -239,21 +239,10 @@ func LoadConfig() (cfg *Config, remainingArgs []string, err error) {
os.Exit(0)
}
// Perform service command and exit if specified. Invalid service
// commands show an appropriate error. Only runs on Windows since
// the RunServiceCommand function will be nil when not on Windows.
if serviceOpts.ServiceCommand != "" && RunServiceCommand != nil {
err := RunServiceCommand(serviceOpts.ServiceCommand)
if err != nil {
fmt.Fprintln(os.Stderr, err)
}
os.Exit(0)
}
// Load additional config from file.
var configFileError error
parser := newConfigParser(cfgFlags, &serviceOpts, flags.Default)
cfg = &Config{
parser := newConfigParser(cfgFlags, flags.Default)
cfg := &Config{
Flags: cfgFlags,
}
if !preCfg.Simnet || preCfg.ConfigFile !=
@@ -262,31 +251,27 @@ func LoadConfig() (cfg *Config, remainingArgs []string, err error) {
if _, err := os.Stat(preCfg.ConfigFile); os.IsNotExist(err) {
err := createDefaultConfigFile(preCfg.ConfigFile)
if err != nil {
fmt.Fprintf(os.Stderr, "Error creating a "+
"default config file: %s\n", err)
return nil, errors.Wrap(err, "Error creating a default config file")
}
}
err := flags.NewIniParser(parser).ParseFile(preCfg.ConfigFile)
if err != nil {
if pErr := &(os.PathError{}); !errors.As(err, &pErr) {
fmt.Fprintf(os.Stderr, "Error parsing config "+
"file: %s\n", err)
fmt.Fprintln(os.Stderr, usageMessage)
return nil, nil, err
return nil, errors.Wrapf(err, "Error parsing config file: %s\n\n%s", err, usageMessage)
}
configFileError = err
}
}
// Parse command line options again to ensure they take precedence.
remainingArgs, err = parser.Parse()
_, err = parser.Parse()
if err != nil {
var flagsErr *flags.Error
if ok := errors.As(err, &flagsErr); !ok || flagsErr.Type != flags.ErrHelp {
fmt.Fprintln(os.Stderr, usageMessage)
return nil, errors.Wrapf(err, "Error parsing command line arguments: %s\n\n%s", err, usageMessage)
}
return nil, nil, err
return nil, err
}
// Create the home directory if it doesn't already exist.
@@ -306,13 +291,12 @@ func LoadConfig() (cfg *Config, remainingArgs []string, err error) {
str := "%s: Failed to create home directory: %s"
err := errors.Errorf(str, funcName, err)
fmt.Fprintln(os.Stderr, err)
return nil, nil, err
return nil, err
}
err = cfg.ResolveNetwork(parser)
if err != nil {
return nil, nil, err
return nil, err
}
// Set the default policy for relaying non-standard transactions
@@ -327,7 +311,7 @@ func LoadConfig() (cfg *Config, remainingArgs []string, err error) {
err := errors.Errorf(str, funcName)
fmt.Fprintln(os.Stderr, err)
fmt.Fprintln(os.Stderr, usageMessage)
return nil, nil, err
return nil, err
case cfg.RejectNonStd:
relayNonStd = false
case cfg.RelayNonStd:
@@ -364,7 +348,7 @@ func LoadConfig() (cfg *Config, remainingArgs []string, err error) {
err := errors.Errorf("%s: %s", funcName, err.Error())
fmt.Fprintln(os.Stderr, err)
fmt.Fprintln(os.Stderr, usageMessage)
return nil, nil, err
return nil, err
}
// Validate profile port number
@@ -375,7 +359,7 @@ func LoadConfig() (cfg *Config, remainingArgs []string, err error) {
err := errors.Errorf(str, funcName)
fmt.Fprintln(os.Stderr, err)
fmt.Fprintln(os.Stderr, usageMessage)
return nil, nil, err
return nil, err
}
}
@@ -385,7 +369,7 @@ func LoadConfig() (cfg *Config, remainingArgs []string, err error) {
err := errors.Errorf(str, funcName, cfg.BanDuration)
fmt.Fprintln(os.Stderr, err)
fmt.Fprintln(os.Stderr, usageMessage)
return nil, nil, err
return nil, err
}
// Validate any given whitelisted IP addresses and networks.
@@ -402,7 +386,7 @@ func LoadConfig() (cfg *Config, remainingArgs []string, err error) {
err = errors.Errorf(str, funcName, addr)
fmt.Fprintln(os.Stderr, err)
fmt.Fprintln(os.Stderr, usageMessage)
return nil, nil, err
return nil, err
}
var bits int
if ip.To4() == nil {
@@ -427,7 +411,7 @@ func LoadConfig() (cfg *Config, remainingArgs []string, err error) {
err := errors.Errorf(str, funcName)
fmt.Fprintln(os.Stderr, err)
fmt.Fprintln(os.Stderr, usageMessage)
return nil, nil, err
return nil, err
}
// --proxy or --connect without --listen disables listening.
@@ -470,7 +454,7 @@ func LoadConfig() (cfg *Config, remainingArgs []string, err error) {
err := errors.Errorf(str, funcName, cfg.RPCMaxConcurrentReqs)
fmt.Fprintln(os.Stderr, err)
fmt.Fprintln(os.Stderr, usageMessage)
return nil, nil, err
return nil, err
}
// Validate the the minrelaytxfee.
@@ -480,7 +464,7 @@ func LoadConfig() (cfg *Config, remainingArgs []string, err error) {
err := errors.Errorf(str, funcName, err)
fmt.Fprintln(os.Stderr, err)
fmt.Fprintln(os.Stderr, usageMessage)
return nil, nil, err
return nil, err
}
// Disallow 0 and negative min tx fees.
@@ -489,7 +473,7 @@ func LoadConfig() (cfg *Config, remainingArgs []string, err error) {
err := errors.Errorf(str, funcName, cfg.MinRelayTxFee)
fmt.Fprintln(os.Stderr, err)
fmt.Fprintln(os.Stderr, usageMessage)
return nil, nil, err
return nil, err
}
// Limit the max block mass to a sane value.
@@ -502,7 +486,7 @@ func LoadConfig() (cfg *Config, remainingArgs []string, err error) {
blockMaxMassMax, cfg.BlockMaxMass)
fmt.Fprintln(os.Stderr, err)
fmt.Fprintln(os.Stderr, usageMessage)
return nil, nil, err
return nil, err
}
// Limit the max orphan count to a sane value.
@@ -512,7 +496,7 @@ func LoadConfig() (cfg *Config, remainingArgs []string, err error) {
err := errors.Errorf(str, funcName, cfg.MaxOrphanTxs)
fmt.Fprintln(os.Stderr, err)
fmt.Fprintln(os.Stderr, usageMessage)
return nil, nil, err
return nil, err
}
// Look for illegal characters in the user agent comments.
@@ -523,7 +507,7 @@ func LoadConfig() (cfg *Config, remainingArgs []string, err error) {
funcName)
fmt.Fprintln(os.Stderr, err)
fmt.Fprintln(os.Stderr, usageMessage)
return nil, nil, err
return nil, err
}
}
@@ -534,7 +518,7 @@ func LoadConfig() (cfg *Config, remainingArgs []string, err error) {
funcName)
fmt.Fprintln(os.Stderr, err)
fmt.Fprintln(os.Stderr, usageMessage)
return nil, nil, err
return nil, err
}
// Add default port to all listener addresses if needed and remove
@@ -542,7 +526,7 @@ func LoadConfig() (cfg *Config, remainingArgs []string, err error) {
cfg.Listeners, err = network.NormalizeAddresses(cfg.Listeners,
cfg.NetParams().DefaultPort)
if err != nil {
return nil, nil, err
return nil, err
}
// Add default port to all rpc listener addresses if needed and remove
@@ -550,7 +534,7 @@ func LoadConfig() (cfg *Config, remainingArgs []string, err error) {
cfg.RPCListeners, err = network.NormalizeAddresses(cfg.RPCListeners,
cfg.NetParams().RPCPort)
if err != nil {
return nil, nil, err
return nil, err
}
// Disallow --addpeer and --connect used together
@@ -559,7 +543,7 @@ func LoadConfig() (cfg *Config, remainingArgs []string, err error) {
err := errors.Errorf(str, funcName)
fmt.Fprintln(os.Stderr, err)
fmt.Fprintln(os.Stderr, usageMessage)
return nil, nil, err
return nil, err
}
// Add default port to all added peer addresses if needed and remove
@@ -567,13 +551,13 @@ func LoadConfig() (cfg *Config, remainingArgs []string, err error) {
cfg.AddPeers, err = network.NormalizeAddresses(cfg.AddPeers,
cfg.NetParams().DefaultPort)
if err != nil {
return nil, nil, err
return nil, err
}
cfg.ConnectPeers, err = network.NormalizeAddresses(cfg.ConnectPeers,
cfg.NetParams().DefaultPort)
if err != nil {
return nil, nil, err
return nil, err
}
// Setup dial and DNS resolution (lookup) functions depending on the
@@ -590,7 +574,7 @@ func LoadConfig() (cfg *Config, remainingArgs []string, err error) {
err := errors.Errorf(str, funcName, cfg.Proxy, err)
fmt.Fprintln(os.Stderr, err)
fmt.Fprintln(os.Stderr, usageMessage)
return nil, nil, err
return nil, err
}
proxy := &socks.Proxy{
@@ -608,7 +592,7 @@ func LoadConfig() (cfg *Config, remainingArgs []string, err error) {
log.Warnf("%s", configFileError)
}
return cfg, remainingArgs, nil
return cfg, nil
}
// createDefaultConfig copies the file sample-kaspad.conf to the given destination path,

View File

@@ -2,10 +2,11 @@ package database_test
import (
"fmt"
"github.com/kaspanet/kaspad/infrastructure/db/database"
"github.com/kaspanet/kaspad/infrastructure/db/database/ffldb"
"io/ioutil"
"testing"
"github.com/kaspanet/kaspad/infrastructure/db/database"
"github.com/kaspanet/kaspad/infrastructure/db/database/ldb"
)
type databasePrepareFunc func(t *testing.T, testName string) (db database.Database, name string, teardownFunc func())
@@ -14,17 +15,17 @@ type databasePrepareFunc func(t *testing.T, testName string) (db database.Databa
// prepares a separate database type for testing.
// See testForAllDatabaseTypes for further details.
var databasePrepareFuncs = []databasePrepareFunc{
prepareFFLDBForTest,
prepareLDBForTest,
}
func prepareFFLDBForTest(t *testing.T, testName string) (db database.Database, name string, teardownFunc func()) {
func prepareLDBForTest(t *testing.T, testName string) (db database.Database, name string, teardownFunc func()) {
// Create a temp db to run tests against
path, err := ioutil.TempDir("", testName)
if err != nil {
t.Fatalf("%s: TempDir unexpectedly "+
"failed: %s", testName, err)
}
db, err = ffldb.Open(path)
db, err = ldb.NewLevelDB(path)
if err != nil {
t.Fatalf("%s: Open unexpectedly "+
"failed: %s", testName, err)
@@ -36,7 +37,7 @@ func prepareFFLDBForTest(t *testing.T, testName string) (db database.Database, n
"failed: %s", testName, err)
}
}
return db, "ffldb", teardownFunc
return db, "ldb", teardownFunc
}
// testForAllDatabaseTypes runs the given testFunc for every database

View File

@@ -19,18 +19,6 @@ type DataAccessor interface {
// return an error if the key doesn't exist.
Delete(key *Key) error
// AppendToStore appends the given data to the store
// defined by storeName. This function returns a serialized
// location handle that's meant to be stored and later used
// when querying the data that has just now been inserted.
AppendToStore(storeName string, data []byte) ([]byte, error)
// RetrieveFromStore retrieves data from the store defined by
// storeName using the given serialized location handle. It
// returns ErrNotFound if the location does not exist. See
// AppendToStore for further details.
RetrieveFromStore(storeName string, location []byte) ([]byte, error)
// Cursor begins a new cursor over the given bucket.
Cursor(bucket *Bucket) (Cursor, error)
}

View File

@@ -7,8 +7,9 @@ package database_test
import (
"bytes"
"github.com/kaspanet/kaspad/infrastructure/db/database"
"testing"
"github.com/kaspanet/kaspad/infrastructure/db/database"
)
func TestDatabasePut(t *testing.T) {
@@ -166,42 +167,3 @@ func testDatabaseDelete(t *testing.T, db database.Database, testName string) {
"unexpectedly returned that the value exists", testName)
}
}
func TestDatabaseAppendToStoreAndRetrieveFromStore(t *testing.T) {
testForAllDatabaseTypes(t, "TestDatabaseAppendToStoreAndRetrieveFromStore", testDatabaseAppendToStoreAndRetrieveFromStore)
}
func testDatabaseAppendToStoreAndRetrieveFromStore(t *testing.T, db database.Database, testName string) {
// Append some data into the store
storeName := "store"
data := []byte("data")
location, err := db.AppendToStore(storeName, data)
if err != nil {
t.Fatalf("%s: AppendToStore "+
"unexpectedly failed: %s", testName, err)
}
// Retrieve the data and make sure it's equal to what was appended
retrievedData, err := db.RetrieveFromStore(storeName, location)
if err != nil {
t.Fatalf("%s: RetrieveFromStore "+
"unexpectedly failed: %s", testName, err)
}
if !bytes.Equal(retrievedData, data) {
t.Fatalf("%s: RetrieveFromStore "+
"returned unexpected data. Want: %s, got: %s",
testName, string(data), string(retrievedData))
}
// Make sure that an invalid location returns ErrNotFound
fakeLocation := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}
_, err = db.RetrieveFromStore(storeName, fakeLocation)
if err == nil {
t.Fatalf("%s: RetrieveFromStore "+
"unexpectedly succeeded", testName)
}
if !database.IsNotFoundError(err) {
t.Fatalf("%s: RetrieveFromStore "+
"returned wrong error: %s", testName, err)
}
}

View File

@@ -1,231 +0,0 @@
package ff
import (
"container/list"
"encoding/binary"
"fmt"
"github.com/pkg/errors"
"hash/crc32"
"os"
"path/filepath"
"sync"
)
const (
// maxOpenFiles is the max number of open files to maintain in each store's
// cache. Note that this does not include the current/write file, so there
// will typically be one more than this value open.
maxOpenFiles = 25
)
var (
// maxFileSize is the maximum size for each file used to store data.
//
// NOTE: The current code uses uint32 for all offsets, so this value
// must be less than 2^32 (4 GiB).
// NOTE: This is a var rather than a const for testing purposes.
maxFileSize uint32 = 512 * 1024 * 1024 // 512 MiB
)
var (
// byteOrder is the preferred byte order used through the flat files.
// Sometimes big endian will be used to allow ordered byte sortable
// integer values.
byteOrder = binary.LittleEndian
// crc32ByteOrder is the byte order used for CRC-32 checksums.
crc32ByteOrder = binary.BigEndian
// crc32ChecksumLength is the length in bytes of a CRC-32 checksum.
crc32ChecksumLength = 4
// dataLengthLength is the length in bytes of the "data length" section
// of a serialized entry in a flat file store.
dataLengthLength = 4
// castagnoli houses the Catagnoli polynomial used for CRC-32 checksums.
castagnoli = crc32.MakeTable(crc32.Castagnoli)
)
// flatFileStore houses information used to handle reading and writing data
// into flat files with support for multiple concurrent readers.
type flatFileStore struct {
// basePath is the base path used for the flat files.
basePath string
// storeName is the name of this flat-file store.
storeName string
// The following fields are related to the flat files which hold the
// actual data. The number of open files is limited by maxOpenFiles.
//
// openFilesMutex protects concurrent access to the openFiles map. It
// is a RWMutex so multiple readers can simultaneously access open
// files.
//
// openFiles houses the open file handles for existing files which have
// been opened read-only along with an individual RWMutex. This scheme
// allows multiple concurrent readers to the same file while preventing
// the file from being closed out from under them.
//
// lruMutex protects concurrent access to the least recently used list
// and lookup map.
//
// openFilesLRU tracks how the open files are referenced by pushing the
// most recently used files to the front of the list thereby trickling
// the least recently used files to end of the list. When a file needs
// to be closed due to exceeding the max number of allowed open
// files, the one at the end of the list is closed.
//
// fileNumberToLRUElement is a mapping between a specific file number and
// the associated list element on the least recently used list.
//
// Thus, with the combination of these fields, the database supports
// concurrent non-blocking reads across multiple and individual files
// along with intelligently limiting the number of open file handles by
// closing the least recently used files as needed.
//
// NOTE: The locking order used throughout is well-defined and MUST be
// followed. Failure to do so could lead to deadlocks. In particular,
// the locking order is as follows:
// 1) openFilesMutex
// 2) lruMutex
// 3) writeCursor mutex
// 4) specific file mutexes
//
// None of the mutexes are required to be locked at the same time, and
// often aren't. However, if they are to be locked simultaneously, they
// MUST be locked in the order previously specified.
//
// Due to the high performance and multi-read concurrency requirements,
// write locks should only be held for the minimum time necessary.
openFilesMutex sync.RWMutex
openFiles map[uint32]*lockableFile
lruMutex sync.Mutex
openFilesLRU *list.List // Contains uint32 file numbers.
fileNumberToLRUElement map[uint32]*list.Element
// writeCursor houses the state for the current file and location that
// new data is written to.
writeCursor *writeCursor
// isClosed is true when the store is closed. Any operations on a closed
// store will fail.
isClosed bool
}
// writeCursor represents the current file and offset of the flat file on disk
// for performing all writes. It also contains a read-write mutex to support
// multiple concurrent readers which can reuse the file handle.
type writeCursor struct {
sync.RWMutex
// currentFile is the current file that will be appended to when writing
// new data.
currentFile *lockableFile
// currentFileNumber is the current file number and is used to allow
// readers to use the same open file handle.
currentFileNumber uint32
// currentOffset is the offset in the current file where the next new
// data will be written.
currentOffset uint32
}
// openFlatFileStore returns a new flat file store with the current file number
// and offset set and all fields initialized.
func openFlatFileStore(basePath string, storeName string) (*flatFileStore, error) {
// Look for the end of the latest file to determine what the write cursor
// position is from the viewpoint of the flat files on disk.
fileNumber, fileOffset, err := findCurrentLocation(basePath, storeName)
if err != nil {
return nil, err
}
store := &flatFileStore{
basePath: basePath,
storeName: storeName,
openFiles: make(map[uint32]*lockableFile),
openFilesLRU: list.New(),
fileNumberToLRUElement: make(map[uint32]*list.Element),
writeCursor: &writeCursor{
currentFile: &lockableFile{},
currentFileNumber: fileNumber,
currentOffset: fileOffset,
},
isClosed: false,
}
return store, nil
}
func (s *flatFileStore) Close() error {
if s.isClosed {
return errors.Errorf("cannot close a closed store %s",
s.storeName)
}
s.isClosed = true
// Close the write cursor. We lock the write cursor here
// to let it finish any undergoing writing.
s.writeCursor.Lock()
defer s.writeCursor.Unlock()
err := s.writeCursor.currentFile.Close()
if err != nil {
return err
}
// Close all open files
for _, openFile := range s.openFiles {
err := openFile.Close()
if err != nil {
return err
}
}
return nil
}
func (s *flatFileStore) currentLocation() *flatFileLocation {
return &flatFileLocation{
fileNumber: s.writeCursor.currentFileNumber,
fileOffset: s.writeCursor.currentOffset,
dataLength: 0,
}
}
// findCurrentLocation searches the database directory for all flat files for a given
// store to find the end of the most recent file. This position is considered
// the current write cursor.
func findCurrentLocation(dbPath string, storeName string) (fileNumber uint32, fileLength uint32, err error) {
currentFileNumber := uint32(0)
currentFileLength := uint32(0)
for {
currentFilePath := flatFilePath(dbPath, storeName, currentFileNumber)
stat, err := os.Stat(currentFilePath)
if err != nil {
if !os.IsNotExist(err) {
return 0, 0, errors.WithStack(err)
}
if currentFileNumber > 0 {
fileNumber = currentFileNumber - 1
}
fileLength = currentFileLength
break
}
currentFileLength = uint32(stat.Size())
currentFileNumber++
}
log.Tracef("Scan for store '%s' found latest file #%d with length %d",
storeName, fileNumber, fileLength)
return fileNumber, fileLength, nil
}
// flatFilePath return the file path for the provided store's flat file number.
func flatFilePath(dbPath string, storeName string, fileNumber uint32) string {
// Choose 9 digits of precision for the filenames. 9 digits provide
// 10^9 files @ 512MiB each a total of ~476.84PiB.
fileName := fmt.Sprintf("%s-%09d.fdb", storeName, fileNumber)
return filepath.Join(dbPath, fileName)
}

View File

@@ -1,175 +0,0 @@
package ff
import (
"bytes"
"github.com/kaspanet/kaspad/infrastructure/db/database"
"io/ioutil"
"os"
"reflect"
"testing"
)
func prepareStoreForTest(t *testing.T, testName string) (store *flatFileStore, teardownFunc func()) {
// Create a temp db to run tests against
path, err := ioutil.TempDir("", testName)
if err != nil {
t.Fatalf("%s: TempDir unexpectedly "+
"failed: %s", testName, err)
}
name := "test"
store, err = openFlatFileStore(path, name)
if err != nil {
t.Fatalf("%s: openFlatFileStore "+
"unexpectedly failed: %s", testName, err)
}
teardownFunc = func() {
err = store.Close()
if err != nil {
t.Fatalf("%s: Close unexpectedly "+
"failed: %s", testName, err)
}
}
return store, teardownFunc
}
func TestFlatFileStoreSanity(t *testing.T) {
store, teardownFunc := prepareStoreForTest(t, "TestFlatFileStoreSanity")
defer teardownFunc()
// Write something to the store
writeData := []byte("Hello world!")
location, err := store.write(writeData)
if err != nil {
t.Fatalf("TestFlatFileStoreSanity: Write returned "+
"unexpected error: %s", err)
}
// Read from the location previously written to
readData, err := store.read(location)
if err != nil {
t.Fatalf("TestFlatFileStoreSanity: read returned "+
"unexpected error: %s", err)
}
// Make sure that the written data and the read data are equal
if !reflect.DeepEqual(readData, writeData) {
t.Fatalf("TestFlatFileStoreSanity: read data and "+
"write data are not equal. Wrote: %s, read: %s",
string(writeData), string(readData))
}
}
func TestFlatFilePath(t *testing.T) {
tests := []struct {
dbPath string
storeName string
fileNumber uint32
expectedPath string
}{
{
dbPath: "path",
storeName: "store",
fileNumber: 0,
expectedPath: "path/store-000000000.fdb",
},
{
dbPath: "path/to/database",
storeName: "blocks",
fileNumber: 123456789,
expectedPath: "path/to/database/blocks-123456789.fdb",
},
}
for _, test := range tests {
path := flatFilePath(test.dbPath, test.storeName, test.fileNumber)
if path != test.expectedPath {
t.Errorf("TestFlatFilePath: unexpected path. Want: %s, got: %s",
test.expectedPath, path)
}
}
}
func TestFlatFileMultiFileRollback(t *testing.T) {
store, teardownFunc := prepareStoreForTest(t, "TestFlatFileMultiFileRollback")
defer teardownFunc()
// Set the maxFileSize to 16 bytes so that we don't have to write
// an enormous amount of data to disk to get multiple files, all
// for the sake of this test.
currentMaxFileSize := maxFileSize
maxFileSize = 16
defer func() {
maxFileSize = currentMaxFileSize
}()
// Write five 8 byte chunks and keep the last location written to
var lastWriteLocation1 *flatFileLocation
for i := byte(0); i < 5; i++ {
writeData := []byte{i, i, i, i, i, i, i, i}
var err error
lastWriteLocation1, err = store.write(writeData)
if err != nil {
t.Fatalf("TestFlatFileMultiFileRollback: write returned "+
"unexpected error: %s", err)
}
}
// Grab the current location and the current file number
currentLocation := store.currentLocation()
fileNumberBeforeWriting := store.writeCursor.currentFileNumber
// Write (2 * maxOpenFiles) more 8 byte chunks and keep the last location written to
var lastWriteLocation2 *flatFileLocation
for i := byte(0); i < byte(2*maxFileSize); i++ {
writeData := []byte{0, 1, 2, 3, 4, 5, 6, 7}
var err error
lastWriteLocation2, err = store.write(writeData)
if err != nil {
t.Fatalf("TestFlatFileMultiFileRollback: write returned "+
"unexpected error: %s", err)
}
}
// Grab the file number again to later make sure its file no longer exists
fileNumberAfterWriting := store.writeCursor.currentFileNumber
// Rollback
err := store.rollback(currentLocation)
if err != nil {
t.Fatalf("TestFlatFileMultiFileRollback: rollback returned "+
"unexpected error: %s", err)
}
// Make sure that lastWriteLocation1 still exists
expectedData := []byte{4, 4, 4, 4, 4, 4, 4, 4}
data, err := store.read(lastWriteLocation1)
if err != nil {
t.Fatalf("TestFlatFileMultiFileRollback: read returned "+
"unexpected error: %s", err)
}
if !bytes.Equal(data, expectedData) {
t.Fatalf("TestFlatFileMultiFileRollback: read returned "+
"unexpected data. Want: %s, got: %s", string(expectedData),
string(data))
}
// Make sure that lastWriteLocation2 does NOT exist
_, err = store.read(lastWriteLocation2)
if err == nil {
t.Fatalf("TestFlatFileMultiFileRollback: read " +
"unexpectedly succeeded")
}
if !database.IsNotFoundError(err) {
t.Fatalf("TestFlatFileMultiFileRollback: read "+
"returned unexpected error: %s", err)
}
// Make sure that all the appropriate files have been deleted
for i := fileNumberAfterWriting; i > fileNumberBeforeWriting; i-- {
filePath := flatFilePath(store.basePath, store.storeName, i)
if _, err := os.Stat(filePath); err == nil || !os.IsNotExist(err) {
t.Fatalf("TestFlatFileMultiFileRollback: file "+
"unexpectedly still exists: %s", filePath)
}
}
}

View File

@@ -1,103 +0,0 @@
package ff
// FlatFileDB is a flat-file database. It supports opening
// multiple flat-file stores. See flatFileStore for further
// details.
type FlatFileDB struct {
path string
flatFileStores map[string]*flatFileStore
}
// NewFlatFileDB opens the flat-file database defined by
// the given path.
func NewFlatFileDB(path string) *FlatFileDB {
return &FlatFileDB{
path: path,
flatFileStores: make(map[string]*flatFileStore),
}
}
// Close closes the flat-file database.
func (ffdb *FlatFileDB) Close() error {
for _, store := range ffdb.flatFileStores {
err := store.Close()
if err != nil {
return err
}
}
return nil
}
// Write appends the specified data bytes to the specified store.
// It returns a serialized location handle that's meant to be
// stored and later used when querying the data that has just now
// been inserted.
// See flatFileStore.write() for further details.
func (ffdb *FlatFileDB) Write(storeName string, data []byte) ([]byte, error) {
store, err := ffdb.store(storeName)
if err != nil {
return nil, err
}
location, err := store.write(data)
if err != nil {
return nil, err
}
return serializeLocation(location), nil
}
// Read reads data from the specified flat file store at the
// location specified by the given serialized location handle.
// It returns ErrNotFound if the location does not exist.
// See flatFileStore.read() for further details.
func (ffdb *FlatFileDB) Read(storeName string, serializedLocation []byte) ([]byte, error) {
store, err := ffdb.store(storeName)
if err != nil {
return nil, err
}
location, err := deserializeLocation(serializedLocation)
if err != nil {
return nil, err
}
return store.read(location)
}
// CurrentLocation returns the serialized location handle to
// the current location within the flat file store defined
// storeName. It is mainly to be used to rollback flat-file
// stores in case of data incongruency.
func (ffdb *FlatFileDB) CurrentLocation(storeName string) ([]byte, error) {
store, err := ffdb.store(storeName)
if err != nil {
return nil, err
}
currentLocation := store.currentLocation()
return serializeLocation(currentLocation), nil
}
// Rollback truncates the flat-file store defined by the given
// storeName to the location defined by the given serialized
// location handle.
func (ffdb *FlatFileDB) Rollback(storeName string, serializedLocation []byte) error {
store, err := ffdb.store(storeName)
if err != nil {
return err
}
location, err := deserializeLocation(serializedLocation)
if err != nil {
return err
}
return store.rollback(location)
}
func (ffdb *FlatFileDB) store(storeName string) (*flatFileStore, error) {
store, ok := ffdb.flatFileStores[storeName]
if !ok {
var err error
store, err = openFlatFileStore(ffdb.path, storeName)
if err != nil {
return nil, err
}
ffdb.flatFileStores[storeName] = store
}
return store, nil
}

View File

@@ -1,44 +0,0 @@
package ff
import "github.com/pkg/errors"
// flatFileLocationSerializedSize is the size in bytes of a serialized flat
// file location. See serializeLocation for further details.
const flatFileLocationSerializedSize = 12
// flatFileLocation identifies a particular flat file location.
type flatFileLocation struct {
fileNumber uint32
fileOffset uint32
dataLength uint32
}
// serializeLocation returns the serialization of the passed flat file location
// of certain data. This to later on be used for retrieval of said data.
// The serialized location format is:
//
// [0:4] File Number (4 bytes)
// [4:8] File offset (4 bytes)
// [8:12] Data length (4 bytes)
func serializeLocation(location *flatFileLocation) []byte {
var serializedLocation [flatFileLocationSerializedSize]byte
byteOrder.PutUint32(serializedLocation[0:4], location.fileNumber)
byteOrder.PutUint32(serializedLocation[4:8], location.fileOffset)
byteOrder.PutUint32(serializedLocation[8:12], location.dataLength)
return serializedLocation[:]
}
// deserializeLocation deserializes the passed serialized flat file location.
// See serializeLocation for further details.
func deserializeLocation(serializedLocation []byte) (*flatFileLocation, error) {
if len(serializedLocation) != flatFileLocationSerializedSize {
return nil, errors.Errorf("unexpected serializedLocation length: %d",
len(serializedLocation))
}
location := &flatFileLocation{
fileNumber: byteOrder.Uint32(serializedLocation[0:4]),
fileOffset: byteOrder.Uint32(serializedLocation[4:8]),
dataLength: byteOrder.Uint32(serializedLocation[8:12]),
}
return location, nil
}

View File

@@ -1,62 +0,0 @@
package ff
import (
"bytes"
"encoding/hex"
"reflect"
"strings"
"testing"
)
func TestFlatFileLocationSerialization(t *testing.T) {
location := &flatFileLocation{
fileNumber: 1,
fileOffset: 2,
dataLength: 3,
}
serializedLocation := serializeLocation(location)
expectedSerializedLocation := []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
if !bytes.Equal(serializedLocation, expectedSerializedLocation) {
t.Fatalf("TestFlatFileLocationSerialization: serializeLocation "+
"returned unexpected bytes. Want: %s, got: %s",
hex.EncodeToString(expectedSerializedLocation), hex.EncodeToString(serializedLocation))
}
deserializedLocation, err := deserializeLocation(serializedLocation)
if err != nil {
t.Fatalf("TestFlatFileLocationSerialization: deserializeLocation "+
"unexpectedly failed: %s", err)
}
if !reflect.DeepEqual(deserializedLocation, location) {
t.Fatalf("TestFlatFileLocationSerialization: original "+
"location and deserialized location aren't the same. Want: %v, "+
"got: %v", location, deserializedLocation)
}
}
func TestFlatFileLocationDeserializationErrors(t *testing.T) {
expectedError := "unexpected serializedLocation length"
tooShortSerializedLocation := []byte{0, 1, 2, 3, 4, 5}
_, err := deserializeLocation(tooShortSerializedLocation)
if err == nil {
t.Fatalf("TestFlatFileLocationSerialization: deserializeLocation " +
"unexpectedly succeeded")
}
if !strings.Contains(err.Error(), expectedError) {
t.Fatalf("TestFlatFileLocationSerialization: deserializeLocation "+
"returned unexpected error. Want: %s, got: %s", expectedError, err)
}
tooLongSerializedLocation := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}
_, err = deserializeLocation(tooLongSerializedLocation)
if err == nil {
t.Fatalf("TestFlatFileLocationSerialization: deserializeLocation " +
"unexpectedly succeeded")
}
if !strings.Contains(err.Error(), expectedError) {
t.Fatalf("TestFlatFileLocationSerialization: deserializeLocation "+
"returned unexpected error. Want: %s, got: %s", expectedError, err)
}
}

View File

@@ -1,44 +0,0 @@
package ff
import (
"github.com/pkg/errors"
"io"
"sync"
)
// lockableFile represents a flat file on disk that has been opened for either
// read or read/write access. It also contains a read-write mutex to support
// multiple concurrent readers.
type lockableFile struct {
sync.RWMutex
file
isClosed bool
}
// file is an interface which acts very similar to a *os.File and is typically
// implemented by it. It exists so the test code can provide mock files for
// properly testing corruption and file system issues.
type file interface {
io.Closer
io.WriterAt
io.ReaderAt
Truncate(size int64) error
Sync() error
}
func (lf *lockableFile) Close() error {
if lf.isClosed {
return errors.Errorf("cannot close an already closed file")
}
lf.isClosed = true
lf.Lock()
defer lf.Unlock()
if lf.file == nil {
return nil
}
return errors.WithStack(lf.file.Close())
}

View File

@@ -1,5 +0,0 @@
package ff
import "github.com/kaspanet/kaspad/infrastructure/logger"
var log, _ = logger.Get(logger.SubsystemTags.KSDB)

View File

@@ -1,153 +0,0 @@
package ff
import (
"github.com/kaspanet/kaspad/infrastructure/db/database"
"github.com/pkg/errors"
"hash/crc32"
"os"
)
// read reads the specified flat file record and returns the data. It ensures
// the integrity of the data by comparing the calculated checksum against the
// one stored in the flat file. This function also automatically handles all
// file management such as opening and closing files as necessary to stay
// within the maximum allowed open files limit. It returns ErrNotFound if the
// location does not exist.
//
// Format: <data length><data><checksum>
func (s *flatFileStore) read(location *flatFileLocation) ([]byte, error) {
if s.isClosed {
return nil, errors.Errorf("cannot read from a closed store %s",
s.storeName)
}
// Return not-found if the location is greater than or equal to
// the current write cursor.
if s.writeCursor.currentFileNumber < location.fileNumber ||
(s.writeCursor.currentFileNumber == location.fileNumber && s.writeCursor.currentOffset <= location.fileOffset) {
return nil, database.ErrNotFound
}
// Get the referenced flat file.
flatFile, err := s.flatFile(location.fileNumber)
if err != nil {
return nil, err
}
flatFile.RLock()
defer flatFile.RUnlock()
data := make([]byte, location.dataLength)
n, err := flatFile.file.ReadAt(data, int64(location.fileOffset))
if err != nil {
return nil, errors.Wrapf(err, "failed to read data in store '%s' "+
"from file %d, offset %d", s.storeName, location.fileNumber,
location.fileOffset)
}
// Calculate the checksum of the read data and ensure it matches the
// serialized checksum.
serializedChecksum := crc32ByteOrder.Uint32(data[n-crc32ChecksumLength:])
calculatedChecksum := crc32.Checksum(data[:n-crc32ChecksumLength], castagnoli)
if serializedChecksum != calculatedChecksum {
return nil, errors.Errorf("data in store '%s' does not match "+
"checksum - got %x, want %x", s.storeName, calculatedChecksum,
serializedChecksum)
}
// The data excludes the length of the data and the checksum.
return data[dataLengthLength : n-crc32ChecksumLength], nil
}
// flatFile attempts to return an existing file handle for the passed flat file
// number if it is already open as well as marking it as most recently used. It
// will also open the file when it's not already open subject to the rules
// described in openFile. Also handles closing files as needed to avoid going
// over the max allowed open files.
func (s *flatFileStore) flatFile(fileNumber uint32) (*lockableFile, error) {
// When the requested flat file is open for writes, return it.
s.writeCursor.RLock()
defer s.writeCursor.RUnlock()
if fileNumber == s.writeCursor.currentFileNumber && s.writeCursor.currentFile.file != nil {
openFile := s.writeCursor.currentFile
return openFile, nil
}
// Try to return an open file under the overall files read lock.
s.openFilesMutex.RLock()
defer s.openFilesMutex.RUnlock()
if openFile, ok := s.openFiles[fileNumber]; ok {
s.lruMutex.Lock()
defer s.lruMutex.Unlock()
s.openFilesLRU.MoveToFront(s.fileNumberToLRUElement[fileNumber])
return openFile, nil
}
// Since the file isn't open already, need to check the open files map
// again under write lock in case multiple readers got here and a
// separate one is already opening the file.
if openFlatFile, ok := s.openFiles[fileNumber]; ok {
return openFlatFile, nil
}
// The file isn't open, so open it while potentially closing the least
// recently used one as needed.
openFile, err := s.openFile(fileNumber)
if err != nil {
return nil, err
}
return openFile, nil
}
// openFile returns a read-only file handle for the passed flat file number.
// The function also keeps track of the open files, performs least recently
// used tracking, and limits the number of open files to maxOpenFiles by closing
// the least recently used file as needed.
//
// This function MUST be called with the open files mutex (s.openFilesMutex)
// locked for WRITES.
func (s *flatFileStore) openFile(fileNumber uint32) (*lockableFile, error) {
// Open the appropriate file as read-only.
filePath := flatFilePath(s.basePath, s.storeName, fileNumber)
file, err := os.Open(filePath)
if err != nil {
return nil, errors.WithStack(err)
}
flatFile := &lockableFile{file: file}
// Close the least recently used file if the file exceeds the max
// allowed open files. This is not done until after the file open in
// case the file fails to open, there is no need to close any files.
//
// A write lock is required on the LRU list here to protect against
// modifications happening as already open files are read from and
// shuffled to the front of the list.
//
// Also, add the file that was just opened to the front of the least
// recently used list to indicate it is the most recently used file and
// therefore should be closed last.
s.lruMutex.Lock()
defer s.lruMutex.Unlock()
lruList := s.openFilesLRU
if lruList.Len() >= maxOpenFiles {
lruFileNumber := lruList.Remove(lruList.Back()).(uint32)
oldFile := s.openFiles[lruFileNumber]
// Close the old file under the write lock for the file in case
// any readers are currently reading from it so it's not closed
// out from under them.
oldFile.Lock()
defer oldFile.Unlock()
_ = oldFile.file.Close()
delete(s.openFiles, lruFileNumber)
delete(s.fileNumberToLRUElement, lruFileNumber)
}
s.fileNumberToLRUElement[fileNumber] = lruList.PushFront(fileNumber)
// Store a reference to it in the open files map.
s.openFiles[fileNumber] = flatFile
return flatFile, nil
}

View File

@@ -1,135 +0,0 @@
package ff
import (
"github.com/pkg/errors"
"os"
)
// rollback rolls the flat files on disk back to the provided file number
// and offset. This involves potentially deleting and truncating the files that
// were partially written.
//
// There are effectively two scenarios to consider here:
// 1) Transient write failures from which recovery is possible
// 2) More permanent failures such as hard disk death and/or removal
//
// In either case, the write cursor will be repositioned to the old flat file
// offset regardless of any other errors that occur while attempting to undo
// writes.
//
// For the first scenario, this will lead to any data which failed to be undone
// being overwritten and thus behaves as desired as the system continues to run.
//
// For the second scenario, the metadata which stores the current write cursor
// position within the flat files will not have been updated yet and thus if
// the system eventually recovers (perhaps the hard drive is reconnected), it
// will also lead to any data which failed to be undone being overwritten and
// thus behaves as desired.
func (s *flatFileStore) rollback(targetLocation *flatFileLocation) error {
if s.isClosed {
return errors.Errorf("cannot rollback a closed store %s",
s.storeName)
}
// Grab the write cursor mutex since it is modified throughout this
// function.
s.writeCursor.Lock()
defer s.writeCursor.Unlock()
// Nothing to do if the rollback point is the same as the current write
// cursor.
targetFileNumber := targetLocation.fileNumber
targetFileOffset := targetLocation.fileOffset
if s.writeCursor.currentFileNumber == targetFileNumber && s.writeCursor.currentOffset == targetFileOffset {
return nil
}
// If the rollback point is greater than the current write cursor then
// something has gone very wrong, e.g. database corruption.
if s.writeCursor.currentFileNumber < targetFileNumber ||
(s.writeCursor.currentFileNumber == targetFileNumber && s.writeCursor.currentOffset < targetFileOffset) {
return errors.Errorf("targetLocation is greater than the " +
"current write cursor")
}
// Regardless of any failures that happen below, reposition the write
// cursor to the target flat file and offset.
defer func() {
s.writeCursor.currentFileNumber = targetFileNumber
s.writeCursor.currentOffset = targetFileOffset
}()
log.Warnf("ROLLBACK: Rolling back to file %d, offset %d",
targetFileNumber, targetFileOffset)
// Close the current write file if it needs to be deleted.
if s.writeCursor.currentFileNumber > targetFileNumber {
s.closeCurrentWriteCursorFile()
}
// Delete all files that are newer than the provided rollback file
// while also moving the write cursor file backwards accordingly.
s.lruMutex.Lock()
defer s.lruMutex.Unlock()
s.openFilesMutex.Lock()
defer s.openFilesMutex.Unlock()
for s.writeCursor.currentFileNumber > targetFileNumber {
err := s.deleteFile(s.writeCursor.currentFileNumber)
if err != nil {
return errors.Wrapf(err, "ROLLBACK: Failed to delete file "+
"number %d in store '%s'", s.writeCursor.currentFileNumber,
s.storeName)
}
s.writeCursor.currentFileNumber--
}
// Open the file for the current write cursor if needed.
s.writeCursor.currentFile.Lock()
defer s.writeCursor.currentFile.Unlock()
if s.writeCursor.currentFile.file == nil {
openFile, err := s.openWriteFile(s.writeCursor.currentFileNumber)
if err != nil {
return err
}
s.writeCursor.currentFile.file = openFile
}
// Truncate the file to the provided target offset.
err := s.writeCursor.currentFile.file.Truncate(int64(targetFileOffset))
if err != nil {
return errors.Wrapf(err, "ROLLBACK: Failed to truncate file %d "+
"in store '%s'", s.writeCursor.currentFileNumber, s.storeName)
}
// Sync the file to disk.
err = s.writeCursor.currentFile.file.Sync()
if err != nil {
return errors.Wrapf(err, "ROLLBACK: Failed to sync file %d in "+
"store '%s'", s.writeCursor.currentFileNumber, s.storeName)
}
return nil
}
// deleteFile removes the file for the passed flat file number.
// This function MUST be called with the lruMutex and the openFilesMutex
// held for writes.
func (s *flatFileStore) deleteFile(fileNumber uint32) error {
// Cleanup the file before deleting it
if file, ok := s.openFiles[fileNumber]; ok {
file.Lock()
defer file.Unlock()
err := file.Close()
if err != nil {
return err
}
lruElement := s.fileNumberToLRUElement[fileNumber]
s.openFilesLRU.Remove(lruElement)
delete(s.openFiles, fileNumber)
delete(s.fileNumberToLRUElement, fileNumber)
}
// Delete the file from disk
filePath := flatFilePath(s.basePath, s.storeName, fileNumber)
return errors.WithStack(os.Remove(filePath))
}

View File

@@ -1,176 +0,0 @@
package ff
import (
"github.com/kaspanet/kaspad/util/panics"
"github.com/pkg/errors"
"hash/crc32"
"os"
"syscall"
)
// write appends the specified data bytes to the store's write cursor location
// and increments it accordingly. When the data would exceed the max file size
// for the current flat file, this function will close the current file, create
// the next file, update the write cursor, and write the data to the new file.
//
// The write cursor will also be advanced the number of bytes actually written
// in the event of failure.
//
// Format: <data length><data><checksum>
func (s *flatFileStore) write(data []byte) (*flatFileLocation, error) {
if s.isClosed {
return nil, errors.Errorf("cannot write to a closed store %s",
s.storeName)
}
// Compute how many bytes will be written.
// 4 bytes for data length + length of the data + 4 bytes for checksum.
dataLength := uint32(len(data))
fullLength := uint32(dataLengthLength) + dataLength + uint32(crc32ChecksumLength)
// Move to the next file if adding the new data would exceed the max
// allowed size for the current flat file. Also detect overflow because
// even though it isn't possible currently, numbers might change in
// the future to make it possible.
//
// NOTE: The writeCursor.currentOffset field isn't protected by the
// mutex since it's only read/changed during this function which can
// only be called during a write transaction, of which there can be
// only one at a time.
cursor := s.writeCursor
finalOffset := cursor.currentOffset + fullLength
if finalOffset < cursor.currentOffset || finalOffset > maxFileSize {
// This is done under the write cursor lock since the curFileNum
// field is accessed elsewhere by readers.
//
// Close the current write file to force a read-only reopen
// with LRU tracking. The close is done under the write lock
// for the file to prevent it from being closed out from under
// any readers currently reading from it.
func() {
cursor.Lock()
defer cursor.Unlock()
s.closeCurrentWriteCursorFile()
// Start writes into next file.
cursor.currentFileNumber++
cursor.currentOffset = 0
}()
}
// All writes are done under the write lock for the file to ensure any
// readers are finished and blocked first.
cursor.currentFile.Lock()
defer cursor.currentFile.Unlock()
// Open the current file if needed. This will typically only be the
// case when moving to the next file to write to or on initial database
// load. However, it might also be the case if rollbacks happened after
// file writes started during a transaction commit.
if cursor.currentFile.file == nil {
file, err := s.openWriteFile(cursor.currentFileNumber)
if err != nil {
return nil, err
}
cursor.currentFile.file = file
}
originalOffset := cursor.currentOffset
hasher := crc32.New(castagnoli)
var scratch [4]byte
// Data length.
byteOrder.PutUint32(scratch[:], dataLength)
err := s.writeData(scratch[:], "data length")
if err != nil {
return nil, err
}
_, _ = hasher.Write(scratch[:])
// Data.
err = s.writeData(data[:], "data")
if err != nil {
return nil, err
}
_, _ = hasher.Write(data)
// Castagnoli CRC-32 as a checksum of all the previous.
err = s.writeData(hasher.Sum(nil), "checksum")
if err != nil {
return nil, err
}
// Sync the file to disk.
err = cursor.currentFile.file.Sync()
if err != nil {
return nil, errors.Wrapf(err, "failed to sync file %d "+
"in store '%s'", cursor.currentFileNumber, s.storeName)
}
location := &flatFileLocation{
fileNumber: cursor.currentFileNumber,
fileOffset: originalOffset,
dataLength: fullLength,
}
return location, nil
}
// openWriteFile returns a file handle for the passed flat file number in
// read/write mode. The file will be created if needed. It is typically used
// for the current file that will have all new data appended. Unlike openFile,
// this function does not keep track of the open file and it is not subject to
// the maxOpenFiles limit.
func (s *flatFileStore) openWriteFile(fileNumber uint32) (file, error) {
// The current flat file needs to be read-write so it is possible to
// append to it. Also, it shouldn't be part of the least recently used
// file.
filePath := flatFilePath(s.basePath, s.storeName, fileNumber)
file, err := os.OpenFile(filePath, os.O_RDWR|os.O_CREATE, 0666)
if err != nil {
return nil, errors.Wrapf(err, "failed to open file %q",
filePath)
}
return file, nil
}
// writeData is a helper function for write which writes the provided data at
// the current write offset and updates the write cursor accordingly. The field
// name parameter is only used when there is an error to provide a nicer error
// message.
//
// The write cursor will be advanced the number of bytes actually written in the
// event of failure.
//
// NOTE: This function MUST be called with the write cursor current file lock
// held and must only be called during a write transaction so it is effectively
// locked for writes. Also, the write cursor current file must NOT be nil.
func (s *flatFileStore) writeData(data []byte, fieldName string) error {
cursor := s.writeCursor
n, err := cursor.currentFile.file.WriteAt(data, int64(cursor.currentOffset))
cursor.currentOffset += uint32(n)
if err != nil {
var pathErr *os.PathError
if ok := errors.As(err, &pathErr); ok && pathErr.Err == syscall.ENOSPC {
panics.Exit(log, "No space left on the hard disk.")
}
return errors.Wrapf(err, "failed to write %s in store %s to file %d "+
"at offset %d", fieldName, s.storeName, cursor.currentFileNumber,
cursor.currentOffset-uint32(n))
}
return nil
}
// closeCurrentWriteCursorFile closes the currently open writeCursor file if
// it's open.
// This method MUST be called with the writeCursor lock held for writes.
func (s *flatFileStore) closeCurrentWriteCursorFile() {
s.writeCursor.currentFile.Lock()
defer s.writeCursor.currentFile.Unlock()
if s.writeCursor.currentFile.file != nil {
_ = s.writeCursor.currentFile.file.Close()
s.writeCursor.currentFile.file = nil
}
}

View File

@@ -1,178 +0,0 @@
package ffldb
import (
"github.com/kaspanet/kaspad/infrastructure/db/database"
"github.com/kaspanet/kaspad/infrastructure/db/database/ffldb/ff"
"github.com/kaspanet/kaspad/infrastructure/db/database/ffldb/ldb"
"github.com/pkg/errors"
)
var (
// flatFilesBucket keeps an index flat-file stores and their
// current locations. Among other things, it is used to repair
// the database in case a corruption occurs.
flatFilesBucket = database.MakeBucket([]byte("flat-files"))
)
// ffldb is a database utilizing LevelDB for key-value data and
// flat-files for raw data storage.
type ffldb struct {
flatFileDB *ff.FlatFileDB
levelDB *ldb.LevelDB
}
// Open opens a new ffldb with the given path.
func Open(path string) (database.Database, error) {
flatFileDB := ff.NewFlatFileDB(path)
levelDB, err := ldb.NewLevelDB(path)
if err != nil {
return nil, err
}
db := &ffldb{
flatFileDB: flatFileDB,
levelDB: levelDB,
}
err = db.initialize()
if err != nil {
return nil, err
}
return db, nil
}
// Close closes the database.
// This method is part of the Database interface.
func (db *ffldb) Close() error {
err := db.flatFileDB.Close()
if err != nil {
ldbCloseErr := db.levelDB.Close()
if ldbCloseErr != nil {
return errors.Wrapf(err, "err occurred during leveldb close: %s", ldbCloseErr)
}
return err
}
return db.levelDB.Close()
}
// Put sets the value for the given key. It overwrites
// any previous value for that key.
// This method is part of the DataAccessor interface.
func (db *ffldb) Put(key *database.Key, value []byte) error {
return db.levelDB.Put(key, value)
}
// Get gets the value for the given key. It returns
// ErrNotFound if the given key does not exist.
// This method is part of the DataAccessor interface.
func (db *ffldb) Get(key *database.Key) ([]byte, error) {
return db.levelDB.Get(key)
}
// Has returns true if the database does contains the
// given key.
// This method is part of the DataAccessor interface.
func (db *ffldb) Has(key *database.Key) (bool, error) {
return db.levelDB.Has(key)
}
// Delete deletes the value for the given key. Will not
// return an error if the key doesn't exist.
// This method is part of the DataAccessor interface.
func (db *ffldb) Delete(key *database.Key) error {
return db.levelDB.Delete(key)
}
// AppendToStore appends the given data to the flat
// file store defined by storeName. This function
// returns a serialized location handle that's meant
// to be stored and later used when querying the data
// that has just now been inserted.
// This method is part of the DataAccessor interface.
func (db *ffldb) AppendToStore(storeName string, data []byte) ([]byte, error) {
return appendToStore(db, db.flatFileDB, storeName, data)
}
func appendToStore(accessor database.DataAccessor, ffdb *ff.FlatFileDB, storeName string, data []byte) ([]byte, error) {
// Save a reference to the current location in case
// we fail and need to rollback.
previousLocation, err := ffdb.CurrentLocation(storeName)
if err != nil {
return nil, err
}
rollback := func() error {
return ffdb.Rollback(storeName, previousLocation)
}
// Append the data to the store and rollback in case of an error.
location, err := ffdb.Write(storeName, data)
if err != nil {
rollbackErr := rollback()
if rollbackErr != nil {
return nil, errors.Wrapf(err, "error occurred during rollback: %s", rollbackErr)
}
return nil, err
}
// Get the new location. If this fails we won't be able to update
// the current store location, in which case we roll back.
currentLocation, err := ffdb.CurrentLocation(storeName)
if err != nil {
rollbackErr := rollback()
if rollbackErr != nil {
return nil, errors.Wrapf(err, "error occurred during rollback: %s", rollbackErr)
}
return nil, err
}
// Set the current store location and roll back in case an error.
err = setCurrentStoreLocation(accessor, storeName, currentLocation)
if err != nil {
rollbackErr := rollback()
if rollbackErr != nil {
return nil, errors.Wrapf(err, "error occurred during rollback: %s", rollbackErr)
}
return nil, err
}
return location, err
}
func setCurrentStoreLocation(accessor database.DataAccessor, storeName string, location []byte) error {
locationKey := flatFilesBucket.Key([]byte(storeName))
return accessor.Put(locationKey, location)
}
// RetrieveFromStore retrieves data from the store defined by
// storeName using the given serialized location handle. It
// returns ErrNotFound if the location does not exist. See
// AppendToStore for further details.
// This method is part of the DataAccessor interface.
func (db *ffldb) RetrieveFromStore(storeName string, location []byte) ([]byte, error) {
return db.flatFileDB.Read(storeName, location)
}
// Cursor begins a new cursor over the given bucket.
// This method is part of the DataAccessor interface.
func (db *ffldb) Cursor(bucket *database.Bucket) (database.Cursor, error) {
ldbCursor := db.levelDB.Cursor(bucket)
return ldbCursor, nil
}
// Begin begins a new ffldb transaction.
// This method is part of the Database interface.
func (db *ffldb) Begin() (database.Transaction, error) {
ldbTx, err := db.levelDB.Begin()
if err != nil {
return nil, err
}
transaction := &transaction{
ldbTx: ldbTx,
ffdb: db.flatFileDB,
isClosed: false,
}
return transaction, nil
}

View File

@@ -1,131 +0,0 @@
package ffldb
import (
"github.com/kaspanet/kaspad/infrastructure/db/database"
"io/ioutil"
"reflect"
"testing"
)
func prepareDatabaseForTest(t *testing.T, testName string) (db database.Database, teardownFunc func()) {
// Create a temp db to run tests against
path, err := ioutil.TempDir("", testName)
if err != nil {
t.Fatalf("%s: TempDir unexpectedly "+
"failed: %s", testName, err)
}
db, err = Open(path)
if err != nil {
t.Fatalf("%s: Open unexpectedly "+
"failed: %s", testName, err)
}
teardownFunc = func() {
err = db.Close()
if err != nil {
t.Fatalf("%s: Close unexpectedly "+
"failed: %s", testName, err)
}
}
return db, teardownFunc
}
func TestRepairFlatFiles(t *testing.T) {
// Create a temp db to run tests against
path, err := ioutil.TempDir("", "TestRepairFlatFiles")
if err != nil {
t.Fatalf("TestRepairFlatFiles: TempDir unexpectedly "+
"failed: %s", err)
}
db, err := Open(path)
if err != nil {
t.Fatalf("TestRepairFlatFiles: Open unexpectedly "+
"failed: %s", err)
}
isOpen := true
defer func() {
if isOpen {
err := db.Close()
if err != nil {
t.Fatalf("TestRepairFlatFiles: Close unexpectedly "+
"failed: %s", err)
}
}
}()
// Cast to ffldb since we're going to be messing with its internals
ffldbInstance, ok := db.(*ffldb)
if !ok {
t.Fatalf("TestRepairFlatFiles: unexpectedly can't cast " +
"db to ffldb")
}
// Append data to the same store
storeName := "test"
_, err = ffldbInstance.AppendToStore(storeName, []byte("data1"))
if err != nil {
t.Fatalf("TestRepairFlatFiles: AppendToStore unexpectedly "+
"failed: %s", err)
}
// Grab the current location to test against later
oldCurrentLocation, err := ffldbInstance.flatFileDB.CurrentLocation(storeName)
if err != nil {
t.Fatalf("TestRepairFlatFiles: CurrentStoreLocation "+
"unexpectedly failed: %s", err)
}
// Append more data to the same store. We expect this to disappear later.
location2, err := ffldbInstance.AppendToStore(storeName, []byte("data2"))
if err != nil {
t.Fatalf("TestRepairFlatFiles: AppendToStore unexpectedly "+
"failed: %s", err)
}
// Manually update the current location to point to the first piece of data
err = setCurrentStoreLocation(ffldbInstance, storeName, oldCurrentLocation)
if err != nil {
t.Fatalf("TestRepairFlatFiles: setCurrentStoreLocation "+
"unexpectedly failed: %s", err)
}
// Reopen the database
err = ffldbInstance.Close()
if err != nil {
t.Fatalf("TestRepairFlatFiles: Close unexpectedly "+
"failed: %s", err)
}
isOpen = false
db, err = Open(path)
if err != nil {
t.Fatalf("TestRepairFlatFiles: Open unexpectedly "+
"failed: %s", err)
}
isOpen = true
ffldbInstance, ok = db.(*ffldb)
if !ok {
t.Fatalf("TestRepairFlatFiles: unexpectedly can't cast " +
"db to ffldb")
}
// Make sure that the current location rolled back as expected
currentLocation, err := ffldbInstance.flatFileDB.CurrentLocation(storeName)
if err != nil {
t.Fatalf("TestRepairFlatFiles: CurrentStoreLocation "+
"unexpectedly failed: %s", err)
}
if !reflect.DeepEqual(oldCurrentLocation, currentLocation) {
t.Fatalf("TestRepairFlatFiles: currentLocation did " +
"not roll back")
}
// Make sure that we can't get data that no longer exists
_, err = ffldbInstance.RetrieveFromStore(storeName, location2)
if err == nil {
t.Fatalf("TestRepairFlatFiles: RetrieveFromStore " +
"unexpectedly succeeded")
}
if !database.IsNotFoundError(err) {
t.Fatalf("TestRepairFlatFiles: RetrieveFromStore "+
"returned wrong error: %s", err)
}
}

View File

@@ -1,55 +0,0 @@
package ffldb
// initialize initializes the database. If this function fails then the
// database is irrecoverably corrupted.
func (db *ffldb) initialize() error {
flatFiles, err := db.flatFiles()
if err != nil {
return err
}
for storeName, currentLocation := range flatFiles {
err := db.tryRepair(storeName, currentLocation)
if err != nil {
return err
}
}
return nil
}
func (db *ffldb) flatFiles() (map[string][]byte, error) {
flatFilesCursor := db.levelDB.Cursor(flatFilesBucket)
defer func() {
err := flatFilesCursor.Close()
if err != nil {
log.Warnf("cursor failed to close")
}
}()
flatFiles := make(map[string][]byte)
for flatFilesCursor.Next() {
storeNameKey, err := flatFilesCursor.Key()
if err != nil {
return nil, err
}
storeName := string(storeNameKey.Suffix())
currentLocation, err := flatFilesCursor.Value()
if err != nil {
return nil, err
}
flatFiles[storeName] = currentLocation
}
return flatFiles, nil
}
// tryRepair attempts to sync the store with the current location value.
// Possible scenarios:
// a. currentLocation and the store are synced. Rollback does nothing.
// b. currentLocation is smaller than the store's location. Rollback truncates
// the store.
// c. currentLocation is greater than the store's location. Rollback returns an
// error. This indicates definite database corruption and is irrecoverable.
func (db *ffldb) tryRepair(storeName string, currentLocation []byte) error {
return db.flatFileDB.Rollback(storeName, currentLocation)
}

View File

@@ -1,5 +0,0 @@
package ffldb
import "github.com/kaspanet/kaspad/infrastructure/logger"
var log, _ = logger.Get(logger.SubsystemTags.KSDB)

View File

@@ -1,137 +0,0 @@
package ffldb
import (
"github.com/kaspanet/kaspad/infrastructure/db/database"
"github.com/kaspanet/kaspad/infrastructure/db/database/ffldb/ff"
"github.com/kaspanet/kaspad/infrastructure/db/database/ffldb/ldb"
"github.com/pkg/errors"
)
// transaction is an ffldb transaction.
//
// Note: Transactions provide data consistency over the state of
// the database as it was when the transaction started. There is
// NO guarantee that if one puts data into the transaction then
// it will be available to get within the same transaction.
type transaction struct {
ldbTx *ldb.LevelDBTransaction
ffdb *ff.FlatFileDB
isClosed bool
}
// Put sets the value for the given key. It overwrites
// any previous value for that key.
// This method is part of the DataAccessor interface.
func (tx *transaction) Put(key *database.Key, value []byte) error {
if tx.isClosed {
return errors.New("cannot put into a closed transaction")
}
return tx.ldbTx.Put(key, value)
}
// Get gets the value for the given key. It returns
// ErrNotFound if the given key does not exist.
// This method is part of the DataAccessor interface.
func (tx *transaction) Get(key *database.Key) ([]byte, error) {
if tx.isClosed {
return nil, errors.New("cannot get from a closed transaction")
}
return tx.ldbTx.Get(key)
}
// Has returns true if the database does contains the
// given key.
// This method is part of the DataAccessor interface.
func (tx *transaction) Has(key *database.Key) (bool, error) {
if tx.isClosed {
return false, errors.New("cannot has from a closed transaction")
}
return tx.ldbTx.Has(key)
}
// Delete deletes the value for the given key. Will not
// return an error if the key doesn't exist.
// This method is part of the DataAccessor interface.
func (tx *transaction) Delete(key *database.Key) error {
if tx.isClosed {
return errors.New("cannot delete from a closed transaction")
}
return tx.ldbTx.Delete(key)
}
// AppendToStore appends the given data to the flat
// file store defined by storeName. This function
// returns a serialized location handle that's meant
// to be stored and later used when querying the data
// that has just now been inserted.
// This method is part of the DataAccessor interface.
func (tx *transaction) AppendToStore(storeName string, data []byte) ([]byte, error) {
if tx.isClosed {
return nil, errors.New("cannot append to store on a closed transaction")
}
return appendToStore(tx, tx.ffdb, storeName, data)
}
// RetrieveFromStore retrieves data from the store defined by
// storeName using the given serialized location handle. It
// returns ErrNotFound if the location does not exist. See
// AppendToStore for further details.
// This method is part of the DataAccessor interface.
func (tx *transaction) RetrieveFromStore(storeName string, location []byte) ([]byte, error) {
if tx.isClosed {
return nil, errors.New("cannot retrieve from store on a closed transaction")
}
return tx.ffdb.Read(storeName, location)
}
// Cursor begins a new cursor over the given bucket.
// This method is part of the DataAccessor interface.
func (tx *transaction) Cursor(bucket *database.Bucket) (database.Cursor, error) {
if tx.isClosed {
return nil, errors.New("cannot open a cursor from a closed transaction")
}
return tx.ldbTx.Cursor(bucket)
}
// Rollback rolls back whatever changes were made to the
// database within this transaction.
// This method is part of the Transaction interface.
func (tx *transaction) Rollback() error {
if tx.isClosed {
return errors.New("cannot rollback a closed transaction")
}
tx.isClosed = true
return tx.ldbTx.Rollback()
}
// Commit commits whatever changes were made to the database
// within this transaction.
// This method is part of the Transaction interface.
func (tx *transaction) Commit() error {
if tx.isClosed {
return errors.New("cannot commit a closed transaction")
}
tx.isClosed = true
return tx.ldbTx.Commit()
}
// RollbackUnlessClosed rolls back changes that were made to
// the database within the transaction, unless the transaction
// had already been closed using either Rollback or Commit.
func (tx *transaction) RollbackUnlessClosed() error {
if tx.isClosed {
return nil
}
tx.isClosed = true
return tx.ldbTx.RollbackUnlessClosed()
}

View File

@@ -1,500 +0,0 @@
package ffldb
import (
"bytes"
"github.com/kaspanet/kaspad/infrastructure/db/database"
"strings"
"testing"
)
func TestTransactionCommitForLevelDBMethods(t *testing.T) {
db, teardownFunc := prepareDatabaseForTest(t, "TestTransactionCommitForLevelDBMethods")
defer teardownFunc()
// Put a value into the database
key1 := database.MakeBucket().Key([]byte("key1"))
value1 := []byte("value1")
err := db.Put(key1, value1)
if err != nil {
t.Fatalf("TestTransactionCommitForLevelDBMethods: Put "+
"unexpectedly failed: %s", err)
}
// Begin a new transaction
dbTx, err := db.Begin()
if err != nil {
t.Fatalf("TestTransactionCommitForLevelDBMethods: Begin "+
"unexpectedly failed: %s", err)
}
defer func() {
err := dbTx.RollbackUnlessClosed()
if err != nil {
t.Fatalf("TestTransactionCommitForLevelDBMethods: RollbackUnlessClosed "+
"unexpectedly failed: %s", err)
}
}()
// Make sure that Has returns that the original value exists
exists, err := dbTx.Has(key1)
if err != nil {
t.Fatalf("TestTransactionCommitForLevelDBMethods: Has "+
"unexpectedly failed: %s", err)
}
if !exists {
t.Fatalf("TestTransactionCommitForLevelDBMethods: Has " +
"unexpectedly returned that the value does not exist")
}
// Get the existing value and make sure it's equal to the original
existingValue, err := dbTx.Get(key1)
if err != nil {
t.Fatalf("TestTransactionCommitForLevelDBMethods: Get "+
"unexpectedly failed: %s", err)
}
if !bytes.Equal(existingValue, value1) {
t.Fatalf("TestTransactionCommitForLevelDBMethods: Get "+
"returned unexpected value. Want: %s, got: %s",
string(value1), string(existingValue))
}
// Delete the existing value
err = dbTx.Delete(key1)
if err != nil {
t.Fatalf("TestTransactionCommitForLevelDBMethods: Delete "+
"unexpectedly failed: %s", err)
}
// Try to get a value that does not exist and make sure it returns ErrNotFound
_, err = dbTx.Get(database.MakeBucket().Key([]byte("doesn't exist")))
if err == nil {
t.Fatalf("TestTransactionCommitForLevelDBMethods: Get " +
"unexpectedly succeeded")
}
if !database.IsNotFoundError(err) {
t.Fatalf("TestTransactionCommitForLevelDBMethods: Get "+
"returned unexpected error: %s", err)
}
// Put a new value
key2 := database.MakeBucket().Key([]byte("key2"))
value2 := []byte("value2")
err = dbTx.Put(key2, value2)
if err != nil {
t.Fatalf("TestTransactionCommitForLevelDBMethods: Put "+
"unexpectedly failed: %s", err)
}
// Commit the transaction
err = dbTx.Commit()
if err != nil {
t.Fatalf("TestTransactionCommitForLevelDBMethods: Commit "+
"unexpectedly failed: %s", err)
}
// Make sure that Has returns that the original value does NOT exist
exists, err = db.Has(key1)
if err != nil {
t.Fatalf("TestTransactionCommitForLevelDBMethods: Has "+
"unexpectedly failed: %s", err)
}
if exists {
t.Fatalf("TestTransactionCommitForLevelDBMethods: Has " +
"unexpectedly returned that the value exists")
}
// Try to Get the existing value and make sure an ErrNotFound is returned
_, err = db.Get(key1)
if err == nil {
t.Fatalf("TestTransactionCommitForLevelDBMethods: Get " +
"unexpectedly succeeded")
}
if !database.IsNotFoundError(err) {
t.Fatalf("TestTransactionCommitForLevelDBMethods: Get "+
"returned unexpected err: %s", err)
}
// Make sure that Has returns that the new value exists
exists, err = db.Has(key2)
if err != nil {
t.Fatalf("TestTransactionCommitForLevelDBMethods: Has "+
"unexpectedly failed: %s", err)
}
if !exists {
t.Fatalf("TestTransactionCommitForLevelDBMethods: Has " +
"unexpectedly returned that the value does not exist")
}
// Get the new value and make sure it's equal to the original
existingValue, err = db.Get(key2)
if err != nil {
t.Fatalf("TestTransactionCommitForLevelDBMethods: Get "+
"unexpectedly failed: %s", err)
}
if !bytes.Equal(existingValue, value2) {
t.Fatalf("TestTransactionCommitForLevelDBMethods: Get "+
"returned unexpected value. Want: %s, got: %s",
string(value2), string(existingValue))
}
}
func TestTransactionRollbackForLevelDBMethods(t *testing.T) {
db, teardownFunc := prepareDatabaseForTest(t, "TestTransactionRollbackForLevelDBMethods")
defer teardownFunc()
// Put a value into the database
key1 := database.MakeBucket().Key([]byte("key1"))
value1 := []byte("value1")
err := db.Put(key1, value1)
if err != nil {
t.Fatalf("TestTransactionRollbackForLevelDBMethods: Put "+
"unexpectedly failed: %s", err)
}
// Begin a new transaction
dbTx, err := db.Begin()
if err != nil {
t.Fatalf("TestTransactionRollbackForLevelDBMethods: Begin "+
"unexpectedly failed: %s", err)
}
defer func() {
err := dbTx.RollbackUnlessClosed()
if err != nil {
t.Fatalf("TestTransactionRollbackForLevelDBMethods: RollbackUnlessClosed "+
"unexpectedly failed: %s", err)
}
}()
// Make sure that Has returns that the original value exists
exists, err := dbTx.Has(key1)
if err != nil {
t.Fatalf("TestTransactionRollbackForLevelDBMethods: Has "+
"unexpectedly failed: %s", err)
}
if !exists {
t.Fatalf("TestTransactionRollbackForLevelDBMethods: Has " +
"unexpectedly returned that the value does not exist")
}
// Get the existing value and make sure it's equal to the original
existingValue, err := dbTx.Get(key1)
if err != nil {
t.Fatalf("TestTransactionRollbackForLevelDBMethods: Get "+
"unexpectedly failed: %s", err)
}
if !bytes.Equal(existingValue, value1) {
t.Fatalf("TestTransactionRollbackForLevelDBMethods: Get "+
"returned unexpected value. Want: %s, got: %s",
string(value1), string(existingValue))
}
// Delete the existing value
err = dbTx.Delete(key1)
if err != nil {
t.Fatalf("TestTransactionRollbackForLevelDBMethods: Delete "+
"unexpectedly failed: %s", err)
}
// Put a new value
key2 := database.MakeBucket().Key([]byte("key2"))
value2 := []byte("value2")
err = dbTx.Put(key2, value2)
if err != nil {
t.Fatalf("TestTransactionRollbackForLevelDBMethods: Put "+
"unexpectedly failed: %s", err)
}
// Rollback the transaction
err = dbTx.Rollback()
if err != nil {
t.Fatalf("TestTransactionRollbackForLevelDBMethods: Rollback "+
"unexpectedly failed: %s", err)
}
// Make sure that Has returns that the original value still exists
exists, err = db.Has(key1)
if err != nil {
t.Fatalf("TestTransactionRollbackForLevelDBMethods: Has "+
"unexpectedly failed: %s", err)
}
if !exists {
t.Fatalf("TestTransactionRollbackForLevelDBMethods: Has " +
"unexpectedly returned that the value does not exist")
}
// Get the existing value and make sure it is still returned
existingValue, err = db.Get(key1)
if err != nil {
t.Fatalf("TestTransactionRollbackForLevelDBMethods: Get "+
"unexpectedly failed: %s", err)
}
if !bytes.Equal(existingValue, value1) {
t.Fatalf("TestTransactionRollbackForLevelDBMethods: Get "+
"returned unexpected value. Want: %s, got: %s",
string(value1), string(existingValue))
}
// Make sure that Has returns that the new value does NOT exist
exists, err = db.Has(key2)
if err != nil {
t.Fatalf("TestTransactionRollbackForLevelDBMethods: Has "+
"unexpectedly failed: %s", err)
}
if exists {
t.Fatalf("TestTransactionRollbackForLevelDBMethods: Has " +
"unexpectedly returned that the value exists")
}
// Try to Get the new value and make sure it returns an ErrNotFound
_, err = db.Get(key2)
if err == nil {
t.Fatalf("TestTransactionRollbackForLevelDBMethods: Get " +
"unexpectedly succeeded")
}
if !database.IsNotFoundError(err) {
t.Fatalf("TestTransactionRollbackForLevelDBMethods: Get "+
"returned unexpected error: %s", err)
}
}
func TestTransactionCloseErrors(t *testing.T) {
tests := []struct {
name string
function func(dbTx database.Transaction) error
shouldReturnError bool
}{
{
name: "Put",
function: func(dbTx database.Transaction) error {
return dbTx.Put(database.MakeBucket().Key([]byte("key")), []byte("value"))
},
shouldReturnError: true,
},
{
name: "Get",
function: func(dbTx database.Transaction) error {
_, err := dbTx.Get(database.MakeBucket().Key([]byte("key")))
return err
},
shouldReturnError: true,
},
{
name: "Has",
function: func(dbTx database.Transaction) error {
_, err := dbTx.Has(database.MakeBucket().Key([]byte("key")))
return err
},
shouldReturnError: true,
},
{
name: "Delete",
function: func(dbTx database.Transaction) error {
return dbTx.Delete(database.MakeBucket().Key([]byte("key")))
},
shouldReturnError: true,
},
{
name: "Cursor",
function: func(dbTx database.Transaction) error {
_, err := dbTx.Cursor(database.MakeBucket([]byte("bucket")))
return err
},
shouldReturnError: true,
},
{
name: "AppendToStore",
function: func(dbTx database.Transaction) error {
_, err := dbTx.AppendToStore("store", []byte("data"))
return err
},
shouldReturnError: true,
},
{
name: "RetrieveFromStore",
function: func(dbTx database.Transaction) error {
_, err := dbTx.RetrieveFromStore("store", []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})
return err
},
shouldReturnError: true,
},
{
name: "Rollback",
function: func(dbTx database.Transaction) error {
return dbTx.Rollback()
},
shouldReturnError: true,
},
{
name: "Commit",
function: func(dbTx database.Transaction) error {
return dbTx.Commit()
},
shouldReturnError: true,
},
{
name: "RollbackUnlessClosed",
function: func(dbTx database.Transaction) error {
return dbTx.RollbackUnlessClosed()
},
shouldReturnError: false,
},
}
for _, test := range tests {
func() {
db, teardownFunc := prepareDatabaseForTest(t, "TestTransactionCloseErrors")
defer teardownFunc()
// Begin a new transaction to test Commit
commitTx, err := db.Begin()
if err != nil {
t.Fatalf("TestTransactionCloseErrors: Begin "+
"unexpectedly failed: %s", err)
}
defer func() {
err := commitTx.RollbackUnlessClosed()
if err != nil {
t.Fatalf("TestTransactionCloseErrors: RollbackUnlessClosed "+
"unexpectedly failed: %s", err)
}
}()
// Commit the Commit test transaction
err = commitTx.Commit()
if err != nil {
t.Fatalf("TestTransactionCloseErrors: Commit "+
"unexpectedly failed: %s", err)
}
// Begin a new transaction to test Rollback
rollbackTx, err := db.Begin()
if err != nil {
t.Fatalf("TestTransactionCloseErrors: Begin "+
"unexpectedly failed: %s", err)
}
defer func() {
err := rollbackTx.RollbackUnlessClosed()
if err != nil {
t.Fatalf("TestTransactionCloseErrors: RollbackUnlessClosed "+
"unexpectedly failed: %s", err)
}
}()
// Rollback the Rollback test transaction
err = rollbackTx.Rollback()
if err != nil {
t.Fatalf("TestTransactionCloseErrors: Rollback "+
"unexpectedly failed: %s", err)
}
expectedErrContainsString := "closed transaction"
// Make sure that the test function returns a "closed transaction" error
// for both the commitTx and the rollbackTx
for _, closedTx := range []database.Transaction{commitTx, rollbackTx} {
err = test.function(closedTx)
if test.shouldReturnError {
if err == nil {
t.Fatalf("TestTransactionCloseErrors: %s "+
"unexpectedly succeeded", test.name)
}
if !strings.Contains(err.Error(), expectedErrContainsString) {
t.Fatalf("TestTransactionCloseErrors: %s "+
"returned wrong error. Want: %s, got: %s",
test.name, expectedErrContainsString, err)
}
} else {
if err != nil {
t.Fatalf("TestTransactionCloseErrors: %s "+
"unexpectedly failed: %s", test.name, err)
}
}
}
}()
}
}
func TestTransactionRollbackUnlessClosed(t *testing.T) {
db, teardownFunc := prepareDatabaseForTest(t, "TestTransactionRollbackUnlessClosed")
defer teardownFunc()
// Begin a new transaction
dbTx, err := db.Begin()
if err != nil {
t.Fatalf("TestTransactionRollbackUnlessClosed: Begin "+
"unexpectedly failed: %s", err)
}
// Roll it back
err = dbTx.RollbackUnlessClosed()
if err != nil {
t.Fatalf("TestTransactionRollbackUnlessClosed: RollbackUnlessClosed "+
"unexpectedly failed: %s", err)
}
}
func TestTransactionCommitForFlatFileMethods(t *testing.T) {
db, teardownFunc := prepareDatabaseForTest(t, "TestTransactionCommitForFlatFileMethods")
defer teardownFunc()
// Put a value into the database
store := "store"
value1 := []byte("value1")
location1, err := db.AppendToStore(store, value1)
if err != nil {
t.Fatalf("TestTransactionCommitForFlatFileMethods: AppendToStore "+
"unexpectedly failed: %s", err)
}
// Begin a new transaction
dbTx, err := db.Begin()
if err != nil {
t.Fatalf("TestTransactionCommitForFlatFileMethods: Begin "+
"unexpectedly failed: %s", err)
}
defer func() {
err := dbTx.RollbackUnlessClosed()
if err != nil {
t.Fatalf("TestTransactionCommitForFlatFileMethods: RollbackUnlessClosed "+
"unexpectedly failed: %s", err)
}
}()
// Retrieve the existing value and make sure it's equal to the original
existingValue, err := dbTx.RetrieveFromStore(store, location1)
if err != nil {
t.Fatalf("TestTransactionCommitForFlatFileMethods: RetrieveFromStore "+
"unexpectedly failed: %s", err)
}
if !bytes.Equal(existingValue, value1) {
t.Fatalf("TestTransactionCommitForFlatFileMethods: RetrieveFromStore "+
"returned unexpected value. Want: %s, got: %s",
string(value1), string(existingValue))
}
// Put a new value
value2 := []byte("value2")
location2, err := dbTx.AppendToStore(store, value2)
if err != nil {
t.Fatalf("TestTransactionCommitForFlatFileMethods: AppendToStore "+
"unexpectedly failed: %s", err)
}
// Commit the transaction
err = dbTx.Commit()
if err != nil {
t.Fatalf("TestTransactionCommitForFlatFileMethods: Commit "+
"unexpectedly failed: %s", err)
}
// Retrieve the new value and make sure it's equal to the original
newValue, err := db.RetrieveFromStore(store, location2)
if err != nil {
t.Fatalf("TestTransactionCommitForFlatFileMethods: RetrieveFromStore "+
"unexpectedly failed: %s", err)
}
if !bytes.Equal(newValue, value2) {
t.Fatalf("TestTransactionCommitForFlatFileMethods: RetrieveFromStore "+
"returned unexpected value. Want: %s, got: %s",
string(value2), string(newValue))
}
}

View File

@@ -18,13 +18,14 @@ type LevelDBCursor struct {
}
// Cursor begins a new cursor over the given prefix.
func (db *LevelDB) Cursor(bucket *database.Bucket) *LevelDBCursor {
func (db *LevelDB) Cursor(bucket *database.Bucket) (database.Cursor, error) {
ldbIterator := db.ldb.NewIterator(util.BytesPrefix(bucket.Path()), nil)
return &LevelDBCursor{
ldbIterator: ldbIterator,
bucket: bucket,
isClosed: false,
}
}, nil
}
// Next moves the iterator to the next key/value pair. It returns whether the

View File

@@ -3,13 +3,14 @@ package ldb
import (
"bytes"
"fmt"
"github.com/kaspanet/kaspad/infrastructure/db/database"
"reflect"
"strings"
"testing"
"github.com/kaspanet/kaspad/infrastructure/db/database"
)
func validateCurrentCursorKeyAndValue(t *testing.T, testName string, cursor *LevelDBCursor,
func validateCurrentCursorKeyAndValue(t *testing.T, testName string, cursor database.Cursor,
expectedKey *database.Key, expectedValue []byte) {
cursorKey, err := cursor.Key()
@@ -70,7 +71,11 @@ func TestCursorSanity(t *testing.T) {
}
// Open a new cursor
cursor := ldb.Cursor(bucket)
cursor, err := ldb.Cursor(bucket)
if err != nil {
t.Fatalf("TestCursorSanity: ldb.Cursor "+
"unexpectedly failed: %s", err)
}
defer func() {
err := cursor.Close()
if err != nil {
@@ -90,7 +95,7 @@ func TestCursorSanity(t *testing.T) {
validateCurrentCursorKeyAndValue(t, "TestCursorSanity", cursor, expectedKey, expectedValue)
// Seek to a non-existant key
err := cursor.Seek(database.MakeBucket().Key([]byte("doesn't exist")))
err = cursor.Seek(database.MakeBucket().Key([]byte("doesn't exist")))
if err == nil {
t.Fatalf("TestCursorSanity: Seek " +
"unexpectedly succeeded")
@@ -145,31 +150,31 @@ func TestCursorCloseErrors(t *testing.T) {
// function is the LevelDBCursor function that we're
// verifying returns an error after the cursor had
// been closed.
function func(dbTx *LevelDBCursor) error
function func(dbTx database.Cursor) error
}{
{
name: "Seek",
function: func(cursor *LevelDBCursor) error {
function: func(cursor database.Cursor) error {
return cursor.Seek(database.MakeBucket().Key([]byte{}))
},
},
{
name: "Key",
function: func(cursor *LevelDBCursor) error {
function: func(cursor database.Cursor) error {
_, err := cursor.Key()
return err
},
},
{
name: "Value",
function: func(cursor *LevelDBCursor) error {
function: func(cursor database.Cursor) error {
_, err := cursor.Value()
return err
},
},
{
name: "Close",
function: func(cursor *LevelDBCursor) error {
function: func(cursor database.Cursor) error {
return cursor.Close()
},
},
@@ -181,10 +186,14 @@ func TestCursorCloseErrors(t *testing.T) {
defer teardownFunc()
// Open a new cursor
cursor := ldb.Cursor(database.MakeBucket())
cursor, err := ldb.Cursor(database.MakeBucket())
if err != nil {
t.Fatalf("TestCursorCloseErrors: ldb.Cursor "+
"unexpectedly failed: %s", err)
}
// Close the cursor
err := cursor.Close()
err = cursor.Close()
if err != nil {
t.Fatalf("TestCursorCloseErrors: Close "+
"unexpectedly failed: %s", err)
@@ -223,10 +232,14 @@ func TestCursorCloseFirstAndNext(t *testing.T) {
}
// Open a new cursor
cursor := ldb.Cursor(database.MakeBucket([]byte("bucket")))
cursor, err := ldb.Cursor(database.MakeBucket([]byte("bucket")))
if err != nil {
t.Fatalf("TestCursorCloseFirstAndNext: ldb.Cursor "+
"unexpectedly failed: %s", err)
}
// Close the cursor
err := cursor.Close()
err = cursor.Close()
if err != nil {
t.Fatalf("TestCursorCloseFirstAndNext: Close "+
"unexpectedly failed: %s", err)

View File

@@ -1,10 +1,11 @@
package ldb
import (
"github.com/kaspanet/kaspad/infrastructure/db/database"
"io/ioutil"
"reflect"
"testing"
"github.com/kaspanet/kaspad/infrastructure/db/database"
)
func prepareDatabaseForTest(t *testing.T, testName string) (ldb *LevelDB, teardownFunc func()) {

View File

@@ -28,7 +28,7 @@ type LevelDBTransaction struct {
}
// Begin begins a new transaction.
func (db *LevelDB) Begin() (*LevelDBTransaction, error) {
func (db *LevelDB) Begin() (database.Transaction, error) {
snapshot, err := db.ldb.GetSnapshot()
if err != nil {
return nil, errors.WithStack(err)
@@ -131,10 +131,10 @@ func (tx *LevelDBTransaction) Delete(key *database.Key) error {
}
// Cursor begins a new cursor over the given bucket.
func (tx *LevelDBTransaction) Cursor(bucket *database.Bucket) (*LevelDBCursor, error) {
func (tx *LevelDBTransaction) Cursor(bucket *database.Bucket) (database.Cursor, error) {
if tx.isClosed {
return nil, errors.New("cannot open a cursor from a closed transaction")
}
return tx.db.Cursor(bucket), nil
return tx.db.Cursor(bucket)
}

View File

@@ -1,9 +1,10 @@
package ldb
import (
"github.com/kaspanet/kaspad/infrastructure/db/database"
"strings"
"testing"
"github.com/kaspanet/kaspad/infrastructure/db/database"
)
func TestTransactionCloseErrors(t *testing.T) {
@@ -122,8 +123,8 @@ func TestTransactionCloseErrors(t *testing.T) {
// Make sure that the test function returns a "closed transaction" error
// for both the commitTx and the rollbackTx
for _, closedTx := range []*LevelDBTransaction{commitTx, rollbackTx} {
err = test.function(closedTx)
for _, closedTx := range []database.Transaction{commitTx, rollbackTx} {
err = test.function(closedTx.(*LevelDBTransaction))
if test.shouldReturnError {
if err == nil {
t.Fatalf("TestTransactionCloseErrors: %s "+

View File

@@ -7,9 +7,10 @@ package database_test
import (
"bytes"
"github.com/kaspanet/kaspad/infrastructure/db/database"
"strings"
"testing"
"github.com/kaspanet/kaspad/infrastructure/db/database"
)
func TestTransactionPut(t *testing.T) {
@@ -307,59 +308,6 @@ func testTransactionDelete(t *testing.T, db database.Database, testName string)
}
}
func TestTransactionAppendToStoreAndRetrieveFromStore(t *testing.T) {
testForAllDatabaseTypes(t, "TestTransactionAppendToStoreAndRetrieveFromStore", testTransactionAppendToStoreAndRetrieveFromStore)
}
func testTransactionAppendToStoreAndRetrieveFromStore(t *testing.T, db database.Database, testName string) {
// Begin a new transaction
dbTx, err := db.Begin()
if err != nil {
t.Fatalf("%s: Begin "+
"unexpectedly failed: %s", testName, err)
}
defer func() {
err := dbTx.RollbackUnlessClosed()
if err != nil {
t.Fatalf("%s: RollbackUnlessClosed "+
"unexpectedly failed: %s", testName, err)
}
}()
// Append some data into the store
storeName := "store"
data := []byte("data")
location, err := dbTx.AppendToStore(storeName, data)
if err != nil {
t.Fatalf("%s: AppendToStore "+
"unexpectedly failed: %s", testName, err)
}
// Retrieve the data and make sure it's equal to what was appended
retrievedData, err := dbTx.RetrieveFromStore(storeName, location)
if err != nil {
t.Fatalf("%s: RetrieveFromStore "+
"unexpectedly failed: %s", testName, err)
}
if !bytes.Equal(retrievedData, data) {
t.Fatalf("%s: RetrieveFromStore "+
"returned unexpected data. Want: %s, got: %s",
testName, string(data), string(retrievedData))
}
// Make sure that an invalid location returns ErrNotFound
fakeLocation := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}
_, err = dbTx.RetrieveFromStore(storeName, fakeLocation)
if err == nil {
t.Fatalf("%s: RetrieveFromStore "+
"unexpectedly succeeded", testName)
}
if !database.IsNotFoundError(err) {
t.Fatalf("%s: RetrieveFromStore "+
"returned wrong error: %s", testName, err)
}
}
func TestTransactionCommit(t *testing.T) {
testForAllDatabaseTypes(t, "TestTransactionCommit", testTransactionCommit)
}

View File

@@ -6,16 +6,12 @@ import (
"github.com/pkg/errors"
)
const (
blockStoreName = "blocks"
)
var (
blockLocationsBucket = database.MakeBucket([]byte("block-locations"))
blocksBucket = database.MakeBucket([]byte("blocks"))
)
func blockLocationKey(hash *daghash.Hash) *database.Key {
return blockLocationsBucket.Key(hash[:])
func blockKey(hash *daghash.Hash) *database.Key {
return blocksBucket.Key(hash[:])
}
// StoreBlock stores the given block in the database.
@@ -35,14 +31,7 @@ func StoreBlock(context *TxContext, hash *daghash.Hash, blockBytes []byte) error
}
// Write the block's bytes to the block store
blockLocation, err := accessor.AppendToStore(blockStoreName, blockBytes)
if err != nil {
return err
}
// Write the block's hash to the blockLocations bucket
blockLocationsKey := blockLocationKey(hash)
err = accessor.Put(blockLocationsKey, blockLocation)
err = accessor.Put(blockKey(hash), blockBytes)
if err != nil {
return err
}
@@ -58,9 +47,7 @@ func HasBlock(context Context, hash *daghash.Hash) (bool, error) {
return false, err
}
blockLocationsKey := blockLocationKey(hash)
return accessor.Has(blockLocationsKey)
return accessor.Has(blockKey(hash))
}
// FetchBlock returns the block of the given hash. Returns
@@ -72,19 +59,5 @@ func FetchBlock(context Context, hash *daghash.Hash) ([]byte, error) {
return nil, err
}
blockLocationsKey := blockLocationKey(hash)
blockLocation, err := accessor.Get(blockLocationsKey)
if err != nil {
if database.IsNotFoundError(err) {
return nil, errors.Wrapf(err,
"block %s not found", hash)
}
return nil, err
}
bytes, err := accessor.RetrieveFromStore(blockStoreName, blockLocation)
if err != nil {
return nil, err
}
return bytes, nil
return accessor.Get(blockKey(hash))
}

View File

@@ -2,7 +2,7 @@ package dbaccess
import (
"github.com/kaspanet/kaspad/infrastructure/db/database"
"github.com/kaspanet/kaspad/infrastructure/db/database/ffldb"
"github.com/kaspanet/kaspad/infrastructure/db/database/ldb"
)
// DatabaseContext represents a context in which all database queries run
@@ -13,7 +13,7 @@ type DatabaseContext struct {
// New creates a new DatabaseContext with database is in the specified `path`
func New(path string) (*DatabaseContext, error) {
db, err := ffldb.Open(path)
db, err := ldb.NewLevelDB(path)
if err != nil {
return nil, err
}

View File

@@ -36,6 +36,18 @@ func RemoveFromUTXOSet(context Context, outpointKey []byte) error {
return accessor.Delete(key)
}
// GetFromUTXOSet return the given outpoint from the
// database's UTXO set.
func GetFromUTXOSet(context Context, outpointKey []byte) ([]byte, error) {
accessor, err := context.accessor()
if err != nil {
return nil, err
}
key := utxoKey(outpointKey)
return accessor.Get(key)
}
// UTXOSetCursor opens a cursor over all the UTXO entries
// that have been previously added to the database.
func UTXOSetCursor(context Context) (database.Cursor, error) {

View File

@@ -53,6 +53,7 @@ var (
dnssLog = BackendLog.Logger("DNSS")
snvrLog = BackendLog.Logger("SNVR")
ibdsLog = BackendLog.Logger("IBDS")
wsvcLog = BackendLog.Logger("WSVC")
)
// SubsystemTags is an enum of all sub system tags
@@ -83,7 +84,8 @@ var SubsystemTags = struct {
NTAR,
DNSS,
SNVR,
IBDS string
IBDS,
WSVC string
}{
ADXR: "ADXR",
AMGR: "AMGR",
@@ -112,6 +114,7 @@ var SubsystemTags = struct {
DNSS: "DNSS",
SNVR: "SNVR",
IBDS: "IBDS",
WSVC: "WSVC",
}
// subsystemLoggers maps each subsystem identifier to its associated logger.
@@ -143,6 +146,7 @@ var subsystemLoggers = map[string]*Logger{
SubsystemTags.DNSS: dnssLog,
SubsystemTags.SNVR: snvrLog,
SubsystemTags.IBDS: ibdsLog,
SubsystemTags.WSVC: wsvcLog,
}
// InitLog attaches log file and error log file to the backend log.

File diff suppressed because it is too large Load Diff

View File

@@ -5,518 +5,27 @@
package addressmanager
import (
"fmt"
"github.com/kaspanet/kaspad/app/appmessage"
"io/ioutil"
"net"
"reflect"
"testing"
"time"
"github.com/kaspanet/kaspad/app/appmessage"
"github.com/kaspanet/kaspad/infrastructure/config"
"github.com/kaspanet/kaspad/infrastructure/db/dbaccess"
"github.com/kaspanet/kaspad/util/mstime"
"github.com/kaspanet/kaspad/util/subnetworkid"
"github.com/pkg/errors"
)
// naTest is used to describe a test to be performed against the NetAddressKey
// method.
type naTest struct {
in appmessage.NetAddress
want AddressKey
}
// naTests houses all of the tests to be performed against the NetAddressKey
// method.
var naTests = make([]naTest, 0)
// Put some IP in here for convenience. Points to google.
var someIP = "173.194.115.66"
// addNaTests
func addNaTests() {
// IPv4
// Localhost
addNaTest("127.0.0.1", 16111, "127.0.0.1:16111")
addNaTest("127.0.0.1", 16110, "127.0.0.1:16110")
// Class A
addNaTest("1.0.0.1", 16111, "1.0.0.1:16111")
addNaTest("2.2.2.2", 16110, "2.2.2.2:16110")
addNaTest("27.253.252.251", 8335, "27.253.252.251:8335")
addNaTest("123.3.2.1", 8336, "123.3.2.1:8336")
// Private Class A
addNaTest("10.0.0.1", 16111, "10.0.0.1:16111")
addNaTest("10.1.1.1", 16110, "10.1.1.1:16110")
addNaTest("10.2.2.2", 8335, "10.2.2.2:8335")
addNaTest("10.10.10.10", 8336, "10.10.10.10:8336")
// Class B
addNaTest("128.0.0.1", 16111, "128.0.0.1:16111")
addNaTest("129.1.1.1", 16110, "129.1.1.1:16110")
addNaTest("180.2.2.2", 8335, "180.2.2.2:8335")
addNaTest("191.10.10.10", 8336, "191.10.10.10:8336")
// Private Class B
addNaTest("172.16.0.1", 16111, "172.16.0.1:16111")
addNaTest("172.16.1.1", 16110, "172.16.1.1:16110")
addNaTest("172.16.2.2", 8335, "172.16.2.2:8335")
addNaTest("172.16.172.172", 8336, "172.16.172.172:8336")
// Class C
addNaTest("193.0.0.1", 16111, "193.0.0.1:16111")
addNaTest("200.1.1.1", 16110, "200.1.1.1:16110")
addNaTest("205.2.2.2", 8335, "205.2.2.2:8335")
addNaTest("223.10.10.10", 8336, "223.10.10.10:8336")
// Private Class C
addNaTest("192.168.0.1", 16111, "192.168.0.1:16111")
addNaTest("192.168.1.1", 16110, "192.168.1.1:16110")
addNaTest("192.168.2.2", 8335, "192.168.2.2:8335")
addNaTest("192.168.192.192", 8336, "192.168.192.192:8336")
// IPv6
// Localhost
addNaTest("::1", 16111, "[::1]:16111")
addNaTest("fe80::1", 16110, "[fe80::1]:16110")
// Link-local
addNaTest("fe80::1:1", 16111, "[fe80::1:1]:16111")
addNaTest("fe91::2:2", 16110, "[fe91::2:2]:16110")
addNaTest("fea2::3:3", 8335, "[fea2::3:3]:8335")
addNaTest("feb3::4:4", 8336, "[feb3::4:4]:8336")
// Site-local
addNaTest("fec0::1:1", 16111, "[fec0::1:1]:16111")
addNaTest("fed1::2:2", 16110, "[fed1::2:2]:16110")
addNaTest("fee2::3:3", 8335, "[fee2::3:3]:8335")
addNaTest("fef3::4:4", 8336, "[fef3::4:4]:8336")
}
func addNaTest(ip string, port uint16, want AddressKey) {
nip := net.ParseIP(ip)
na := *appmessage.NewNetAddressIPPort(nip, port, appmessage.SFNodeNetwork)
test := naTest{na, want}
naTests = append(naTests, test)
}
func lookupFuncForTest(host string) ([]net.IP, error) {
return nil, errors.New("not implemented")
}
func newAddrManagerForTest(t *testing.T, testName string,
localSubnetworkID *subnetworkid.SubnetworkID) (addressManager *AddressManager, teardown func()) {
func newAddrManagerForTest(t *testing.T, testName string) (addressManager *AddressManager, teardown func()) {
cfg := config.DefaultConfig()
cfg.SubnetworkID = localSubnetworkID
dbPath, err := ioutil.TempDir("", testName)
if err != nil {
t.Fatalf("Error creating temporary directory: %s", err)
}
databaseContext, err := dbaccess.New(dbPath)
if err != nil {
t.Fatalf("error creating db: %s", err)
}
addressManager, err = New(cfg, databaseContext)
addressManager, err := New(NewConfig(cfg))
if err != nil {
t.Fatalf("error creating address manager: %s", err)
}
return addressManager, func() {
err := databaseContext.Close()
if err != nil {
t.Fatalf("error closing the database: %s", err)
}
}
}
func TestStartStop(t *testing.T) {
amgr, teardown := newAddrManagerForTest(t, "TestStartStop", nil)
defer teardown()
err := amgr.Start()
if err != nil {
t.Fatalf("Address Manager failed to start: %v", err)
}
err = amgr.Stop()
if err != nil {
t.Fatalf("Address Manager failed to stop: %v", err)
}
}
func TestAddAddressByIP(t *testing.T) {
fmtErr := errors.Errorf("")
addrErr := &net.AddrError{}
var tests = []struct {
addrIP string
err error
}{
{
someIP + ":16111",
nil,
},
{
someIP,
addrErr,
},
{
someIP[:12] + ":8333",
fmtErr,
},
{
someIP + ":abcd",
fmtErr,
},
}
amgr, teardown := newAddrManagerForTest(t, "TestAddAddressByIP", nil)
defer teardown()
for i, test := range tests {
err := AddAddressByIP(amgr, test.addrIP, nil)
if test.err != nil && err == nil {
t.Errorf("TestAddAddressByIP test %d failed expected an error and got none", i)
continue
}
if test.err == nil && err != nil {
t.Errorf("TestAddAddressByIP test %d failed expected no error and got one", i)
continue
}
if reflect.TypeOf(err) != reflect.TypeOf(test.err) {
t.Errorf("TestAddAddressByIP test %d failed got %v, want %v", i,
reflect.TypeOf(err), reflect.TypeOf(test.err))
continue
}
}
}
func TestAddLocalAddress(t *testing.T) {
var tests = []struct {
address appmessage.NetAddress
priority AddressPriority
valid bool
}{
{
appmessage.NetAddress{IP: net.ParseIP("192.168.0.100")},
InterfacePrio,
false,
},
{
appmessage.NetAddress{IP: net.ParseIP("204.124.1.1")},
InterfacePrio,
true,
},
{
appmessage.NetAddress{IP: net.ParseIP("204.124.1.1")},
BoundPrio,
true,
},
{
appmessage.NetAddress{IP: net.ParseIP("::1")},
InterfacePrio,
false,
},
{
appmessage.NetAddress{IP: net.ParseIP("fe80::1")},
InterfacePrio,
false,
},
{
appmessage.NetAddress{IP: net.ParseIP("2620:100::1")},
InterfacePrio,
true,
},
}
amgr, teardown := newAddrManagerForTest(t, "TestAddLocalAddress", nil)
defer teardown()
for x, test := range tests {
result := amgr.AddLocalAddress(&test.address, test.priority)
if result == nil && !test.valid {
t.Errorf("TestAddLocalAddress test #%d failed: %s should have "+
"been accepted", x, test.address.IP)
continue
}
if result != nil && test.valid {
t.Errorf("TestAddLocalAddress test #%d failed: %s should not have "+
"been accepted", x, test.address.IP)
continue
}
}
}
func TestAttempt(t *testing.T) {
amgr, teardown := newAddrManagerForTest(t, "TestAttempt", nil)
defer teardown()
// Add a new address and get it
err := AddAddressByIP(amgr, someIP+":8333", nil)
if err != nil {
t.Fatalf("Adding address failed: %v", err)
}
ka := amgr.GetAddress()
if !ka.LastAttempt().IsZero() {
t.Errorf("Address should not have attempts, but does")
}
na := ka.NetAddress()
amgr.Attempt(na)
if ka.LastAttempt().IsZero() {
t.Errorf("Address should have an attempt, but does not")
}
}
func TestConnected(t *testing.T) {
amgr, teardown := newAddrManagerForTest(t, "TestConnected", nil)
defer teardown()
// Add a new address and get it
err := AddAddressByIP(amgr, someIP+":8333", nil)
if err != nil {
t.Fatalf("Adding address failed: %v", err)
}
ka := amgr.GetAddress()
na := ka.NetAddress()
// make it an hour ago
na.Timestamp = mstime.Now().Add(time.Hour * -1)
amgr.Connected(na)
if !ka.NetAddress().Timestamp.After(na.Timestamp) {
t.Errorf("Address should have a new timestamp, but does not")
}
}
func TestNeedMoreAddresses(t *testing.T) {
amgr, teardown := newAddrManagerForTest(t, "TestNeedMoreAddresses", nil)
defer teardown()
addrsToAdd := 1500
b := amgr.NeedMoreAddresses()
if !b {
t.Errorf("Expected that we need more addresses")
}
addrs := make([]*appmessage.NetAddress, addrsToAdd)
var err error
for i := 0; i < addrsToAdd; i++ {
s := AddressKey(fmt.Sprintf("%d.%d.173.147:8333", i/128+60, i%128+60))
addrs[i], err = amgr.DeserializeNetAddress(s)
if err != nil {
t.Errorf("Failed to turn %s into an address: %v", s, err)
}
}
srcAddr := appmessage.NewNetAddressIPPort(net.IPv4(173, 144, 173, 111), 8333, 0)
amgr.AddAddresses(addrs, srcAddr, nil)
numAddrs := amgr.TotalNumAddresses()
if numAddrs > addrsToAdd {
t.Errorf("Number of addresses is too many %d vs %d", numAddrs, addrsToAdd)
}
b = amgr.NeedMoreAddresses()
if b {
t.Errorf("Expected that we don't need more addresses")
}
}
func TestGood(t *testing.T) {
amgr, teardown := newAddrManagerForTest(t, "TestGood", nil)
defer teardown()
addrsToAdd := 64 * 64
addrs := make([]*appmessage.NetAddress, addrsToAdd)
subnetworkCount := 32
subnetworkIDs := make([]*subnetworkid.SubnetworkID, subnetworkCount)
var err error
for i := 0; i < addrsToAdd; i++ {
s := AddressKey(fmt.Sprintf("%d.173.147.%d:8333", i/64+60, i%64+60))
addrs[i], err = amgr.DeserializeNetAddress(s)
if err != nil {
t.Errorf("Failed to turn %s into an address: %v", s, err)
}
}
for i := 0; i < subnetworkCount; i++ {
subnetworkIDs[i] = &subnetworkid.SubnetworkID{0xff - byte(i)}
}
srcAddr := appmessage.NewNetAddressIPPort(net.IPv4(173, 144, 173, 111), 8333, 0)
amgr.AddAddresses(addrs, srcAddr, nil)
for i, addr := range addrs {
amgr.Good(addr, subnetworkIDs[i%subnetworkCount])
}
numAddrs := amgr.TotalNumAddresses()
if numAddrs >= addrsToAdd {
t.Errorf("Number of addresses is too many: %d vs %d", numAddrs, addrsToAdd)
}
numCache := len(amgr.AddressCache(true, nil))
if numCache == 0 || numCache >= numAddrs/4 {
t.Errorf("Number of addresses in cache: got %d, want positive and less than %d",
numCache, numAddrs/4)
}
for i := 0; i < subnetworkCount; i++ {
numCache = len(amgr.AddressCache(false, subnetworkIDs[i]))
if numCache == 0 || numCache >= numAddrs/subnetworkCount {
t.Errorf("Number of addresses in subnetwork cache: got %d, want positive and less than %d",
numCache, numAddrs/4/subnetworkCount)
}
}
}
func TestGoodChangeSubnetworkID(t *testing.T) {
amgr, teardown := newAddrManagerForTest(t, "TestGoodChangeSubnetworkID", nil)
defer teardown()
addr := appmessage.NewNetAddressIPPort(net.IPv4(173, 144, 173, 111), 8333, 0)
addrKey := NetAddressKey(addr)
srcAddr := appmessage.NewNetAddressIPPort(net.IPv4(173, 144, 173, 111), 8333, 0)
oldSubnetwork := subnetworkid.SubnetworkIDNative
amgr.AddAddress(addr, srcAddr, oldSubnetwork)
amgr.Good(addr, oldSubnetwork)
// make sure address was saved to addressIndex under oldSubnetwork
ka := amgr.knownAddress(addr)
if ka == nil {
t.Fatalf("Address was not found after first time .Good called")
}
if !ka.SubnetworkID().IsEqual(oldSubnetwork) {
t.Fatalf("Address index did not point to oldSubnetwork")
}
// make sure address was added to correct bucket under oldSubnetwork
bucket := amgr.subnetworkTriedAddresBucketArrays[*oldSubnetwork][amgr.triedAddressBucketIndex(addr)]
wasFound := false
for _, ka := range bucket {
if NetAddressKey(ka.NetAddress()) == addrKey {
wasFound = true
}
}
if !wasFound {
t.Fatalf("Address was not found in the correct bucket in oldSubnetwork")
}
// now call .Good again with a different subnetwork
newSubnetwork := subnetworkid.SubnetworkIDRegistry
amgr.Good(addr, newSubnetwork)
// make sure address was updated in addressIndex under newSubnetwork
ka = amgr.knownAddress(addr)
if ka == nil {
t.Fatalf("Address was not found after second time .Good called")
}
if !ka.SubnetworkID().IsEqual(newSubnetwork) {
t.Fatalf("Address index did not point to newSubnetwork")
}
// make sure address was removed from bucket under oldSubnetwork
bucket = amgr.subnetworkTriedAddresBucketArrays[*oldSubnetwork][amgr.triedAddressBucketIndex(addr)]
wasFound = false
for _, ka := range bucket {
if NetAddressKey(ka.NetAddress()) == addrKey {
wasFound = true
}
}
if wasFound {
t.Fatalf("Address was not removed from bucket in oldSubnetwork")
}
// make sure address was added to correct bucket under newSubnetwork
bucket = amgr.subnetworkTriedAddresBucketArrays[*newSubnetwork][amgr.triedAddressBucketIndex(addr)]
wasFound = false
for _, ka := range bucket {
if NetAddressKey(ka.NetAddress()) == addrKey {
wasFound = true
}
}
if !wasFound {
t.Fatalf("Address was not found in the correct bucket in newSubnetwork")
}
}
func TestGetAddress(t *testing.T) {
localSubnetworkID := &subnetworkid.SubnetworkID{0xff}
amgr, teardown := newAddrManagerForTest(t, "TestGetAddress", localSubnetworkID)
defer teardown()
// Get an address from an empty set (should error)
if rv := amgr.GetAddress(); rv != nil {
t.Errorf("GetAddress failed: got: %v want: %v\n", rv, nil)
}
// Add a new address and get it
err := AddAddressByIP(amgr, someIP+":8332", localSubnetworkID)
if err != nil {
t.Fatalf("Adding address failed: %v", err)
}
ka := amgr.GetAddress()
if ka == nil {
t.Fatalf("Did not get an address where there is one in the pool")
}
amgr.Attempt(ka.NetAddress())
// Checks that we don't get it if we find that it has other subnetwork ID than expected.
actualSubnetworkID := &subnetworkid.SubnetworkID{0xfe}
amgr.Good(ka.NetAddress(), actualSubnetworkID)
ka = amgr.GetAddress()
if ka != nil {
t.Errorf("Didn't expect to get an address because there shouldn't be any address from subnetwork ID %s or nil", localSubnetworkID)
}
// Checks that the total number of addresses incremented although the new address is not full node or a partial node of the same subnetwork as the local node.
numAddrs := amgr.TotalNumAddresses()
if numAddrs != 1 {
t.Errorf("Wrong number of addresses: got %d, want %d", numAddrs, 1)
}
// Now we repeat the same process, but now the address has the expected subnetwork ID.
// Add a new address and get it
err = AddAddressByIP(amgr, someIP+":8333", localSubnetworkID)
if err != nil {
t.Fatalf("Adding address failed: %v", err)
}
ka = amgr.GetAddress()
if ka == nil {
t.Fatalf("Did not get an address where there is one in the pool")
}
if ka.NetAddress().IP.String() != someIP {
t.Errorf("Wrong IP: got %v, want %v", ka.NetAddress().IP.String(), someIP)
}
if !ka.SubnetworkID().IsEqual(localSubnetworkID) {
t.Errorf("Wrong Subnetwork ID: got %v, want %v", *ka.SubnetworkID(), localSubnetworkID)
}
amgr.Attempt(ka.NetAddress())
// Mark this as a good address and get it
amgr.Good(ka.NetAddress(), localSubnetworkID)
ka = amgr.GetAddress()
if ka == nil {
t.Fatalf("Did not get an address where there is one in the pool")
}
if ka.NetAddress().IP.String() != someIP {
t.Errorf("Wrong IP: got %v, want %v", ka.NetAddress().IP.String(), someIP)
}
if *ka.SubnetworkID() != *localSubnetworkID {
t.Errorf("Wrong Subnetwork ID: got %v, want %v", ka.SubnetworkID(), localSubnetworkID)
}
numAddrs = amgr.TotalNumAddresses()
if numAddrs != 2 {
t.Errorf("Wrong number of addresses: got %d, want %d", numAddrs, 1)
}
}
func TestGetBestLocalAddress(t *testing.T) {
func TestBestLocalAddress(t *testing.T) {
localAddrs := []appmessage.NetAddress{
{IP: net.ParseIP("192.168.0.100")},
{IP: net.ParseIP("::1")},
@@ -557,12 +66,12 @@ func TestGetBestLocalAddress(t *testing.T) {
},
}
amgr, teardown := newAddrManagerForTest(t, "TestGetBestLocalAddress", nil)
amgr, teardown := newAddrManagerForTest(t, "TestGetBestLocalAddress")
defer teardown()
// Test against default when there's no address
for x, test := range tests {
got := amgr.GetBestLocalAddress(&test.remoteAddr)
got := amgr.BestLocalAddress(&test.remoteAddr)
if !test.want0.IP.Equal(got.IP) {
t.Errorf("TestGetBestLocalAddress test1 #%d failed for remote address %s: want %s got %s",
x, test.remoteAddr.IP, test.want1.IP, got.IP)
@@ -571,12 +80,12 @@ func TestGetBestLocalAddress(t *testing.T) {
}
for _, localAddr := range localAddrs {
amgr.AddLocalAddress(&localAddr, InterfacePrio)
amgr.localAddresses.addLocalNetAddress(&localAddr, InterfacePrio)
}
// Test against want1
for x, test := range tests {
got := amgr.GetBestLocalAddress(&test.remoteAddr)
got := amgr.BestLocalAddress(&test.remoteAddr)
if !test.want1.IP.Equal(got.IP) {
t.Errorf("TestGetBestLocalAddress test1 #%d failed for remote address %s: want %s got %s",
x, test.remoteAddr.IP, test.want1.IP, got.IP)
@@ -586,42 +95,15 @@ func TestGetBestLocalAddress(t *testing.T) {
// Add a public IP to the list of local addresses.
localAddr := appmessage.NetAddress{IP: net.ParseIP("204.124.8.100")}
amgr.AddLocalAddress(&localAddr, InterfacePrio)
amgr.localAddresses.addLocalNetAddress(&localAddr, InterfacePrio)
// Test against want2
for x, test := range tests {
got := amgr.GetBestLocalAddress(&test.remoteAddr)
got := amgr.BestLocalAddress(&test.remoteAddr)
if !test.want2.IP.Equal(got.IP) {
t.Errorf("TestGetBestLocalAddress test2 #%d failed for remote address %s: want %s got %s",
x, test.remoteAddr.IP, test.want2.IP, got.IP)
continue
}
}
/*
// Add a Tor generated IP address
localAddr = appmessage.NetAddress{IP: net.ParseIP("fd87:d87e:eb43:25::1")}
amgr.AddLocalAddress(&localAddr, ManualPrio)
// Test against want3
for x, test := range tests {
got := amgr.GetBestLocalAddress(&test.remoteAddr)
if !test.want3.IP.Equal(got.IP) {
t.Errorf("TestGetBestLocalAddress test3 #%d failed for remote address %s: want %s got %s",
x, test.remoteAddr.IP, test.want3.IP, got.IP)
continue
}
}
*/
}
func TestNetAddressKey(t *testing.T) {
addNaTests()
t.Logf("Running %d tests", len(naTests))
for i, test := range naTests {
key := NetAddressKey(&test.in)
if key != test.want {
t.Errorf("NetAddressKey #%d\n got: %s want: %s", i, key, test.want)
continue
}
}
}

View File

@@ -0,0 +1,43 @@
package addressmanager
import (
"math/rand"
"time"
"github.com/kaspanet/kaspad/app/appmessage"
)
// AddressRandomize implement AddressRandomizer interface
type AddressRandomize struct {
random *rand.Rand
}
// NewAddressRandomize returns a new RandomizeAddress.
func NewAddressRandomize() *AddressRandomize {
return &AddressRandomize{
random: rand.New(rand.NewSource(time.Now().UnixNano())),
}
}
// RandomAddress returns a random address from input list
func (amc *AddressRandomize) RandomAddress(addresses []*appmessage.NetAddress) *appmessage.NetAddress {
if len(addresses) > 0 {
randomIndex := rand.Intn(len(addresses))
return addresses[randomIndex]
}
return nil
}
// RandomAddresses returns count addresses at random from input list
func (amc *AddressRandomize) RandomAddresses(addresses []*appmessage.NetAddress, count int) []*appmessage.NetAddress {
result := make([]*appmessage.NetAddress, 0, count)
if len(addresses) > 0 {
randomIndexes := rand.Perm(len(addresses))
for i := 0; i < count; i++ {
result = append(result, addresses[randomIndexes[i]])
}
}
return result
}

View File

@@ -0,0 +1,27 @@
package addressmanager
import (
"net"
"github.com/kaspanet/kaspad/infrastructure/config"
)
// Config is a descriptor which specifies the AddressManager instance configuration.
type Config struct {
AcceptUnroutable bool
DefaultPort string
ExternalIPs []string
Listeners []string
Lookup func(string) ([]net.IP, error)
}
// NewConfig returns a new address manager Config.
func NewConfig(cfg *config.Config) *Config {
return &Config{
AcceptUnroutable: cfg.NetParams().AcceptUnroutable,
DefaultPort: cfg.NetParams().DefaultPort,
ExternalIPs: cfg.ExternalIPs,
Listeners: cfg.Listeners,
Lookup: cfg.Lookup,
}
}

View File

@@ -1,24 +0,0 @@
// Copyright (c) 2013-2015 The btcsuite developers
// Use of this source code is governed by an ISC
// license that can be found in the LICENSE file.
package addressmanager
import (
"github.com/kaspanet/kaspad/app/appmessage"
"github.com/kaspanet/kaspad/util/mstime"
)
func TstKnownAddressIsBad(ka *KnownAddress) bool {
return ka.isBad()
}
func TstKnownAddressChance(ka *KnownAddress) float64 {
return ka.chance()
}
func TstNewKnownAddress(na *appmessage.NetAddress, attempts int,
lastattempt, lastsuccess mstime.Time, tried bool, refs int) *KnownAddress {
return &KnownAddress{netAddress: na, attempts: attempts, lastAttempt: lastattempt,
lastSuccess: lastsuccess, tried: tried, referenceCount: refs}
}

View File

@@ -1,107 +0,0 @@
// Copyright (c) 2013-2014 The btcsuite developers
// Use of this source code is governed by an ISC
// license that can be found in the LICENSE file.
package addressmanager
import (
"github.com/kaspanet/kaspad/app/appmessage"
"github.com/kaspanet/kaspad/util/mstime"
"time"
"github.com/kaspanet/kaspad/util/subnetworkid"
)
// KnownAddress tracks information about a known network address that is used
// to determine how viable an address is.
type KnownAddress struct {
netAddress *appmessage.NetAddress
sourceAddress *appmessage.NetAddress
attempts int
lastAttempt mstime.Time
lastSuccess mstime.Time
tried bool
referenceCount int // reference count of new buckets
subnetworkID *subnetworkid.SubnetworkID
isBanned bool
bannedTime mstime.Time
}
// NetAddress returns the underlying appmessage.NetAddress associated with the
// known address.
func (ka *KnownAddress) NetAddress() *appmessage.NetAddress {
return ka.netAddress
}
// SubnetworkID returns the subnetwork ID of the known address.
func (ka *KnownAddress) SubnetworkID() *subnetworkid.SubnetworkID {
return ka.subnetworkID
}
// LastAttempt returns the last time the known address was attempted.
func (ka *KnownAddress) LastAttempt() mstime.Time {
return ka.lastAttempt
}
// chance returns the selection probability for a known address. The priority
// depends upon how recently the address has been seen, how recently it was last
// attempted and how often attempts to connect to it have failed.
func (ka *KnownAddress) chance() float64 {
now := mstime.Now()
lastAttempt := now.Sub(ka.lastAttempt)
if lastAttempt < 0 {
lastAttempt = 0
}
c := 1.0
// Very recent attempts are less likely to be retried.
if lastAttempt < 10*time.Minute {
c *= 0.01
}
// Failed attempts deprioritise.
for i := ka.attempts; i > 0; i-- {
c /= 1.5
}
return c
}
// isBad returns true if the address in question has not been tried in the last
// minute and meets one of the following criteria:
// 1) It claims to be from the future
// 2) It hasn't been seen in over a month
// 3) It has failed at least three times and never succeeded
// 4) It has failed ten times in the last week
// All addresses that meet these criteria are assumed to be worthless and not
// worth keeping hold of.
func (ka *KnownAddress) isBad() bool {
if ka.lastAttempt.After(mstime.Now().Add(-1 * time.Minute)) {
return false
}
// From the future?
if ka.netAddress.Timestamp.After(mstime.Now().Add(10 * time.Minute)) {
return true
}
// Over a month old?
if ka.netAddress.Timestamp.Before(mstime.Now().Add(-1 * numMissingDays * time.Hour * 24)) {
return true
}
// Never succeeded?
if ka.lastSuccess.IsZero() && ka.attempts >= numRetries {
return true
}
// Hasn't succeeded in too long?
if !ka.lastSuccess.After(mstime.Now().Add(-1*minBadDays*time.Hour*24)) &&
ka.attempts >= maxFailures {
return true
}
return false
}

View File

@@ -1,115 +0,0 @@
// Copyright (c) 2013-2015 The btcsuite developers
// Use of this source code is governed by an ISC
// license that can be found in the LICENSE file.
package addressmanager_test
import (
"github.com/kaspanet/kaspad/app/appmessage"
"github.com/kaspanet/kaspad/util/mstime"
"math"
"testing"
"time"
"github.com/kaspanet/kaspad/infrastructure/network/addressmanager"
)
func TestChance(t *testing.T) {
now := mstime.Now()
var tests = []struct {
addr *addressmanager.KnownAddress
expected float64
}{
{
//Test normal case
addressmanager.TstNewKnownAddress(&appmessage.NetAddress{Timestamp: now.Add(-35 * time.Second)},
0, mstime.Now().Add(-30*time.Minute), mstime.Now(), false, 0),
1.0,
}, {
//Test case in which lastseen < 0
addressmanager.TstNewKnownAddress(&appmessage.NetAddress{Timestamp: now.Add(20 * time.Second)},
0, mstime.Now().Add(-30*time.Minute), mstime.Now(), false, 0),
1.0,
}, {
//Test case in which lastAttempt < 0
addressmanager.TstNewKnownAddress(&appmessage.NetAddress{Timestamp: now.Add(-35 * time.Second)},
0, mstime.Now().Add(30*time.Minute), mstime.Now(), false, 0),
1.0 * .01,
}, {
//Test case in which lastAttempt < ten minutes
addressmanager.TstNewKnownAddress(&appmessage.NetAddress{Timestamp: now.Add(-35 * time.Second)},
0, mstime.Now().Add(-5*time.Minute), mstime.Now(), false, 0),
1.0 * .01,
}, {
//Test case with several failed attempts.
addressmanager.TstNewKnownAddress(&appmessage.NetAddress{Timestamp: now.Add(-35 * time.Second)},
2, mstime.Now().Add(-30*time.Minute), mstime.Now(), false, 0),
1 / 1.5 / 1.5,
},
}
err := .0001
for i, test := range tests {
chance := addressmanager.TstKnownAddressChance(test.addr)
if math.Abs(test.expected-chance) >= err {
t.Errorf("case %d: got %f, expected %f", i, chance, test.expected)
}
}
}
func TestIsBad(t *testing.T) {
now := mstime.Now()
future := now.Add(35 * time.Minute)
monthOld := now.Add(-43 * time.Hour * 24)
secondsOld := now.Add(-2 * time.Second)
minutesOld := now.Add(-27 * time.Minute)
hoursOld := now.Add(-5 * time.Hour)
zeroTime := mstime.Time{}
futureNa := &appmessage.NetAddress{Timestamp: future}
minutesOldNa := &appmessage.NetAddress{Timestamp: minutesOld}
monthOldNa := &appmessage.NetAddress{Timestamp: monthOld}
currentNa := &appmessage.NetAddress{Timestamp: secondsOld}
//Test addresses that have been tried in the last minute.
if addressmanager.TstKnownAddressIsBad(addressmanager.TstNewKnownAddress(futureNa, 3, secondsOld, zeroTime, false, 0)) {
t.Errorf("test case 1: addresses that have been tried in the last minute are not bad.")
}
if addressmanager.TstKnownAddressIsBad(addressmanager.TstNewKnownAddress(monthOldNa, 3, secondsOld, zeroTime, false, 0)) {
t.Errorf("test case 2: addresses that have been tried in the last minute are not bad.")
}
if addressmanager.TstKnownAddressIsBad(addressmanager.TstNewKnownAddress(currentNa, 3, secondsOld, zeroTime, false, 0)) {
t.Errorf("test case 3: addresses that have been tried in the last minute are not bad.")
}
if addressmanager.TstKnownAddressIsBad(addressmanager.TstNewKnownAddress(currentNa, 3, secondsOld, monthOld, true, 0)) {
t.Errorf("test case 4: addresses that have been tried in the last minute are not bad.")
}
if addressmanager.TstKnownAddressIsBad(addressmanager.TstNewKnownAddress(currentNa, 2, secondsOld, secondsOld, true, 0)) {
t.Errorf("test case 5: addresses that have been tried in the last minute are not bad.")
}
//Test address that claims to be from the future.
if !addressmanager.TstKnownAddressIsBad(addressmanager.TstNewKnownAddress(futureNa, 0, minutesOld, hoursOld, true, 0)) {
t.Errorf("test case 6: addresses that claim to be from the future are bad.")
}
//Test address that has not been seen in over a month.
if !addressmanager.TstKnownAddressIsBad(addressmanager.TstNewKnownAddress(monthOldNa, 0, minutesOld, hoursOld, true, 0)) {
t.Errorf("test case 7: addresses more than a month old are bad.")
}
//It has failed at least three times and never succeeded.
if !addressmanager.TstKnownAddressIsBad(addressmanager.TstNewKnownAddress(minutesOldNa, 3, minutesOld, zeroTime, true, 0)) {
t.Errorf("test case 8: addresses that have never succeeded are bad.")
}
//It has failed ten times in the last week
if !addressmanager.TstKnownAddressIsBad(addressmanager.TstNewKnownAddress(minutesOldNa, 10, minutesOld, monthOld, true, 0)) {
t.Errorf("test case 9: addresses that have not succeeded in too long are bad.")
}
//Test an address that should work.
if addressmanager.TstKnownAddressIsBad(addressmanager.TstNewKnownAddress(minutesOldNa, 2, minutesOld, hoursOld, true, 0)) {
t.Errorf("test case 10: This should be a valid address.")
}
}

View File

@@ -0,0 +1,400 @@
package addressmanager
import (
"net"
"runtime"
"strconv"
"strings"
"sync"
"github.com/kaspanet/kaspad/app/appmessage"
"github.com/pkg/errors"
)
// AddressPriority type is used to describe the hierarchy of local address
// discovery methods.
type AddressPriority int
const (
// InterfacePrio signifies the address is on a local interface
InterfacePrio AddressPriority = iota
// BoundPrio signifies the address has been explicitly bounded to.
BoundPrio
// UpnpPrio signifies the address was obtained from UPnP.
UpnpPrio
// HTTPPrio signifies the address was obtained from an external HTTP service.
HTTPPrio
// ManualPrio signifies the address was provided by --externalip.
ManualPrio
)
type localAddress struct {
netAddress *appmessage.NetAddress
score AddressPriority
}
type localAddressManager struct {
localAddresses map[AddressKey]*localAddress
lookupFunc func(string) ([]net.IP, error)
cfg *Config
mutex sync.Mutex
}
func newLocalAddressManager(cfg *Config) (*localAddressManager, error) {
localAddressManager := localAddressManager{
localAddresses: map[AddressKey]*localAddress{},
cfg: cfg,
lookupFunc: cfg.Lookup,
}
err := localAddressManager.initListeners()
if err != nil {
return nil, err
}
return &localAddressManager, nil
}
// addLocalNetAddress adds netAddress to the list of known local addresses to advertise
// with the given priority.
func (lam *localAddressManager) addLocalNetAddress(netAddress *appmessage.NetAddress, priority AddressPriority) error {
if !IsRoutable(netAddress, lam.cfg.AcceptUnroutable) {
return errors.Errorf("address %s is not routable", netAddress.IP)
}
lam.mutex.Lock()
defer lam.mutex.Unlock()
addressKey := netAddressKey(netAddress)
address, ok := lam.localAddresses[addressKey]
if !ok || address.score < priority {
if ok {
address.score = priority + 1
} else {
lam.localAddresses[addressKey] = &localAddress{
netAddress: netAddress,
score: priority,
}
}
}
return nil
}
// bestLocalAddress returns the most appropriate local address to use
// for the given remote address.
func (lam *localAddressManager) bestLocalAddress(remoteAddress *appmessage.NetAddress) *appmessage.NetAddress {
lam.mutex.Lock()
defer lam.mutex.Unlock()
bestReach := 0
var bestScore AddressPriority
var bestAddress *appmessage.NetAddress
for _, localAddress := range lam.localAddresses {
reach := reachabilityFrom(localAddress.netAddress, remoteAddress, lam.cfg.AcceptUnroutable)
if reach > bestReach ||
(reach == bestReach && localAddress.score > bestScore) {
bestReach = reach
bestScore = localAddress.score
bestAddress = localAddress.netAddress
}
}
if bestAddress == nil {
// Send something unroutable if nothing suitable.
var ip net.IP
if !IsIPv4(remoteAddress) {
ip = net.IPv6zero
} else {
ip = net.IPv4zero
}
services := appmessage.SFNodeNetwork | appmessage.SFNodeBloom
bestAddress = appmessage.NewNetAddressIPPort(ip, 0, services)
}
return bestAddress
}
// addLocalAddress adds an address that this node is listening on to the
// address manager so that it may be relayed to peers.
func (lam *localAddressManager) addLocalAddress(addr string) error {
host, portStr, err := net.SplitHostPort(addr)
if err != nil {
return err
}
port, err := strconv.ParseUint(portStr, 10, 16)
if err != nil {
return err
}
if ip := net.ParseIP(host); ip != nil && ip.IsUnspecified() {
// If bound to unspecified address, advertise all local interfaces
addrs, err := net.InterfaceAddrs()
if err != nil {
return err
}
for _, addr := range addrs {
ifaceIP, _, err := net.ParseCIDR(addr.String())
if err != nil {
continue
}
// If bound to 0.0.0.0, do not add IPv6 interfaces and if bound to
// ::, do not add IPv4 interfaces.
if (ip.To4() == nil) != (ifaceIP.To4() == nil) {
continue
}
netAddr := appmessage.NewNetAddressIPPort(ifaceIP, uint16(port), appmessage.DefaultServices)
lam.addLocalNetAddress(netAddr, BoundPrio)
}
} else {
netAddr, err := lam.hostToNetAddress(host, uint16(port), appmessage.DefaultServices)
if err != nil {
return err
}
lam.addLocalNetAddress(netAddr, BoundPrio)
}
return nil
}
// initListeners initializes the configured net listeners and adds any bound
// addresses to the address manager
func (lam *localAddressManager) initListeners() error {
if len(lam.cfg.ExternalIPs) != 0 {
defaultPort, err := strconv.ParseUint(lam.cfg.DefaultPort, 10, 16)
if err != nil {
log.Errorf("Can not parse default port %s for active DAG: %s",
lam.cfg.DefaultPort, err)
return err
}
for _, sip := range lam.cfg.ExternalIPs {
eport := uint16(defaultPort)
host, portstr, err := net.SplitHostPort(sip)
if err != nil {
// no port, use default.
host = sip
} else {
port, err := strconv.ParseUint(portstr, 10, 16)
if err != nil {
log.Warnf("Can not parse port from %s for "+
"externalip: %s", sip, err)
continue
}
eport = uint16(port)
}
na, err := lam.hostToNetAddress(host, eport, appmessage.DefaultServices)
if err != nil {
log.Warnf("Not adding %s as externalip: %s", sip, err)
continue
}
err = lam.addLocalNetAddress(na, ManualPrio)
if err != nil {
log.Warnf("Skipping specified external IP: %s", err)
}
}
} else {
// Listen for TCP connections at the configured addresses
netAddrs, err := parseListeners(lam.cfg.Listeners)
if err != nil {
return err
}
// Add bound addresses to address manager to be advertised to peers.
for _, addr := range netAddrs {
listener, err := net.Listen(addr.Network(), addr.String())
if err != nil {
log.Warnf("Can't listen on %s: %s", addr, err)
continue
}
addr := listener.Addr().String()
err = listener.Close()
if err != nil {
return err
}
err = lam.addLocalAddress(addr)
if err != nil {
log.Warnf("Skipping bound address %s: %s", addr, err)
}
}
}
return nil
}
// hostToNetAddress returns a netaddress given a host address. If
// the host is not an IP address it will be resolved.
func (lam *localAddressManager) hostToNetAddress(host string, port uint16, services appmessage.ServiceFlag) (*appmessage.NetAddress, error) {
ip := net.ParseIP(host)
if ip == nil {
ips, err := lam.lookupFunc(host)
if err != nil {
return nil, err
}
if len(ips) == 0 {
return nil, errors.Errorf("no addresses found for %s", host)
}
ip = ips[0]
}
return appmessage.NewNetAddressIPPort(ip, port, services), nil
}
// parseListeners determines whether each listen address is IPv4 and IPv6 and
// returns a slice of appropriate net.Addrs to listen on with TCP. It also
// properly detects addresses which apply to "all interfaces" and adds the
// address as both IPv4 and IPv6.
func parseListeners(addrs []string) ([]net.Addr, error) {
netAddrs := make([]net.Addr, 0, len(addrs)*2)
for _, addr := range addrs {
host, _, err := net.SplitHostPort(addr)
if err != nil {
// Shouldn't happen due to already being normalized.
return nil, err
}
// Empty host or host of * on plan9 is both IPv4 and IPv6.
if host == "" || (host == "*" && runtime.GOOS == "plan9") {
netAddrs = append(netAddrs, simpleAddr{net: "tcp4", addr: addr})
netAddrs = append(netAddrs, simpleAddr{net: "tcp6", addr: addr})
continue
}
// Strip IPv6 zone id if present since net.ParseIP does not
// handle it.
zoneIndex := strings.LastIndex(host, "%")
if zoneIndex > 0 {
host = host[:zoneIndex]
}
// Parse the IP.
ip := net.ParseIP(host)
if ip == nil {
hostAddrs, err := net.LookupHost(host)
if err != nil {
return nil, err
}
ip = net.ParseIP(hostAddrs[0])
if ip == nil {
return nil, errors.Errorf("Cannot resolve IP address for host '%s'", host)
}
}
// To4 returns nil when the IP is not an IPv4 address, so use
// this determine the address type.
if ip.To4() == nil {
netAddrs = append(netAddrs, simpleAddr{net: "tcp6", addr: addr})
} else {
netAddrs = append(netAddrs, simpleAddr{net: "tcp4", addr: addr})
}
}
return netAddrs, nil
}
// reachabilityFrom returns the relative reachability of the provided local
// address to the provided remote address.
func reachabilityFrom(localAddress, remoteAddress *appmessage.NetAddress, acceptUnroutable bool) int {
const (
Unreachable = 0
Default = iota
Teredo
Ipv6Weak
Ipv4
Ipv6Strong
Private
)
IsRoutable := func(na *appmessage.NetAddress) bool {
if acceptUnroutable {
return !IsLocal(na)
}
return IsValid(na) && !(IsRFC1918(na) || IsRFC2544(na) ||
IsRFC3927(na) || IsRFC4862(na) || IsRFC3849(na) ||
IsRFC4843(na) || IsRFC5737(na) || IsRFC6598(na) ||
IsLocal(na) || (IsRFC4193(na)))
}
if !IsRoutable(remoteAddress) {
return Unreachable
}
if IsRFC4380(remoteAddress) {
if !IsRoutable(localAddress) {
return Default
}
if IsRFC4380(localAddress) {
return Teredo
}
if IsIPv4(localAddress) {
return Ipv4
}
return Ipv6Weak
}
if IsIPv4(remoteAddress) {
if IsRoutable(localAddress) && IsIPv4(localAddress) {
return Ipv4
}
return Unreachable
}
/* ipv6 */
var tunnelled bool
// Is our v6 is tunnelled?
if IsRFC3964(localAddress) || IsRFC6052(localAddress) || IsRFC6145(localAddress) {
tunnelled = true
}
if !IsRoutable(localAddress) {
return Default
}
if IsRFC4380(localAddress) {
return Teredo
}
if IsIPv4(localAddress) {
return Ipv4
}
if tunnelled {
// only prioritise ipv6 if we aren't tunnelling it.
return Ipv6Weak
}
return Ipv6Strong
}
// simpleAddr implements the net.Addr interface with two struct fields
type simpleAddr struct {
net, addr string
}
// String returns the address.
//
// This is part of the net.Addr interface.
func (a simpleAddr) String() string {
return a.addr
}
// Network returns the network.
//
// This is part of the net.Addr interface.
func (a simpleAddr) Network() string {
return a.net
}
// Ensure simpleAddr implements the net.Addr interface.
var _ net.Addr = simpleAddr{}

View File

@@ -5,8 +5,9 @@
package addressmanager
import (
"github.com/kaspanet/kaspad/app/appmessage"
"net"
"github.com/kaspanet/kaspad/app/appmessage"
)
var (
@@ -77,6 +78,13 @@ var (
heNet = ipNet("2001:470::", 32, 128)
)
const (
// GetAddressesMax is the most addresses that we will send in response
// to a getAddress (in practise the most addresses we will return from a
// call to AddressCache()).
GetAddressesMax = 2500
)
// ipNet returns a net.IPNet struct given the passed IP address string, number
// of one bits to include at the start of the mask, and the total number of bits
// for the mask.
@@ -199,8 +207,8 @@ func IsValid(na *appmessage.NetAddress) bool {
// IsRoutable returns whether or not the passed address is routable over
// the public internet. This is true as long as the address is valid and is not
// in any reserved ranges.
func (am *AddressManager) IsRoutable(na *appmessage.NetAddress) bool {
if am.cfg.NetParams().AcceptUnroutable {
func IsRoutable(na *appmessage.NetAddress, acceptUnroutable bool) bool {
if acceptUnroutable {
return !IsLocal(na)
}
@@ -218,7 +226,7 @@ func (am *AddressManager) GroupKey(na *appmessage.NetAddress) string {
if IsLocal(na) {
return "local"
}
if !am.IsRoutable(na) {
if !IsRoutable(na, am.cfg.AcceptUnroutable) {
return "unroutable"
}
if IsIPv4(na) {

View File

@@ -5,15 +5,16 @@
package addressmanager
import (
"github.com/kaspanet/kaspad/app/appmessage"
"net"
"testing"
"github.com/kaspanet/kaspad/app/appmessage"
)
// TestIPTypes ensures the various functions which determine the type of an IP
// address based on RFCs work as intended.
func TestIPTypes(t *testing.T) {
amgr, teardown := newAddrManagerForTest(t, "TestAddAddressByIP", nil)
amgr, teardown := newAddrManagerForTest(t, "TestAddAddressByIP")
defer teardown()
type ipTest struct {
in appmessage.NetAddress
@@ -136,7 +137,7 @@ func TestIPTypes(t *testing.T) {
t.Errorf("IsValid %s\n got: %v want: %v", test.in.IP, rv, test.valid)
}
if rv := amgr.IsRoutable(&test.in); rv != test.routable {
if rv := IsRoutable(&test.in, amgr.cfg.AcceptUnroutable); rv != test.routable {
t.Errorf("IsRoutable %s\n got: %v want: %v", test.in.IP, rv, test.routable)
}
}
@@ -145,7 +146,7 @@ func TestIPTypes(t *testing.T) {
// TestGroupKey tests the GroupKey function to ensure it properly groups various
// IP addresses.
func TestGroupKey(t *testing.T) {
amgr, teardown := newAddrManagerForTest(t, "TestAddAddressByIP", nil)
amgr, teardown := newAddrManagerForTest(t, "TestAddAddressByIP")
defer teardown()
tests := []struct {

View File

@@ -1,11 +1,12 @@
package addressmanager
import (
"net"
"strconv"
"github.com/kaspanet/kaspad/app/appmessage"
"github.com/kaspanet/kaspad/util/subnetworkid"
"github.com/pkg/errors"
"net"
"strconv"
)
// AddAddressByIP adds an address where we are given an ip:port and not a
@@ -26,6 +27,6 @@ func AddAddressByIP(am *AddressManager, addressIP string, subnetworkID *subnetwo
return errors.Errorf("invalid port %s: %s", portString, err)
}
netAddress := appmessage.NewNetAddressIPPort(ip, uint16(port), 0)
am.AddAddress(netAddress, netAddress, subnetworkID)
am.AddAddresses(netAddress)
return nil
}

View File

@@ -1,5 +1,7 @@
package connmanager
import "github.com/kaspanet/kaspad/app/appmessage"
// checkOutgoingConnections goes over all activeOutgoing and makes sure they are still active.
// Then it opens connections so that we have targetOutgoing active connections
func (c *ConnectionManager) checkOutgoingConnections(connSet connectionSet) {
@@ -14,6 +16,12 @@ func (c *ConnectionManager) checkOutgoingConnections(connSet connectionSet) {
delete(c.activeOutgoing, address)
}
connections := c.netAdapter.P2PConnections()
connectedAddresses := make([]*appmessage.NetAddress, len(connections))
for i, connection := range connections {
connectedAddresses[i] = connection.NetAddress()
}
liveConnections := len(c.activeOutgoing)
if c.targetOutgoing == liveConnections {
return
@@ -23,47 +31,21 @@ func (c *ConnectionManager) checkOutgoingConnections(connSet connectionSet) {
liveConnections, c.targetOutgoing, c.targetOutgoing-liveConnections)
connectionsNeededCount := c.targetOutgoing - len(c.activeOutgoing)
connectionAttempts := connectionsNeededCount * 2
for i := 0; i < connectionAttempts; i++ {
// Return in case we've already reached or surpassed our target
if len(c.activeOutgoing) >= c.targetOutgoing {
return
}
netAddresses := c.addressManager.RandomAddresses(connectionsNeededCount, connectedAddresses)
address := c.addressManager.GetAddress()
if address == nil {
log.Warnf("No more addresses available")
return
}
for _, netAddress := range netAddresses {
addressString := netAddress.TCPAddress().String()
netAddress := address.NetAddress()
tcpAddress := netAddress.TCPAddress()
addressString := tcpAddress.String()
if c.connectionExists(addressString) {
log.Debugf("Fetched address %s from address manager but it's already connected. Skipping...", addressString)
continue
}
isBanned, err := c.addressManager.IsBanned(netAddress)
if err != nil {
log.Infof("Couldn't resolve whether %s is banned: %s", addressString, err)
continue
}
if isBanned {
continue
}
c.addressManager.Attempt(netAddress)
log.Debugf("Connecting to %s because we have %d outgoing connections and the target is "+
"%d", addressString, len(c.activeOutgoing), c.targetOutgoing)
err = c.initiateConnection(addressString)
err := c.initiateConnection(addressString)
if err != nil {
log.Infof("Couldn't connect to %s: %s", addressString, err)
c.addressManager.RemoveAddress(netAddress)
continue
}
c.addressManager.Connected(netAddress)
c.activeOutgoing[addressString] = struct{}{}
}
}

View File

@@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.25.0-devel
// protoc v3.6.1
// protoc-gen-go v1.25.0
// protoc v3.12.3
// source: peer_service.proto
package pb

View File

@@ -17,12 +17,13 @@ type gRPCServer struct {
server *grpc.Server
}
const maxMessageSize = 1024 * 1024 * 10 // 10MB
// MaxMessageSize is the max size allowed for a message
const MaxMessageSize = 1024 * 1024 * 10 // 10MB
// newGRPCServer creates a gRPC server
func newGRPCServer(listeningAddresses []string) *gRPCServer {
return &gRPCServer{
server: grpc.NewServer(grpc.MaxRecvMsgSize(maxMessageSize), grpc.MaxSendMsgSize(maxMessageSize)),
server: grpc.NewServer(grpc.MaxRecvMsgSize(MaxMessageSize), grpc.MaxSendMsgSize(MaxMessageSize)),
listeningAddresses: listeningAddresses,
}
}
@@ -39,7 +40,7 @@ func (s *gRPCServer) Start() error {
}
}
log.Debugf("Server started with maxMessageSize %d", maxMessageSize)
log.Debugf("Server started with MaxMessageSize %d", MaxMessageSize)
return nil
}

Some files were not shown because too many files have changed in this diff Show More