Fix type detection in RemoveInvalidTransactions

This commit is contained in:
Ori Newman 2023-12-12 17:07:14 +02:00
parent 8e71f79f98
commit 92e9d7b8f3
4 changed files with 12 additions and 18 deletions

View File

@ -131,11 +131,12 @@ func (bb *blockBuilder) validateTransactions(stagingArea *model.StagingArea,
for _, transaction := range transactions { for _, transaction := range transactions {
err := bb.validateTransaction(stagingArea, transaction) err := bb.validateTransaction(stagingArea, transaction)
if err != nil { if err != nil {
if !errors.As(err, &ruleerrors.RuleError{}) { ruleError := ruleerrors.RuleError{}
if !errors.As(err, &ruleError) {
return err return err
} }
invalidTransactions = append(invalidTransactions, invalidTransactions = append(invalidTransactions,
ruleerrors.InvalidTransaction{Transaction: transaction, Error: err}) ruleerrors.InvalidTransaction{Transaction: transaction, Error: &ruleError})
} }
} }

View File

@ -315,7 +315,7 @@ func NewErrMissingParents(missingParentHashes []*externalapi.DomainHash) error {
// InvalidTransaction is a struct containing an invalid transaction, and the error explaining why it's invalid. // InvalidTransaction is a struct containing an invalid transaction, and the error explaining why it's invalid.
type InvalidTransaction struct { type InvalidTransaction struct {
Transaction *externalapi.DomainTransaction Transaction *externalapi.DomainTransaction
Error error Error *RuleError
} }
func (invalid InvalidTransaction) String() string { func (invalid InvalidTransaction) String() string {

View File

@ -3,9 +3,10 @@ package ruleerrors
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/kaspanet/kaspad/domain/consensus/utils/consensushashing"
"testing" "testing"
"github.com/kaspanet/kaspad/domain/consensus/utils/consensushashing"
"github.com/kaspanet/kaspad/domain/consensus/model/externalapi" "github.com/kaspanet/kaspad/domain/consensus/model/externalapi"
) )
@ -49,7 +50,7 @@ func TestNewErrMissingTxOut(t *testing.T) {
func TestNewErrInvalidTransactionsInNewBlock(t *testing.T) { func TestNewErrInvalidTransactionsInNewBlock(t *testing.T) {
tx := &externalapi.DomainTransaction{Fee: 1337} tx := &externalapi.DomainTransaction{Fee: 1337}
txID := consensushashing.TransactionID(tx) txID := consensushashing.TransactionID(tx)
outer := NewErrInvalidTransactionsInNewBlock([]InvalidTransaction{{tx, ErrNoTxInputs}}) outer := NewErrInvalidTransactionsInNewBlock([]InvalidTransaction{{tx, &ErrNoTxInputs}})
//TODO: Implement Stringer for `DomainTransaction` //TODO: Implement Stringer for `DomainTransaction`
expectedOuterErr := fmt.Sprintf("ErrInvalidTransactionsInNewBlock: [(%s: ErrNoTxInputs)]", txID) expectedOuterErr := fmt.Sprintf("ErrInvalidTransactionsInNewBlock: [(%s: ErrNoTxInputs)]", txID)
inner := &ErrInvalidTransactionsInNewBlock{} inner := &ErrInvalidTransactionsInNewBlock{}
@ -60,7 +61,7 @@ func TestNewErrInvalidTransactionsInNewBlock(t *testing.T) {
if len(inner.InvalidTransactions) != 1 { if len(inner.InvalidTransactions) != 1 {
t.Fatalf("TestNewErrInvalidTransactionsInNewBlock: Expected len(inner.MissingOutpoints) 1, found: %d", len(inner.InvalidTransactions)) t.Fatalf("TestNewErrInvalidTransactionsInNewBlock: Expected len(inner.MissingOutpoints) 1, found: %d", len(inner.InvalidTransactions))
} }
if inner.InvalidTransactions[0].Error != ErrNoTxInputs { if *inner.InvalidTransactions[0].Error != ErrNoTxInputs {
t.Fatalf("TestNewErrInvalidTransactionsInNewBlock: Expected ErrNoTxInputs. found: %v", inner.InvalidTransactions[0].Error) t.Fatalf("TestNewErrInvalidTransactionsInNewBlock: Expected ErrNoTxInputs. found: %v", inner.InvalidTransactions[0].Error)
} }
if inner.InvalidTransactions[0].Transaction.Fee != 1337 { if inner.InvalidTransactions[0].Transaction.Fee != 1337 {

View File

@ -1,10 +1,12 @@
package mempool package mempool
import ( import (
"sync"
"github.com/kaspanet/kaspad/domain/consensus/ruleerrors" "github.com/kaspanet/kaspad/domain/consensus/ruleerrors"
"github.com/kaspanet/kaspad/domain/consensus/utils/consensushashing" "github.com/kaspanet/kaspad/domain/consensus/utils/consensushashing"
"github.com/kaspanet/kaspad/domain/consensus/utils/constants" "github.com/kaspanet/kaspad/domain/consensus/utils/constants"
"sync" "github.com/pkg/errors"
"github.com/kaspanet/kaspad/domain/consensusreference" "github.com/kaspanet/kaspad/domain/consensusreference"
@ -209,17 +211,7 @@ func (mp *mempool) RemoveInvalidTransactions(err *ruleerrors.ErrInvalidTransacti
defer mp.mtx.Unlock() defer mp.mtx.Unlock()
for _, tx := range err.InvalidTransactions { for _, tx := range err.InvalidTransactions {
ruleErr, success := tx.Error.(ruleerrors.RuleError) removeRedeemers := !errors.As(tx.Error, &ruleerrors.ErrMissingTxOut{})
if !success {
continue
}
inner := ruleErr.Unwrap()
removeRedeemers := true
if _, ok := inner.(ruleerrors.ErrMissingTxOut); ok {
removeRedeemers = false
}
err := mp.removeTransaction(consensushashing.TransactionID(tx.Transaction), removeRedeemers) err := mp.removeTransaction(consensushashing.TransactionID(tx.Transaction), removeRedeemers)
if err != nil { if err != nil {
return err return err