diff --git a/bigchaindb/backend/mongodb/query.py b/bigchaindb/backend/mongodb/query.py index e3b71315..2e4dff51 100644 --- a/bigchaindb/backend/mongodb/query.py +++ b/bigchaindb/backend/mongodb/query.py @@ -1,7 +1,6 @@ """Query implementation for MongoDB""" from time import time -from itertools import chain from pymongo import ReturnDocument from pymongo import errors @@ -86,39 +85,30 @@ def get_blocks_status_from_transaction(conn, transaction_id): @register_query(MongoDBConnection) def get_txids_filtered(conn, asset_id, operation=None): - parts = [] + match_create = { + 'block.transactions.operation': 'CREATE', + 'block.transactions.id': asset_id + } + match_transfer = { + 'block.transactions.operation': 'TRANSFER', + 'block.transactions.asset.id': asset_id + } - if operation in (Transaction.CREATE, None): - # get the txid of the create transaction for asset_id - cursor = conn.db['bigchain'].aggregate([ - {'$match': { - 'block.transactions.id': asset_id, - 'block.transactions.operation': 'CREATE' - }}, - {'$unwind': '$block.transactions'}, - {'$match': { - 'block.transactions.id': asset_id, - 'block.transactions.operation': 'CREATE' - }}, - {'$project': {'block.transactions.id': True}} - ]) - parts.append(elem['block']['transactions']['id'] for elem in cursor) + if operation == Transaction.CREATE: + match = match_create + elif operation == Transaction.TRANSFER: + match = match_transfer + else: + match = {'$or': [match_create, match_transfer]} - if operation in (Transaction.TRANSFER, None): - # get txids of transfer transaction with asset_id - cursor = conn.db['bigchain'].aggregate([ - {'$match': { - 'block.transactions.asset.id': asset_id - }}, - {'$unwind': '$block.transactions'}, - {'$match': { - 'block.transactions.asset.id': asset_id - }}, - {'$project': {'block.transactions.id': True}} - ]) - parts.append(elem['block']['transactions']['id'] for elem in cursor) - - return chain(*parts) + pipeline = [ + {'$match': match}, + {'$unwind': '$block.transactions'}, + {'$match': match}, + {'$project': {'block.transactions.id': True}} + ] + cursor = conn.db['bigchain'].aggregate(pipeline) + return (elem['block']['transactions']['id'] for elem in cursor) @register_query(MongoDBConnection) diff --git a/bigchaindb/backend/mongodb/schema.py b/bigchaindb/backend/mongodb/schema.py index 95c2d02a..4c5189ac 100644 --- a/bigchaindb/backend/mongodb/schema.py +++ b/bigchaindb/backend/mongodb/schema.py @@ -60,8 +60,7 @@ def create_bigchain_secondary_index(conn, dbname): # secondary index for asset uuid, this field is unique conn.conn[dbname]['bigchain']\ - .create_index('block.transactions.transaction.asset.id', - name='asset_id') + .create_index('block.transactions.asset.id', name='asset_id') # secondary index on the public keys of outputs conn.conn[dbname]['bigchain']\ diff --git a/tests/backend/mongodb/test_indexes.py b/tests/backend/mongodb/test_indexes.py new file mode 100644 index 00000000..ba6afae1 --- /dev/null +++ b/tests/backend/mongodb/test_indexes.py @@ -0,0 +1,23 @@ +import pytest +from unittest.mock import MagicMock + +pytestmark = pytest.mark.bdb + + +def test_asset_id_index(): + from bigchaindb.backend.mongodb.query import get_txids_filtered + from bigchaindb.backend import connect + + # Passes a mock in place of a connection to get the query params from the + # query function, then gets the explain plan from MongoDB to test that + # it's using certain indexes. + + m = MagicMock() + get_txids_filtered(m, '') + pipeline = m.db['bigchain'].aggregate.call_args[0][0] + run = connect().db.command + res = run('aggregate', 'bigchain', pipeline=pipeline, explain=True) + stages = (res['stages'][0]['$cursor']['queryPlanner']['winningPlan'] + ['inputStage']['inputStages']) + indexes = [s['inputStage']['indexName'] for s in stages] + assert set(indexes) == {'asset_id', 'transaction_id'}