Rehabilitate balance api

This commit is contained in:
nolash 2021-03-21 19:39:38 +01:00
parent 333d410b1c
commit 267ce84caa
Signed by untrusted user who does not match committer: lash
GPG Key ID: 21D2E7BB88C2A746
4 changed files with 92 additions and 17 deletions

View File

@ -307,16 +307,22 @@ class Api:
s_balance_incoming.link(s_balance_outgoing) s_balance_incoming.link(s_balance_outgoing)
last_in_chain = s_balance_outgoing last_in_chain = s_balance_outgoing
one = celery.chain(s_tokens, s_balance) one = celery.chain(s_tokens, s_balance)
two = celery.chain(s_tokens, s_balance_incoming) two = celery.chain(s_tokens, s_balance_incoming)
three = celery.chain(s_tokens, s_balance_outgoing) three = celery.chain(s_tokens, s_balance_outgoing)
t = None t = None
if self.callback_param != None: if self.callback_param != None:
s_result.link(self.callback_success).on_error(self.callback_error) s_result.link(self.callback_success).on_error(self.callback_error)
t = celery.chord([one, two, three])(s_result) t = celery.chord([one, two, three])(s_result)
else:
t = celery.chord([one, two, three])(s_result)
else: else:
t = celery.chord([one, two, three])(s_result) # TODO: Chord is inefficient with only one chain, but assemble_balances must be able to handle different structures in order to avoid chord
one = celery.chain(s_tokens, s_balance)
if self.callback_param != None:
s_result.link(self.callback_success).on_error(self.callback_error)
t = celery.chord([one])(s_result)
return t return t

View File

@ -3,10 +3,10 @@ import logging
# third-party imports # third-party imports
import celery import celery
from chainlib.chain import ChainSpec
from hexathon import strip_0x from hexathon import strip_0x
# local imports # local imports
from cic_registry.chain import ChainSpec
from cic_eth.db import SessionBase from cic_eth.db import SessionBase
from cic_eth.db.models.otx import Otx from cic_eth.db.models.otx import Otx
from cic_eth.db.models.tx import TxCache from cic_eth.db.models.tx import TxCache
@ -21,7 +21,7 @@ celery_app = celery.current_app
logg = logging.getLogger() logg = logging.getLogger()
def __balance_outgoing_compatible(token_address, holder_address, chain_str): def __balance_outgoing_compatible(token_address, holder_address):
session = SessionBase.create_session() session = SessionBase.create_session()
q = session.query(TxCache.from_value) q = session.query(TxCache.from_value)
q = q.join(Otx) q = q.join(Otx)
@ -37,7 +37,7 @@ def __balance_outgoing_compatible(token_address, holder_address, chain_str):
@celery_app.task(base=CriticalSQLAlchemyTask) @celery_app.task(base=CriticalSQLAlchemyTask)
def balance_outgoing(tokens, holder_address, chain_str): def balance_outgoing(tokens, holder_address, chain_spec_dict):
"""Retrieve accumulated value of unprocessed transactions sent from the given address. """Retrieve accumulated value of unprocessed transactions sent from the given address.
:param tokens: list of token spec dicts with addresses to retrieve balances for :param tokens: list of token spec dicts with addresses to retrieve balances for
@ -49,15 +49,15 @@ def balance_outgoing(tokens, holder_address, chain_str):
:returns: Tokens dicts with outgoing balance added :returns: Tokens dicts with outgoing balance added
:rtype: dict :rtype: dict
""" """
chain_spec = ChainSpec.from_chain_str(chain_str) chain_spec = ChainSpec.from_dict(chain_spec_dict)
for t in tokens: for t in tokens:
b = __balance_outgoing_compatible(t['address'], holder_address, chain_str) b = __balance_outgoing_compatible(t['address'], holder_address)
t['balance_outgoing'] = b t['balance_outgoing'] = b
return tokens return tokens
def __balance_incoming_compatible(token_address, receiver_address, chain_str): def __balance_incoming_compatible(token_address, receiver_address):
session = SessionBase.create_session() session = SessionBase.create_session()
q = session.query(TxCache.to_value) q = session.query(TxCache.to_value)
q = q.join(Otx) q = q.join(Otx)
@ -75,7 +75,7 @@ def __balance_incoming_compatible(token_address, receiver_address, chain_str):
@celery_app.task(base=CriticalSQLAlchemyTask) @celery_app.task(base=CriticalSQLAlchemyTask)
def balance_incoming(tokens, receipient_address, chain_str): def balance_incoming(tokens, receipient_address, chain_spec_dict):
"""Retrieve accumulated value of unprocessed transactions to be received by the given address. """Retrieve accumulated value of unprocessed transactions to be received by the given address.
:param tokens: list of token spec dicts with addresses to retrieve balances for :param tokens: list of token spec dicts with addresses to retrieve balances for
@ -87,9 +87,9 @@ def balance_incoming(tokens, receipient_address, chain_str):
:returns: Tokens dicts with outgoing balance added :returns: Tokens dicts with outgoing balance added
:rtype: dict :rtype: dict
""" """
chain_spec = ChainSpec.from_chain_str(chain_str) chain_spec = ChainSpec.from_dict(chain_spec_dict)
for t in tokens: for t in tokens:
b = __balance_incoming_compatible(t['address'], receipient_address, chain_str) b = __balance_incoming_compatible(t['address'], receipient_address)
t['balance_incoming'] = b t['balance_incoming'] = b
return tokens return tokens
@ -107,6 +107,7 @@ def assemble_balances(balances_collection):
:rtype: list of dicts :rtype: list of dicts
""" """
tokens = {} tokens = {}
logg.debug('received collection {}'.format(balances_collection))
for c in balances_collection: for c in balances_collection:
for b in c: for b in c:
address = b['address'] address = b['address']

View File

@ -3,10 +3,14 @@ import os
import sys import sys
import logging import logging
# local imports
from cic_eth.api import Api
script_dir = os.path.dirname(os.path.realpath(__file__)) script_dir = os.path.dirname(os.path.realpath(__file__))
root_dir = os.path.dirname(script_dir) root_dir = os.path.dirname(script_dir)
sys.path.insert(0, root_dir) sys.path.insert(0, root_dir)
# assemble fixtures
from tests.fixtures_config import * from tests.fixtures_config import *
from tests.fixtures_database import * from tests.fixtures_database import *
from tests.fixtures_celery import * from tests.fixtures_celery import *
@ -15,3 +19,12 @@ from chainlib.eth.pytest import *
from contract_registry.pytest import * from contract_registry.pytest import *
from cic_eth_registry.pytest.fixtures_contracts import * from cic_eth_registry.pytest.fixtures_contracts import *
from cic_eth_registry.pytest.fixtures_tokens import * from cic_eth_registry.pytest.fixtures_tokens import *
@pytest.fixture(scope='function')
def api(
default_chain_spec,
custodial_roles,
):
chain_str = str(default_chain_spec)
return Api(chain_str, queue=None, callback_param='foo')

View File

@ -0,0 +1,55 @@
# standard imports
import os
import logging
# external imports
from chainlib.eth.address import to_checksum_address
# local imports
from cic_eth.api.api_task import Api
logg = logging.getLogger()
def test_balance_simple_api(
default_chain_spec,
init_database,
cic_registry,
foo_token,
register_tokens,
api,
celery_session_worker,
):
chain_str = str(default_chain_spec)
a = to_checksum_address('0x' + os.urandom(20).hex())
t = api.balance(a, 'FOO', include_pending=False)
r = t.get_leaf()
assert t.successful()
logg.debug(r)
assert r[0].get('balance_network') != None
def test_balance_complex_api(
default_chain_spec,
init_database,
cic_registry,
foo_token,
register_tokens,
api,
celery_session_worker,
):
chain_str = str(default_chain_spec)
a = to_checksum_address('0x' + os.urandom(20).hex())
t = api.balance(a, 'FOO', include_pending=True)
r = t.get_leaf()
assert t.successful()
logg.debug(r)
assert r[0].get('balance_incoming') != None
assert r[0].get('balance_outgoing') != None
assert r[0].get('balance_network') != None