diff --git a/apiserver/controllers/transaction.go b/apiserver/controllers/transaction.go index 0ff775f82..876895fc0 100644 --- a/apiserver/controllers/transaction.go +++ b/apiserver/controllers/transaction.go @@ -279,8 +279,8 @@ func PostTransaction(requestBody []byte) error { return nil } -// GetTransactionsByIdsHandler finds transactions by the given transactionIds. -func GetTransactionsByIdsHandler(db *gorm.DB, transactionIds []string) ([]*apimodels.TransactionResponse, error) { +// GetTransactionsByIDsHandler finds transactions by the given transactionIds. +func GetTransactionsByIDsHandler(db *gorm.DB, transactionIds []string) ([]*apimodels.TransactionResponse, error) { var txs []*dbmodels.Transaction query := joinTxInputsTxOutputsAndAddresses(db). Where("`transactions`.`transaction_id` IN (?)", transactionIds) diff --git a/apiserver/mqtt/transactions.go b/apiserver/mqtt/transactions.go index 3daa15b25..9b2242ef8 100644 --- a/apiserver/mqtt/transactions.go +++ b/apiserver/mqtt/transactions.go @@ -2,10 +2,11 @@ package mqtt import ( "encoding/json" - "fmt" "github.com/daglabs/btcd/apiserver/apimodels" "github.com/daglabs/btcd/apiserver/controllers" + "github.com/daglabs/btcd/apiserver/database" "github.com/daglabs/btcd/btcjson" + "github.com/daglabs/btcd/rpcclient" "github.com/jinzhu/gorm" ) @@ -15,18 +16,18 @@ func PublishTransactionsNotifications(db *gorm.DB, rawTransactions []btcjson.TxR return nil } - transactionIds := make([]string, len(rawTransactions)) + transactionIDs := make([]string, len(rawTransactions)) for i, tx := range rawTransactions { - transactionIds[i] = tx.TxID + transactionIDs[i] = tx.TxID } - transactions, err := controllers.GetTransactionsByIdsHandler(db, transactionIds) + transactions, err := controllers.GetTransactionsByIDsHandler(db, transactionIDs) if err != nil { return err } for _, transaction := range transactions { - err = publishTransactionNotifications(transaction) + err = publishTransactionNotifications(transaction, "transactions/") if err != nil { return err } @@ -34,10 +35,10 @@ func PublishTransactionsNotifications(db *gorm.DB, rawTransactions []btcjson.TxR return nil } -func publishTransactionNotifications(transaction *apimodels.TransactionResponse) error { +func publishTransactionNotifications(transaction *apimodels.TransactionResponse, topic string) error { addresses := uniqueAddressesForTransaction(transaction) for _, address := range addresses { - err := publishTransactionNotificationForAddress(transaction, address) + err := publishTransactionNotificationForAddress(transaction, address, topic) if err != nil { return err } @@ -57,13 +58,13 @@ func uniqueAddressesForTransaction(transaction *apimodels.TransactionResponse) [ return addresses } -func publishTransactionNotificationForAddress(transaction *apimodels.TransactionResponse, address string) error { +func publishTransactionNotificationForAddress(transaction *apimodels.TransactionResponse, address string, topic string) error { payload, err := json.Marshal(transaction) if err != nil { return err } - token := client.Publish(transactionsTopic(address), 2, false, payload) + token := client.Publish(topic+address, 2, false, payload) token.Wait() if token.Error() != nil { return token.Error() @@ -72,6 +73,33 @@ func publishTransactionNotificationForAddress(transaction *apimodels.Transaction return nil } -func transactionsTopic(address string) string { - return fmt.Sprintf("transactions/%s", address) +// PublishAcceptedTransactionsNotifications publishes notification for each accepted transaction of the given chain-block +func PublishAcceptedTransactionsNotifications(addedChainBlocks []*rpcclient.ChainBlock) error { + db, err := database.DB() + if err != nil { + return err + } + + for _, addedChainBlock := range addedChainBlocks { + for _, acceptedBlock := range addedChainBlock.AcceptedBlocks { + transactionIDs := make([]string, len(acceptedBlock.AcceptedTxIDs)) + for i, acceptedTxID := range acceptedBlock.AcceptedTxIDs { + transactionIDs[i] = acceptedTxID.String() + } + + transactions, err := controllers.GetTransactionsByIDsHandler(db, transactionIDs) + if err != nil { + return err + } + + for _, transaction := range transactions { + err = publishTransactionNotifications(transaction, "transactions/accepted/") + if err != nil { + return err + } + } + return nil + } + } + return nil } diff --git a/apiserver/sync.go b/apiserver/sync.go index 28a8b4acc..5ec619578 100644 --- a/apiserver/sync.go +++ b/apiserver/sync.go @@ -1067,6 +1067,11 @@ func processChainChangedMsgs() { } log.Infof("Chain changed: removed %d blocks and added %d block", len(removedHashes), len(addedBlocks)) + + err = mqtt.PublishAcceptedTransactionsNotifications(chainChanged.AddedChainBlocks) + if err != nil { + panic(errors.Errorf("Error while publishing accepted transactions notifications %s", err)) + } } pendingChainChangedMsgs = unprocessedChainChangedMessages }