diff --git a/planetmint/backend/tarantool/query.py b/planetmint/backend/tarantool/query.py index a9f6fac..2d9a5e1 100644 --- a/planetmint/backend/tarantool/query.py +++ b/planetmint/backend/tarantool/query.py @@ -55,8 +55,8 @@ def get_transaction(connection, tx_id: str) -> DbTransaction: return NotImplemented @register_query(TarantoolDBConnection) -def get_transactions_by_asset(connection, asset: str) -> list[DbTransaction]: - txs = connection.run(connection.space(TARANT_TABLE_TRANSACTION).select(asset, index="transactions_by_asset_cid")) +def get_transactions_by_asset(connection, asset: str, limit: int = 1000) -> list[DbTransaction]: + txs = connection.run(connection.space(TARANT_TABLE_TRANSACTION).select(asset, limit=limit, index="transactions_by_asset_cid")) tx_ids = [tx[0] for tx in txs] return get_complete_transactions_by_ids(connection, tx_ids) diff --git a/planetmint/lib.py b/planetmint/lib.py index e9ea88b..195bc84 100644 --- a/planetmint/lib.py +++ b/planetmint/lib.py @@ -452,8 +452,8 @@ class Planetmint(object): """ return backend.query.get_assets(self.connection, asset_ids) - def get_assets_by_cid(self, asset_cid) -> list[dict]: - asset_txs = backend.query.get_transactions_by_asset(self.connection, asset_cid) + def get_assets_by_cid(self, asset_cid, **kwargs) -> list[dict]: + asset_txs = backend.query.get_transactions_by_asset(self.connection, asset_cid, **kwargs) # flatten and return all found assets return list(chain.from_iterable([Asset.list_to_dict(tx.assets) for tx in asset_txs])) diff --git a/planetmint/web/views/assets.py b/planetmint/web/views/assets.py index a106357..a6f263b 100644 --- a/planetmint/web/views/assets.py +++ b/planetmint/web/views/assets.py @@ -9,7 +9,7 @@ For more information please refer to the documentation: http://planetmint.io/htt """ import logging -from flask_restful import Resource +from flask_restful import Resource, reqparse from flask import current_app from planetmint.backend.exceptions import OperationError from planetmint.web.views.base import make_error @@ -19,10 +19,17 @@ logger = logging.getLogger(__name__) class AssetListApi(Resource): def get(self, cid: str): + parser = reqparse.RequestParser() + parser.add_argument("limit", type=int) + args = parser.parse_args() + + if not args["limit"]: + del args["limit"] + pool = current_app.config["bigchain_pool"] with pool() as planet: - assets = planet.get_assets_by_cid(cid) + assets = planet.get_assets_by_cid(cid, **args) try: # This only works with MongoDB as the backend diff --git a/tests/web/test_assets.py b/tests/web/test_assets.py index d027daa..8cdff72 100644 --- a/tests/web/test_assets.py +++ b/tests/web/test_assets.py @@ -23,3 +23,18 @@ def test_get_assets_tendermint(client, b, alice): assert res.status_code == 200 assert len(res.json) == 1 assert res.json[0] == {"data": assets[0]["data"]} + + +@pytest.mark.bdb +def test_get_assets_tendermint_limit(client, b, alice, bob): + # create assets + assets = [{"data": multihash(marshal({"msg": "abc"}))}] + tx_1 = Create.generate([alice.public_key], [([alice.public_key], 1)], assets=assets).sign([alice.private_key]) + tx_2 = Create.generate([bob.public_key], [([bob.public_key], 1)], assets=assets).sign([bob.private_key]) + + b.store_bulk_transactions([tx_1, tx_2]) + + res = client.get(ASSETS_ENDPOINT + assets[0]["data"] + "?limit=1") + assert res.status_code == 200 + assert len(res.json) == 1 + assert res.json[0] == {"data": assets[0]["data"]}