Add toplevel callback filter test

This commit is contained in:
nolash 2021-04-14 13:42:14 +02:00
parent be65da9924
commit 7a958c6f89
Signed by untrusted user who does not match committer: lash
GPG Key ID: 21D2E7BB88C2A746
3 changed files with 17 additions and 13 deletions

View File

@ -1,6 +1,10 @@
# extended imports # external imports
from chainlib.eth.constant import ZERO_ADDRESS from chainlib.eth.constant import ZERO_ADDRESS
from chainlib.status import Status as TxStatus from chainlib.status import Status as TxStatus
from cic_eth_registry.erc20 import ERC20Token
# local imports
from cic_eth.ext.address import translate_address
class ExtendedTx: class ExtendedTx:
@ -27,12 +31,12 @@ class ExtendedTx:
self.status_code = TxStatus.PENDING.value self.status_code = TxStatus.PENDING.value
def set_actors(self, sender, recipient, trusted_declarator_addresses=None): def set_actors(self, sender, recipient, trusted_declarator_addresses=None, caller_address=ZERO_ADDRESS):
self.sender = sender self.sender = sender
self.recipient = recipient self.recipient = recipient
if trusted_declarator_addresses != None: if trusted_declarator_addresses != None:
self.sender_label = translate_address(sender, trusted_declarator_addresses, self.chain_spec) self.sender_label = translate_address(sender, trusted_declarator_addresses, self.chain_spec, sender_address=caller_address)
self.recipient_label = translate_address(recipient, trusted_declarator_addresses, self.chain_spec) self.recipient_label = translate_address(recipient, trusted_declarator_addresses, self.chain_spec, sender_address=caller_address)
def set_tokens(self, source, source_value, destination=None, destination_value=None): def set_tokens(self, source, source_value, destination=None, destination_value=None):
@ -40,8 +44,8 @@ class ExtendedTx:
destination = source destination = source
if destination_value == None: if destination_value == None:
destination_value = source_value destination_value = source_value
st = ERC20Token(self.rpc, source) st = ERC20Token(self.chain_spec, self.rpc, source)
dt = ERC20Token(self.rpc, destination) dt = ERC20Token(self.chain_spec, self.rpc, destination)
self.source_token = source self.source_token = source
self.source_token_symbol = st.symbol self.source_token_symbol = st.symbol
self.source_token_name = st.name self.source_token_name = st.name

View File

@ -1,7 +1,7 @@
# standard imports # standard imports
import logging import logging
# third-party imports # external imports
import celery import celery
from cic_eth_registry.error import UnknownContractError from cic_eth_registry.error import UnknownContractError
from chainlib.status import Status as TxStatus from chainlib.status import Status as TxStatus
@ -59,10 +59,11 @@ class CallbackFilter(SyncFilter):
trusted_addresses = [] trusted_addresses = []
def __init__(self, chain_spec, method, queue): def __init__(self, chain_spec, method, queue, caller_address=ZERO_ADDRESS):
self.queue = queue self.queue = queue
self.method = method self.method = method
self.chain_spec = chain_spec self.chain_spec = chain_spec
self.caller_address = caller_address
def call_back(self, transfer_type, result): def call_back(self, transfer_type, result):
@ -143,7 +144,7 @@ class CallbackFilter(SyncFilter):
result = None result = None
try: try:
tokentx = ExtendedTx(conn, tx.hash, self.chain_spec) tokentx = ExtendedTx(conn, tx.hash, self.chain_spec)
tokentx.set_actors(transfer_data['from'], transfer_data['to'], self.trusted_addresses) tokentx.set_actors(transfer_data['from'], transfer_data['to'], self.trusted_addresses, caller_address=self.caller_address)
tokentx.set_tokens(transfer_data['token_address'], transfer_data['value']) tokentx.set_tokens(transfer_data['token_address'], transfer_data['value'])
if transfer_data['status'] == 0: if transfer_data['status'] == 0:
tokentx.set_status(1) tokentx.set_status(1)

View File

@ -27,7 +27,6 @@ from cic_eth.runnable.daemons.filters.callback import (
logg = logging.getLogger() logg = logging.getLogger()
@pytest.mark.skip()
def test_transfer_tx( def test_transfer_tx(
default_chain_spec, default_chain_spec,
init_database, init_database,
@ -65,7 +64,6 @@ def test_transfer_tx(
assert transfer_type == 'transfer' assert transfer_type == 'transfer'
@pytest.mark.skip()
def test_transfer_from_tx( def test_transfer_from_tx(
default_chain_spec, default_chain_spec,
init_database, init_database,
@ -164,8 +162,10 @@ def test_callback_filter(
eth_rpc, eth_rpc,
eth_signer, eth_signer,
foo_token, foo_token,
token_roles,
agent_roles, agent_roles,
contract_roles, contract_roles,
register_lookups,
): ):
rpc = RPCConnection.connect(default_chain_spec, 'default') rpc = RPCConnection.connect(default_chain_spec, 'default')
@ -189,14 +189,13 @@ def test_callback_filter(
rcpt = snake_and_camel(r) rcpt = snake_and_camel(r)
tx.apply_receipt(rcpt) tx.apply_receipt(rcpt)
fltr = CallbackFilter(default_chain_spec, None, None) fltr = CallbackFilter(default_chain_spec, None, None, caller_address=contract_roles['CONTRACT_DEPLOYER'])
class CallbackMock: class CallbackMock:
def __init__(self): def __init__(self):
self.results = {} self.results = {}
def call_back(self, transfer_type, result): def call_back(self, transfer_type, result):
self.results[transfer_type] = result self.results[transfer_type] = result