[NOD-1538] Fix mempool not wrapping consensus errors and bad invalid message handling (#1082)

* [NOD-1538] Correct messages.proto.

* [NOD-1538] Fix invalid message handling.

* [NOD-1538] Fix mempool not wrapping consensus errors.

* [NOD-1538] Extract wrapping logic to a separate function.

* [NOD-1538] Extract wrapping logic to an even better separate function.
This commit is contained in:
stasatdaglabs 2020-11-16 15:29:27 +02:00 committed by Svarog
parent 83a88d9989
commit 08749deaeb
14 changed files with 457 additions and 459 deletions

View File

@ -50,15 +50,9 @@ func (m *Manager) routerInitializer(router *routerpkg.Router, netConnection *net
return
}
spawn("Manager.routerInitializer-netConnection.DequeueInvalidMessage", func() {
for {
isOpen, err := netConnection.DequeueInvalidMessage()
if !isOpen {
return
}
if atomic.AddUint32(&isStopping, 1) == 1 {
errChan <- protocolerrors.Wrap(true, err, "received bad message")
}
netConnection.SetOnInvalidMessageHandler(func(err error) {
if atomic.AddUint32(&isStopping, 1) == 1 {
errChan <- protocolerrors.Wrap(true, err, "received bad message")
}
})

View File

@ -30,6 +30,11 @@ func (e RuleError) Error() string {
return e.Err.Error()
}
// Unwrap unwraps the wrapped error
func (e RuleError) Unwrap() error {
return e.Err
}
// RejectCode represents a numeric value by which a remote peer indicates
// why a message was rejected.
type RejectCode uint8
@ -94,6 +99,12 @@ func txRuleError(c RejectCode, desc string) RuleError {
}
}
func newRuleError(err error) RuleError {
return RuleError{
Err: err,
}
}
// extractRejectCode attempts to return a relevant reject code for a given error
// by examining the error for known types. It will return true if a code
// was successfully extracted.

View File

@ -620,6 +620,9 @@ func (mp *mempool) maybeAcceptTransaction(tx *consensusexternalapi.DomainTransac
if errors.As(err, &missingOutpoints) {
return missingOutpoints.MissingOutpoints, nil, nil
}
if errors.As(err, &ruleerrors.RuleError{}) {
return nil, nil, newRuleError(err)
}
return nil, nil, err
}

View File

@ -16,7 +16,6 @@ type NetConnection struct {
connection server.Connection
id *id.ID
router *routerpkg.Router
invalidMessageChan chan error
onDisconnectedHandler server.OnDisconnectedHandler
isRouterClosed uint32
}
@ -25,9 +24,8 @@ func newNetConnection(connection server.Connection, routerInitializer RouterInit
router := routerpkg.NewRouter()
netConnection := &NetConnection{
connection: connection,
router: router,
invalidMessageChan: make(chan error),
connection: connection,
router: router,
}
netConnection.connection.SetOnDisconnectedHandler(func() {
@ -36,15 +34,9 @@ func newNetConnection(connection server.Connection, routerInitializer RouterInit
if atomic.AddUint32(&netConnection.isRouterClosed, 1) == 1 {
netConnection.router.Close()
}
close(netConnection.invalidMessageChan)
netConnection.onDisconnectedHandler()
})
netConnection.connection.SetOnInvalidMessageHandler(func(err error) {
netConnection.invalidMessageChan <- err
})
routerInitializer(router, netConnection)
return netConnection
@ -98,8 +90,7 @@ func (c *NetConnection) Disconnect() {
}
}
// DequeueInvalidMessage dequeues the next invalid message
func (c *NetConnection) DequeueInvalidMessage() (isOpen bool, err error) {
err, isOpen = <-c.invalidMessageChan
return isOpen, err
// SetOnInvalidMessageHandler sets the invalid message handler for this connection
func (c *NetConnection) SetOnInvalidMessageHandler(onInvalidMessageHandler server.OnInvalidMessageHandler) {
c.connection.SetOnInvalidMessageHandler(onInvalidMessageHandler)
}

View File

@ -67,7 +67,9 @@ func (c *gRPCConnection) receiveLoop() error {
}
message, err := protoMessage.ToAppMessage()
if err != nil {
c.onInvalidMessageHandler(err)
if c.onInvalidMessageHandler != nil {
c.onInvalidMessageHandler(err)
}
return err
}
@ -88,7 +90,9 @@ func (c *gRPCConnection) receiveLoop() error {
if errors.Is(err, routerpkg.ErrRouteClosed) {
return nil
}
c.onInvalidMessageHandler(err)
if c.onInvalidMessageHandler != nil {
c.onInvalidMessageHandler(err)
}
return err
}
}

View File

@ -58,10 +58,6 @@ func (c *gRPCConnection) Start(router *router.Router) {
panic(errors.New("onDisconnectedHandler is nil"))
}
if c.onInvalidMessageHandler == nil {
panic(errors.New("onInvalidMessageHandler is nil"))
}
c.router = router
spawn("gRPCConnection.Start-connectionLoops", func() {

View File

@ -42,11 +42,11 @@ func domainHashesToProto(hashes []*externalapi.DomainHash) []*Hash {
return protoHashes
}
func (x *TransactionID) toDomain() (*externalapi.DomainTransactionID, error) {
func (x *TransactionId) toDomain() (*externalapi.DomainTransactionID, error) {
return transactionid.FromBytes(x.Bytes)
}
func protoTransactionIDsToDomain(protoIDs []*TransactionID) ([]*externalapi.DomainTransactionID, error) {
func protoTransactionIDsToDomain(protoIDs []*TransactionId) ([]*externalapi.DomainTransactionID, error) {
txIDs := make([]*externalapi.DomainTransactionID, len(protoIDs))
for i, protoID := range protoIDs {
var err error
@ -58,32 +58,32 @@ func protoTransactionIDsToDomain(protoIDs []*TransactionID) ([]*externalapi.Doma
return txIDs, nil
}
func domainTransactionIDToProto(id *externalapi.DomainTransactionID) *TransactionID {
return &TransactionID{
func domainTransactionIDToProto(id *externalapi.DomainTransactionID) *TransactionId {
return &TransactionId{
Bytes: id[:],
}
}
func wireTransactionIDsToProto(ids []*externalapi.DomainTransactionID) []*TransactionID {
protoIDs := make([]*TransactionID, len(ids))
func wireTransactionIDsToProto(ids []*externalapi.DomainTransactionID) []*TransactionId {
protoIDs := make([]*TransactionId, len(ids))
for i, hash := range ids {
protoIDs[i] = domainTransactionIDToProto(hash)
}
return protoIDs
}
func (x *SubnetworkID) toDomain() (*externalapi.DomainSubnetworkID, error) {
func (x *SubnetworkId) toDomain() (*externalapi.DomainSubnetworkID, error) {
if x == nil {
return nil, nil
}
return subnetworks.FromBytes(x.Bytes)
}
func domainSubnetworkIDToProto(id *externalapi.DomainSubnetworkID) *SubnetworkID {
func domainSubnetworkIDToProto(id *externalapi.DomainSubnetworkID) *SubnetworkId {
if id == nil {
return nil
}
return &SubnetworkID{
return &SubnetworkId{
Bytes: id[:],
}
}

View File

@ -91,7 +91,7 @@ message KaspadMessage {
// RequestAddressesMessage start
message RequestAddressesMessage{
bool includeAllSubnetworks = 1;
SubnetworkID subnetworkID = 2;
SubnetworkId subnetworkId = 2;
}
// RequestAddressesMessage end
@ -101,13 +101,13 @@ message AddressesMessage{
}
message NetAddress{
int64 timestamp = 1;
int64 timestamp = 1;
uint64 services = 2;
bytes ip = 3;
uint32 port = 4;
}
message SubnetworkID{
message SubnetworkId{
bytes bytes = 1;
}
// AddressesMessage end
@ -118,30 +118,30 @@ message TransactionMessage{
repeated TransactionInput inputs = 2;
repeated TransactionOutput outputs = 3;
uint64 lockTime = 4;
SubnetworkID subnetworkID = 5;
SubnetworkId subnetworkId = 5;
uint64 gas = 6;
Hash payloadHash = 7;
bytes Payload = 8;
bytes payload = 8;
}
message TransactionInput{
Outpoint PreviousOutpoint = 1;
bytes SignatureScript = 2;
uint64 Sequence = 3;
Outpoint previousOutpoint = 1;
bytes signatureScript = 2;
uint64 sequence = 3;
}
message Outpoint{
TransactionID transactionID = 1;
TransactionId transactionId = 1;
uint32 index = 2;
}
message TransactionID{
message TransactionId{
bytes bytes = 1;
}
message TransactionOutput{
uint64 value = 1;
bytes ScriptPubKey = 2;
bytes scriptPubKey = 2;
}
// TransactionMessage end
@ -155,7 +155,7 @@ message BlockHeaderMessage{
int32 version = 1;
repeated Hash parentHashes = 2;
Hash hashMerkleRoot = 3;
Hash acceptedIDMerkleRoot = 4;
Hash acceptedIdMerkleRoot = 4;
Hash utxoCommitment = 5;
int64 timestamp = 6;
uint32 bits = 7;
@ -210,13 +210,13 @@ message RequestSelectedTipMessage{
// RequestTransactionsMessage start
message RequestTransactionsMessage {
repeated TransactionID ids = 1;
repeated TransactionId ids = 1;
}
// GetTransactionsMessage end
// TransactionNotFoundMessage start
message TransactionNotFoundMessage{
TransactionID id = 1;
TransactionId id = 1;
}
// TransactionsNotFoundMessage end
@ -228,7 +228,7 @@ message InvRelayBlockMessage{
// InvTransactionMessage start
message InvTransactionsMessage{
repeated TransactionID ids = 1;
repeated TransactionId ids = 1;
}
// InvTransactionMessage end
@ -265,7 +265,7 @@ message VersionMessage{
string userAgent = 6;
Hash selectedTipHash = 7;
bool disableRelayTx = 8;
SubnetworkID subnetworkID = 9;
SubnetworkId subnetworkId = 9;
string network = 10;
}
// VersionMessage end
@ -427,7 +427,7 @@ message AddPeerResponseMessage{
}
message SubmitTransactionRequestMessage{
TransactionMessage transactionMessage = 1;
TransactionMessage transaction = 1;
}
message SubmitTransactionResponseMessage{

View File

@ -31,7 +31,7 @@ func (x *BlockHeaderMessage) toAppMessage() (*appmessage.MsgBlockHeader, error)
return nil, err
}
acceptedIDMerkleRoot, err := x.AcceptedIDMerkleRoot.toDomain()
acceptedIDMerkleRoot, err := x.AcceptedIdMerkleRoot.toDomain()
if err != nil {
return nil, err
}
@ -63,7 +63,7 @@ func (x *BlockHeaderMessage) fromAppMessage(msgBlockHeader *appmessage.MsgBlockH
Version: msgBlockHeader.Version,
ParentHashes: domainHashesToProto(msgBlockHeader.ParentHashes),
HashMerkleRoot: domainHashToProto(msgBlockHeader.HashMerkleRoot),
AcceptedIDMerkleRoot: domainHashToProto(msgBlockHeader.AcceptedIDMerkleRoot),
AcceptedIdMerkleRoot: domainHashToProto(msgBlockHeader.AcceptedIDMerkleRoot),
UtxoCommitment: domainHashToProto(msgBlockHeader.UTXOCommitment),
Timestamp: msgBlockHeader.Timestamp.UnixMilliseconds(),
Bits: msgBlockHeader.Bits,

View File

@ -6,7 +6,7 @@ import (
func (x *KaspadMessage_RequestAddresses) toAppMessage() (appmessage.Message, error) {
protoGetAddresses := x.RequestAddresses
subnetworkID, err := protoGetAddresses.SubnetworkID.toDomain()
subnetworkID, err := protoGetAddresses.SubnetworkId.toDomain()
if err != nil {
return nil, err
}
@ -20,7 +20,7 @@ func (x *KaspadMessage_RequestAddresses) toAppMessage() (appmessage.Message, err
func (x *KaspadMessage_RequestAddresses) fromAppMessage(msgGetAddresses *appmessage.MsgRequestAddresses) error {
x.RequestAddresses = &RequestAddressesMessage{
IncludeAllSubnetworks: msgGetAddresses.IncludeAllSubnetworks,
SubnetworkID: domainSubnetworkIDToProto(msgGetAddresses.SubnetworkID),
SubnetworkId: domainSubnetworkIDToProto(msgGetAddresses.SubnetworkID),
}
return nil
}

View File

@ -19,7 +19,7 @@ func (x *KaspadMessage_Transaction) fromAppMessage(msgTx *appmessage.MsgTx) erro
func (x *TransactionMessage) toAppMessage() (appmessage.Message, error) {
inputs := make([]*appmessage.TxIn, len(x.Inputs))
for i, protoInput := range x.Inputs {
prevTxID, err := protoInput.PreviousOutpoint.TransactionID.toDomain()
prevTxID, err := protoInput.PreviousOutpoint.TransactionId.toDomain()
if err != nil {
return nil, err
}
@ -36,11 +36,11 @@ func (x *TransactionMessage) toAppMessage() (appmessage.Message, error) {
}
}
if x.SubnetworkID == nil {
if x.SubnetworkId == nil {
return nil, errors.New("transaction subnetwork field cannot be nil")
}
subnetworkID, err := x.SubnetworkID.toDomain()
subnetworkID, err := x.SubnetworkId.toDomain()
if err != nil {
return nil, err
}
@ -70,7 +70,7 @@ func (x *TransactionMessage) fromAppMessage(msgTx *appmessage.MsgTx) {
for i, input := range msgTx.TxIn {
protoInputs[i] = &TransactionInput{
PreviousOutpoint: &Outpoint{
TransactionID: domainTransactionIDToProto(&input.PreviousOutpoint.TxID),
TransactionId: domainTransactionIDToProto(&input.PreviousOutpoint.TxID),
Index: input.PreviousOutpoint.Index,
},
SignatureScript: input.SignatureScript,
@ -91,7 +91,7 @@ func (x *TransactionMessage) fromAppMessage(msgTx *appmessage.MsgTx) {
Inputs: protoInputs,
Outputs: protoOutputs,
LockTime: msgTx.LockTime,
SubnetworkID: domainSubnetworkIDToProto(&msgTx.SubnetworkID),
SubnetworkId: domainSubnetworkIDToProto(&msgTx.SubnetworkID),
Gas: msgTx.Gas,
PayloadHash: domainHashToProto(&msgTx.PayloadHash),
Payload: msgTx.Payload,

View File

@ -22,7 +22,7 @@ func (x *KaspadMessage_Version) toAppMessage() (appmessage.Message, error) {
return nil, err
}
subnetworkID, err := x.Version.SubnetworkID.toDomain()
subnetworkID, err := x.Version.SubnetworkId.toDomain()
if err != nil {
return nil, err
}
@ -73,7 +73,7 @@ func (x *KaspadMessage_Version) fromAppMessage(msgVersion *appmessage.MsgVersion
UserAgent: msgVersion.UserAgent,
SelectedTipHash: domainHashToProto(msgVersion.SelectedTipHash),
DisableRelayTx: msgVersion.DisableRelayTx,
SubnetworkID: domainSubnetworkIDToProto(msgVersion.SubnetworkID),
SubnetworkId: domainSubnetworkIDToProto(msgVersion.SubnetworkID),
}
return nil
}

View File

@ -3,7 +3,7 @@ package protowire
import "github.com/kaspanet/kaspad/app/appmessage"
func (x *KaspadMessage_SubmitTransactionRequest) toAppMessage() (appmessage.Message, error) {
msgTx, err := x.SubmitTransactionRequest.TransactionMessage.toAppMessage()
msgTx, err := x.SubmitTransactionRequest.Transaction.toAppMessage()
if err != nil {
return nil, err
}
@ -14,9 +14,9 @@ func (x *KaspadMessage_SubmitTransactionRequest) toAppMessage() (appmessage.Mess
func (x *KaspadMessage_SubmitTransactionRequest) fromAppMessage(message *appmessage.SubmitTransactionRequestMessage) error {
x.SubmitTransactionRequest = &SubmitTransactionRequestMessage{
TransactionMessage: &TransactionMessage{},
Transaction: &TransactionMessage{},
}
x.SubmitTransactionRequest.TransactionMessage.fromAppMessage(message.Transaction)
x.SubmitTransactionRequest.Transaction.fromAppMessage(message.Transaction)
return nil
}