From 92e9d7b8f316f4d4524ba5467ebfed9f2a32d901 Mon Sep 17 00:00:00 2001 From: Ori Newman Date: Tue, 12 Dec 2023 17:07:14 +0200 Subject: [PATCH] Fix type detection in RemoveInvalidTransactions --- .../processes/blockbuilder/block_builder.go | 5 +++-- domain/consensus/ruleerrors/rule_error.go | 2 +- domain/consensus/ruleerrors/rule_error_test.go | 7 ++++--- domain/miningmanager/mempool/mempool.go | 16 ++++------------ 4 files changed, 12 insertions(+), 18 deletions(-) diff --git a/domain/consensus/processes/blockbuilder/block_builder.go b/domain/consensus/processes/blockbuilder/block_builder.go index bef78afd8..88fe446af 100644 --- a/domain/consensus/processes/blockbuilder/block_builder.go +++ b/domain/consensus/processes/blockbuilder/block_builder.go @@ -131,11 +131,12 @@ func (bb *blockBuilder) validateTransactions(stagingArea *model.StagingArea, for _, transaction := range transactions { err := bb.validateTransaction(stagingArea, transaction) if err != nil { - if !errors.As(err, &ruleerrors.RuleError{}) { + ruleError := ruleerrors.RuleError{} + if !errors.As(err, &ruleError) { return err } invalidTransactions = append(invalidTransactions, - ruleerrors.InvalidTransaction{Transaction: transaction, Error: err}) + ruleerrors.InvalidTransaction{Transaction: transaction, Error: &ruleError}) } } diff --git a/domain/consensus/ruleerrors/rule_error.go b/domain/consensus/ruleerrors/rule_error.go index b3321e9ca..a46f8528a 100644 --- a/domain/consensus/ruleerrors/rule_error.go +++ b/domain/consensus/ruleerrors/rule_error.go @@ -315,7 +315,7 @@ func NewErrMissingParents(missingParentHashes []*externalapi.DomainHash) error { // InvalidTransaction is a struct containing an invalid transaction, and the error explaining why it's invalid. type InvalidTransaction struct { Transaction *externalapi.DomainTransaction - Error error + Error *RuleError } func (invalid InvalidTransaction) String() string { diff --git a/domain/consensus/ruleerrors/rule_error_test.go b/domain/consensus/ruleerrors/rule_error_test.go index bdb3dfe04..8ad035ce3 100644 --- a/domain/consensus/ruleerrors/rule_error_test.go +++ b/domain/consensus/ruleerrors/rule_error_test.go @@ -3,9 +3,10 @@ package ruleerrors import ( "errors" "fmt" - "github.com/kaspanet/kaspad/domain/consensus/utils/consensushashing" "testing" + "github.com/kaspanet/kaspad/domain/consensus/utils/consensushashing" + "github.com/kaspanet/kaspad/domain/consensus/model/externalapi" ) @@ -49,7 +50,7 @@ func TestNewErrMissingTxOut(t *testing.T) { func TestNewErrInvalidTransactionsInNewBlock(t *testing.T) { tx := &externalapi.DomainTransaction{Fee: 1337} txID := consensushashing.TransactionID(tx) - outer := NewErrInvalidTransactionsInNewBlock([]InvalidTransaction{{tx, ErrNoTxInputs}}) + outer := NewErrInvalidTransactionsInNewBlock([]InvalidTransaction{{tx, &ErrNoTxInputs}}) //TODO: Implement Stringer for `DomainTransaction` expectedOuterErr := fmt.Sprintf("ErrInvalidTransactionsInNewBlock: [(%s: ErrNoTxInputs)]", txID) inner := &ErrInvalidTransactionsInNewBlock{} @@ -60,7 +61,7 @@ func TestNewErrInvalidTransactionsInNewBlock(t *testing.T) { if len(inner.InvalidTransactions) != 1 { t.Fatalf("TestNewErrInvalidTransactionsInNewBlock: Expected len(inner.MissingOutpoints) 1, found: %d", len(inner.InvalidTransactions)) } - if inner.InvalidTransactions[0].Error != ErrNoTxInputs { + if *inner.InvalidTransactions[0].Error != ErrNoTxInputs { t.Fatalf("TestNewErrInvalidTransactionsInNewBlock: Expected ErrNoTxInputs. found: %v", inner.InvalidTransactions[0].Error) } if inner.InvalidTransactions[0].Transaction.Fee != 1337 { diff --git a/domain/miningmanager/mempool/mempool.go b/domain/miningmanager/mempool/mempool.go index e511eb83b..1261e84b1 100644 --- a/domain/miningmanager/mempool/mempool.go +++ b/domain/miningmanager/mempool/mempool.go @@ -1,10 +1,12 @@ package mempool import ( + "sync" + "github.com/kaspanet/kaspad/domain/consensus/ruleerrors" "github.com/kaspanet/kaspad/domain/consensus/utils/consensushashing" "github.com/kaspanet/kaspad/domain/consensus/utils/constants" - "sync" + "github.com/pkg/errors" "github.com/kaspanet/kaspad/domain/consensusreference" @@ -209,17 +211,7 @@ func (mp *mempool) RemoveInvalidTransactions(err *ruleerrors.ErrInvalidTransacti defer mp.mtx.Unlock() for _, tx := range err.InvalidTransactions { - ruleErr, success := tx.Error.(ruleerrors.RuleError) - if !success { - continue - } - - inner := ruleErr.Unwrap() - removeRedeemers := true - if _, ok := inner.(ruleerrors.ErrMissingTxOut); ok { - removeRedeemers = false - } - + removeRedeemers := !errors.As(tx.Error, &ruleerrors.ErrMissingTxOut{}) err := mp.removeTransaction(consensushashing.TransactionID(tx.Transaction), removeRedeemers) if err != nil { return err