From 267ce84caaa5f9ea78285d86fb5312af772d7e15 Mon Sep 17 00:00:00 2001 From: nolash Date: Sun, 21 Mar 2021 19:39:38 +0100 Subject: [PATCH] Rehabilitate balance api --- apps/cic-eth/cic_eth/api/api_task.py | 22 ++++++--- apps/cic-eth/cic_eth/queue/balance.py | 19 +++---- apps/cic-eth/tests/conftest.py | 13 +++++ apps/cic-eth/tests/task/api/test_balance.py | 55 +++++++++++++++++++++ 4 files changed, 92 insertions(+), 17 deletions(-) create mode 100644 apps/cic-eth/tests/task/api/test_balance.py diff --git a/apps/cic-eth/cic_eth/api/api_task.py b/apps/cic-eth/cic_eth/api/api_task.py index ad99dae5..2184fdca 100644 --- a/apps/cic-eth/cic_eth/api/api_task.py +++ b/apps/cic-eth/cic_eth/api/api_task.py @@ -307,16 +307,22 @@ class Api: s_balance_incoming.link(s_balance_outgoing) last_in_chain = s_balance_outgoing - one = celery.chain(s_tokens, s_balance) - two = celery.chain(s_tokens, s_balance_incoming) - three = celery.chain(s_tokens, s_balance_outgoing) + one = celery.chain(s_tokens, s_balance) + two = celery.chain(s_tokens, s_balance_incoming) + three = celery.chain(s_tokens, s_balance_outgoing) - t = None - if self.callback_param != None: - s_result.link(self.callback_success).on_error(self.callback_error) - t = celery.chord([one, two, three])(s_result) + t = None + if self.callback_param != None: + s_result.link(self.callback_success).on_error(self.callback_error) + t = celery.chord([one, two, three])(s_result) + else: + t = celery.chord([one, two, three])(s_result) else: - t = celery.chord([one, two, three])(s_result) + # TODO: Chord is inefficient with only one chain, but assemble_balances must be able to handle different structures in order to avoid chord + one = celery.chain(s_tokens, s_balance) + if self.callback_param != None: + s_result.link(self.callback_success).on_error(self.callback_error) + t = celery.chord([one])(s_result) return t diff --git a/apps/cic-eth/cic_eth/queue/balance.py b/apps/cic-eth/cic_eth/queue/balance.py index d9648e95..1a5cf2db 100644 --- a/apps/cic-eth/cic_eth/queue/balance.py +++ b/apps/cic-eth/cic_eth/queue/balance.py @@ -3,10 +3,10 @@ import logging # third-party imports import celery +from chainlib.chain import ChainSpec from hexathon import strip_0x # local imports -from cic_registry.chain import ChainSpec from cic_eth.db import SessionBase from cic_eth.db.models.otx import Otx from cic_eth.db.models.tx import TxCache @@ -21,7 +21,7 @@ celery_app = celery.current_app logg = logging.getLogger() -def __balance_outgoing_compatible(token_address, holder_address, chain_str): +def __balance_outgoing_compatible(token_address, holder_address): session = SessionBase.create_session() q = session.query(TxCache.from_value) q = q.join(Otx) @@ -37,7 +37,7 @@ def __balance_outgoing_compatible(token_address, holder_address, chain_str): @celery_app.task(base=CriticalSQLAlchemyTask) -def balance_outgoing(tokens, holder_address, chain_str): +def balance_outgoing(tokens, holder_address, chain_spec_dict): """Retrieve accumulated value of unprocessed transactions sent from the given address. :param tokens: list of token spec dicts with addresses to retrieve balances for @@ -49,15 +49,15 @@ def balance_outgoing(tokens, holder_address, chain_str): :returns: Tokens dicts with outgoing balance added :rtype: dict """ - chain_spec = ChainSpec.from_chain_str(chain_str) + chain_spec = ChainSpec.from_dict(chain_spec_dict) for t in tokens: - b = __balance_outgoing_compatible(t['address'], holder_address, chain_str) + b = __balance_outgoing_compatible(t['address'], holder_address) t['balance_outgoing'] = b return tokens -def __balance_incoming_compatible(token_address, receiver_address, chain_str): +def __balance_incoming_compatible(token_address, receiver_address): session = SessionBase.create_session() q = session.query(TxCache.to_value) q = q.join(Otx) @@ -75,7 +75,7 @@ def __balance_incoming_compatible(token_address, receiver_address, chain_str): @celery_app.task(base=CriticalSQLAlchemyTask) -def balance_incoming(tokens, receipient_address, chain_str): +def balance_incoming(tokens, receipient_address, chain_spec_dict): """Retrieve accumulated value of unprocessed transactions to be received by the given address. :param tokens: list of token spec dicts with addresses to retrieve balances for @@ -87,9 +87,9 @@ def balance_incoming(tokens, receipient_address, chain_str): :returns: Tokens dicts with outgoing balance added :rtype: dict """ - chain_spec = ChainSpec.from_chain_str(chain_str) + chain_spec = ChainSpec.from_dict(chain_spec_dict) for t in tokens: - b = __balance_incoming_compatible(t['address'], receipient_address, chain_str) + b = __balance_incoming_compatible(t['address'], receipient_address) t['balance_incoming'] = b return tokens @@ -107,6 +107,7 @@ def assemble_balances(balances_collection): :rtype: list of dicts """ tokens = {} + logg.debug('received collection {}'.format(balances_collection)) for c in balances_collection: for b in c: address = b['address'] diff --git a/apps/cic-eth/tests/conftest.py b/apps/cic-eth/tests/conftest.py index 7e764ee9..421f9650 100644 --- a/apps/cic-eth/tests/conftest.py +++ b/apps/cic-eth/tests/conftest.py @@ -3,10 +3,14 @@ import os import sys import logging +# local imports +from cic_eth.api import Api + script_dir = os.path.dirname(os.path.realpath(__file__)) root_dir = os.path.dirname(script_dir) sys.path.insert(0, root_dir) +# assemble fixtures from tests.fixtures_config import * from tests.fixtures_database import * from tests.fixtures_celery import * @@ -15,3 +19,12 @@ from chainlib.eth.pytest import * from contract_registry.pytest import * from cic_eth_registry.pytest.fixtures_contracts import * from cic_eth_registry.pytest.fixtures_tokens import * + + +@pytest.fixture(scope='function') +def api( + default_chain_spec, + custodial_roles, + ): + chain_str = str(default_chain_spec) + return Api(chain_str, queue=None, callback_param='foo') diff --git a/apps/cic-eth/tests/task/api/test_balance.py b/apps/cic-eth/tests/task/api/test_balance.py new file mode 100644 index 00000000..650c2b10 --- /dev/null +++ b/apps/cic-eth/tests/task/api/test_balance.py @@ -0,0 +1,55 @@ +# standard imports +import os +import logging + +# external imports +from chainlib.eth.address import to_checksum_address + +# local imports +from cic_eth.api.api_task import Api + +logg = logging.getLogger() + +def test_balance_simple_api( + default_chain_spec, + init_database, + cic_registry, + foo_token, + register_tokens, + api, + celery_session_worker, + ): + + chain_str = str(default_chain_spec) + + a = to_checksum_address('0x' + os.urandom(20).hex()) + t = api.balance(a, 'FOO', include_pending=False) + r = t.get_leaf() + assert t.successful() + logg.debug(r) + + assert r[0].get('balance_network') != None + + +def test_balance_complex_api( + default_chain_spec, + init_database, + cic_registry, + foo_token, + register_tokens, + api, + celery_session_worker, + ): + + chain_str = str(default_chain_spec) + + a = to_checksum_address('0x' + os.urandom(20).hex()) + t = api.balance(a, 'FOO', include_pending=True) + r = t.get_leaf() + assert t.successful() + logg.debug(r) + + assert r[0].get('balance_incoming') != None + assert r[0].get('balance_outgoing') != None + assert r[0].get('balance_network') != None +