Factor out repeated code for most cli apps, along with classes specific to traffic script

This commit is contained in:
nolash 2021-02-21 10:52:48 +01:00
parent 93ae16b578
commit 2b535e2f31
Signed by untrusted user who does not match committer: lash
GPG Key ID: 21D2E7BB88C2A746
11 changed files with 619 additions and 417 deletions

View File

@ -0,0 +1,266 @@
# standard imports
import logging
import json
import uuid
import importlib
import random
import copy
# external imports
import redis
from cic_eth.api.api_task import Api
logg = logging.getLogger(__name__)
class TrafficItem:
def __init__(self, item):
self.method = item.do
self.uuid = uuid.uuid4()
self.ext = None
self.result = None
self.sender = None
self.recipient = None
self.source_token = None
self.destination_token = None
self.source_value = 0
def __str__(self):
return 'traffic item method {} uuid {}'.format(self.method, self.uuid)
class TrafficRouter:
def __init__(self, batch_size=1):
self.items = []
self.weights = []
self.total_weights = 0
self.batch_size = batch_size
self.reserved = {}
self.reserved_count = 0
self.traffic = {}
def add(self, item, weight):
self.weights.append(self.total_weights)
self.total_weights += weight
m = importlib.import_module(item)
self.items.append(m)
def reserve(self):
if len(self.reserved) == self.batch_size:
return None
n = random.randint(0, self.total_weights)
item = self.items[0]
for i in range(len(self.weights)):
if n <= self.weights[i]:
item = self.items[i]
break
ti = TrafficItem(item)
self.reserved[ti.uuid] = ti
return ti
def release(self, traffic_item):
del self.reserved[traffic_item.uuid]
def apply_import_dict(self, keys, dct):
# parse traffic items
for k in keys:
if len(k) > 8 and k[:8] == 'TRAFFIC_':
v = int(dct.get(k))
try:
self.add(k[8:].lower(), v)
except ModuleNotFoundError as e:
raise AttributeError('requested traffic item module not found: {}'.format(e))
logg.debug('found traffic item {} weight {}'.format(k, v))
class TrafficProvisioner:
oracles = {
'account': None,
'token': None,
}
default_aux = {
}
def __init__(self):
self.tokens = self.oracles['token'].get_tokens()
self.accounts = self.oracles['account'].get_accounts()
self.aux = copy.copy(self.default_aux)
self.__balances = {}
def load_balances(self):
pass
def __cache_balance(self, holder_address, token, value):
if self.__balances.get(holder_address) == None:
self.__balances[holder_address] = {}
self.__balances[holder_address][token] = value
logg.debug('setting cached balance of {} token {} to {}'.format(holder_address, token, value))
def add_aux(self, k, v):
logg.debug('added {} = {} to traffictasker'.format(k, v))
self.aux[k] = v
def balances(self, accounts=None, refresh=False):
if refresh:
if accounts == None:
accounts = self.accounts
for account in accounts:
for token in self.tokens:
value = self.balance(account, token)
self.__cache_balance(account, token.symbol(), value)
logg.debug('balance sender {} token {} = {}'.format(account, token, value))
else:
logg.debug('returning cached balances')
return self.__balances
def balance(self, account, token):
# TODO: use proper redis callback
api = Api(
str(self.aux['chain_spec']),
queue=self.aux['api_queue'],
#callback_param='{}:{}:{}:{}'.format(aux['redis_host_callback'], aux['redis_port_callback'], aux['redis_db'], aux['redis_channel']),
#callback_task='cic_eth.callbacks.redis.redis',
#callback_queue=queue,
)
t = api.balance(account, token.symbol())
r = t.get()
for c in t.collect():
r = c[1]
assert t.successful()
return r[0]
def update_balance(self, account, token, value):
self.__cache_balance(account, token.symbol(), value)
class TrafficSyncHandler:
def __init__(self, config, traffic_router):
self.traffic_router = traffic_router
self.redis_channel = str(uuid.uuid4())
self.pubsub = self.__connect_redis(self.redis_channel, config)
self.traffic_items = {}
self.config = config
self.init = False
def __connect_redis(self, redis_channel, config):
r = redis.Redis(config.get('REDIS_HOST'), config.get('REDIS_PORT'), config.get('REDIS_DB'))
redis_pubsub = r.pubsub()
redis_pubsub.subscribe(redis_channel)
logg.debug('redis connected on channel {}'.format(redis_channel))
return redis_pubsub
def refresh(self, block_number, tx_index):
traffic_provisioner = TrafficProvisioner()
traffic_provisioner.add_aux('redis_channel', self.redis_channel)
refresh_balance = not self.init
balances = traffic_provisioner.balances(refresh=refresh_balance)
self.init = True
if len(traffic_provisioner.tokens) == 0:
logg.error('patiently waiting for at least one registered token...')
return
logg.debug('executing handler refresh with accouts {}'.format(traffic_provisioner.accounts))
logg.debug('executing handler refresh with tokens {}'.format(traffic_provisioner.tokens))
sender_indices = [*range(0, len(traffic_provisioner.accounts))]
# TODO: only get balances for the selection that we will be generating for
while True:
traffic_item = self.traffic_router.reserve()
if traffic_item == None:
logg.debug('no traffic_items left to reserve {}'.format(traffic_item))
break
# TODO: temporary selection
token_pair = [
traffic_provisioner.tokens[0],
traffic_provisioner.tokens[0],
]
sender_index_index = random.randint(0, len(sender_indices)-1)
sender_index = sender_indices[sender_index_index]
sender = traffic_provisioner.accounts[sender_index]
#balance_full = balances[sender][token_pair[0].symbol()]
if len(sender_indices) == 1:
sender_indices[m] = sender_sender_indices[len(senders)-1]
sender_indices = sender_indices[:len(sender_indices)-1]
balance_full = traffic_provisioner.balance(sender, token_pair[0])
balance = balance_full['balance_network'] - balance_full['balance_outgoing']
recipient_index = random.randint(0, len(traffic_provisioner.accounts)-1)
recipient = traffic_provisioner.accounts[recipient_index]
(e, t, balance_result,) = traffic_item.method(
token_pair,
sender,
recipient,
balance,
traffic_provisioner.aux,
block_number,
tx_index,
)
traffic_provisioner.update_balance(sender, token_pair[0], balance_result)
sender_indices.append(recipient_index)
if e != None:
logg.info('failed {}: {}'.format(str(traffic_item), e))
self.traffic_router.release(traffic_item)
continue
if t == None:
logg.info('traffic method {} completed immediately')
self.traffic_router.release(traffic_item)
traffic_item.ext = t
self.traffic_items[traffic_item.ext] = traffic_item
while True:
m = self.pubsub.get_message(timeout=0.1)
if m == None:
break
logg.debug('redis message {}'.format(m))
if m['type'] == 'message':
message_data = json.loads(m['data'])
uu = message_data['root_id']
match_item = self.traffic_items[uu]
self.traffic_router.release(match_item)
if message_data['status'] == 0:
logg.error('task item {} failed with error code {}'.format(match_item, message_data['status']))
else:
match_item['result'] = message_data['result']
logg.debug('got callback result: {}'.format(match_item))
def name(self):
return 'traffic_item_handler'
def filter(self, conn, block, tx, session):
logg.debug('handler get {}'.format(tx))

View File

@ -0,0 +1,8 @@
from . import (
log,
argparse,
config,
signer,
rpc,
registry,
)

View File

@ -0,0 +1,73 @@
# standard imports
import logging
import argparse
import os
import sys
default_config_dir = os.environ.get('CONFINI_DIR')
full_template = {
# (long arg and key name, short var, type, default, help,)
'provider': ('p', str, None, 'RPC provider url',),
'registry_address': ('r', str, None, 'CIC registry address',),
'keystore_file': ('y', str, None, 'Keystore file',),
'config_dir': ('c', str, default_config_dir, 'Configuration directory',),
'queue': ('q', str, 'cic-eth', 'Celery task queue',),
'chain_spec': ('i', str, None, 'Chain spec string',),
'abi_dir': (None, str, None, 'Smart contract ABI search path',),
'env_prefix': (None, str, os.environ.get('CONFINI_ENV_PREFIX'), 'Environment prefix for variables to overwrite configuration',),
}
default_include_args = [
'config_dir',
'provider',
'env_prefix',
]
sub = None
def create(caller_dir, include_args=default_include_args):
argparser = argparse.ArgumentParser()
for k in include_args:
a = full_template[k]
long_flag = '--' + k.replace('_', '-')
short_flag = None
dest = None
if a[0] != None:
short_flag = '-' + a[0]
dest = a[0]
else:
dest = k
default = a[2]
if default == None and k == 'config_dir':
default = os.path.join(caller_dir, 'config')
if short_flag == None:
argparser.add_argument(long_flag, dest=dest, type=a[1], default=default, help=a[3])
else:
argparser.add_argument(short_flag, long_flag, dest=dest, type=a[1], default=default, help=a[3])
argparser.add_argument('-v', action='store_true', help='Be verbose')
argparser.add_argument('-vv', action='store_true', help='Be more verbose')
return argparser
def add(argparser, processor, name, description=None):
processor(argparser)
return argparser
def parse(argparser, logger=None):
args = argparser.parse_args(sys.argv[1:])
# handle logging input
if logger != None:
if args.vv:
logger.setLevel(logging.DEBUG)
elif args.v:
logger.setLevel(logging.INFO)
return args

View File

@ -0,0 +1,39 @@
# external imports
import logging
import confini
logg = logging.getLogger(__name__)
default_arg_overrides = {
'abi_dir': 'ETH_ABI_DIR',
'p': 'ETH_PROVIDER',
'i': 'CIC_CHAIN_SPEC',
'r': 'CIC_REGISTRY_ADDRESS',
}
def override(config, override_dict, label):
config.dict_override(override_dict, label)
config.validate()
return config
def create(config_dir, args, env_prefix=None, arg_overrides=default_arg_overrides):
# handle config input
config = confini.Config(config_dir, env_prefix)
config.process()
if arg_overrides != None and args != None:
override_dict = {}
for k in arg_overrides:
v = getattr(args, k)
if v != None:
override_dict[arg_overrides[k]] = v
config = override(config, override_dict, 'args')
else:
config.validate()
return config
def log(config):
logg.debug('config loaded:\n{}'.format(config))

View File

@ -0,0 +1,18 @@
# standard imports
import logging
logging.basicConfig(level=logging.WARNING)
default_mutelist = [
'urllib3',
'websockets.protocol',
'web3.RequestManager',
'web3.providers.WebsocketProvider',
'web3.providers.HTTPProvider',
]
def create(name=None, mutelist=default_mutelist):
logg = logging.getLogger(name)
for m in mutelist:
logging.getLogger(m).setLevel(logging.CRITICAL)
return logg

View File

@ -0,0 +1,86 @@
# standard imports
import logging
import copy
# external imports
from cic_registry import CICRegistry
from eth_token_index import TokenUniqueSymbolIndex
from eth_accounts_index import AccountRegistry
from chainlib.chain import ChainSpec
from cic_registry.chain import ChainRegistry
from cic_registry.helper.declarator import DeclaratorOracleAdapter
logg = logging.getLogger(__name__)
class TokenOracle:
def __init__(self, conn, chain_spec, registry):
self.tokens = []
self.chain_spec = chain_spec
self.registry = registry
token_registry_contract = CICRegistry.get_contract(chain_spec, 'TokenRegistry', 'Registry')
self.getter = TokenUniqueSymbolIndex(conn, token_registry_contract.address())
def get_tokens(self):
token_count = self.getter.count()
if token_count == len(self.tokens):
return self.tokens
for i in range(len(self.tokens), token_count):
token_address = self.getter.get_index(i)
t = self.registry.get_address(self.chain_spec, token_address)
token_symbol = t.symbol()
self.tokens.append(t)
logg.debug('adding token idx {} symbol {} address {}'.format(i, token_symbol, token_address))
return copy.copy(self.tokens)
class AccountsOracle:
def __init__(self, conn, chain_spec, registry):
self.accounts = []
self.chain_spec = chain_spec
self.registry = registry
accounts_registry_contract = CICRegistry.get_contract(chain_spec, 'AccountRegistry', 'Registry')
self.getter = AccountRegistry(conn, accounts_registry_contract.address())
def get_accounts(self):
accounts_count = self.getter.count()
if accounts_count == len(self.accounts):
return self.accounts
for i in range(len(self.accounts), accounts_count):
account = self.getter.get_index(i)
self.accounts.append(account)
logg.debug('adding account {}'.format(account))
return copy.copy(self.accounts)
def init_legacy(config, w3):
chain_spec = ChainSpec.from_chain_str(config.get('CIC_CHAIN_SPEC'))
CICRegistry.init(w3, config.get('CIC_REGISTRY_ADDRESS'), chain_spec)
CICRegistry.add_path(config.get('ETH_ABI_DIR'))
chain_registry = ChainRegistry(chain_spec)
CICRegistry.add_chain_registry(chain_registry, True)
declarator = CICRegistry.get_contract(chain_spec, 'AddressDeclarator', interface='Declarator')
trusted_addresses_src = config.get('CIC_TRUST_ADDRESS')
if trusted_addresses_src == None:
raise ValueError('At least one trusted address must be declared in CIC_TRUST_ADDRESS')
trusted_addresses = trusted_addresses_src.split(',')
for address in trusted_addresses:
logg.info('using trusted address {}'.format(address))
oracle = DeclaratorOracleAdapter(declarator.contract, trusted_addresses)
chain_registry.add_oracle(oracle, 'naive_erc20_oracle')
return CICRegistry

View File

@ -0,0 +1,18 @@
# standard imports
import re
# external imports
import web3
def create(url):
# web3 input
# TODO: Replace with chainlib
re_websocket = r'^wss?:'
re_http = r'^https?:'
blockchain_provider = None
if re.match(re_websocket, url):
blockchain_provider = web3.Web3.WebsocketProvider(url)
elif re.match(re_http, url):
blockchain_provider = web3.Web3.HTTPProvider(url)
w3 = web3.Web3(blockchain_provider)
return w3

View File

@ -0,0 +1,23 @@
# standard imports
import logging
# external imports
from crypto_dev_signer.eth.signer import ReferenceSigner as EIP155Signer
from crypto_dev_signer.keystore import DictKeystore
logg = logging.getLogger(__name__)
keystore = DictKeystore()
def from_keystore(keyfile):
global keystore
# signer
if keyfile == None:
raise ValueError('please specify signer keystore file')
logg.debug('loading keystore file {}'.format(keyfile))
address = keystore.import_keystore_file(keyfile)
signer = EIP155Signer(keystore)
return (address, signer,)

View File

@ -0,0 +1,39 @@
# standard imports
import logging
import random
# external imports
from cic_eth.api.api_task import Api
logging.basicConfig(level=logging.WARNING)
logg = logging.getLogger()
queue = 'cic-eth'
name = 'erc20_transfer'
def do(token_pair, sender, recipient, sender_balance, aux, block_number, tx_index):
logg.debug('running {} {} {} {}'.format(__name__, token_pair, sender, recipient))
decimals = token_pair[0].decimals()
balance_units = int(sender_balance / decimals)
if balance_units == 0:
return (AttributeError('sender {} has zero balance'), None, 0,)
spend_units = random.randint(1, balance_units)
spend_value = spend_units * decimals
api = Api(
str(aux['chain_spec']),
queue=queue,
callback_param='{}:{}:{}:{}'.format(aux['redis_host_callback'], aux['redis_port_callback'], aux['redis_db'], aux['redis_channel']),
callback_task='cic_eth.callbacks.redis.redis',
callback_queue=queue,
)
t = api.transfer(sender, recipient, spend_value, token_pair[0].symbol())
changed_sender_balance = sender_balance - spend_value
return (None, t, changed_sender_balance,)

View File

@ -1,84 +1,51 @@
# standard imports # standard imports
import os import os
import logging import logging
import argparse
import re import re
import sys import sys
import uuid
import importlib
import copy
import random
import json import json
from argparse import RawTextHelpFormatter
# external imports # external imports
import redis import redis
import confini
import web3
import celery import celery
from cic_registry import CICRegistry
from cic_registry.chain import ChainRegistry
from chainlib.chain import ChainSpec
from eth_token_index import TokenUniqueSymbolIndex
from eth_accounts_index import AccountRegistry
from cic_registry.helper.declarator import DeclaratorOracleAdapter
from chainsyncer.backend import MemBackend from chainsyncer.backend import MemBackend
from chainsyncer.driver import HeadSyncer from chainsyncer.driver import HeadSyncer
from chainlib.eth.connection import HTTPConnection from chainlib.eth.connection import HTTPConnection
from chainlib.eth.gas import DefaultGasOracle from chainlib.eth.gas import DefaultGasOracle
from chainlib.eth.nonce import DefaultNonceOracle from chainlib.eth.nonce import DefaultNonceOracle
from chainlib.eth.block import block_latest from chainlib.eth.block import block_latest
from crypto_dev_signer.eth.signer import ReferenceSigner as EIP155Signer
from crypto_dev_signer.keystore import DictKeystore
from hexathon import strip_0x from hexathon import strip_0x
from cic_eth.api.api_task import Api
logging.basicConfig(level=logging.WARNING) # local imports
logg = logging.getLogger() import common
logging.getLogger('urllib3').setLevel(logging.CRITICAL) from cmd.traffic import (
logging.getLogger('websockets.protocol').setLevel(logging.CRITICAL) TrafficItem,
logging.getLogger('web3.RequestManager').setLevel(logging.CRITICAL) TrafficRouter,
logging.getLogger('web3.providers.WebsocketProvider').setLevel(logging.CRITICAL) TrafficProvisioner,
logging.getLogger('web3.providers.HTTPProvider').setLevel(logging.CRITICAL) TrafficSyncHandler,
)
script_dir = os.path.realpath(os.path.dirname(__file__)) script_dir = os.path.realpath(os.path.dirname(__file__))
default_data_dir = '/usr/local/share/cic/solidity/abi'
argparser = argparse.ArgumentParser() logg = common.log.create()
argparser.add_argument('-p', type=str, help='Ethereum provider url') argparser = common.argparse.create(script_dir, common.argparse.full_template)
argparser.add_argument('-r', type=str, help='cic-registry address')
argparser.add_argument('-y', '--key-file', dest='y', type=str, help='Ethereum keystore file to use for signing')
argparser.add_argument('-c', type=str, default='./config', help='config file')
argparser.add_argument('-q', type=str, default='cic-eth', help='celery queue to submit to')
argparser.add_argument('-i', '--chain-spec', dest='i', type=str, help='chain spec')
argparser.add_argument('-v', action='store_true', help='be verbose')
argparser.add_argument('-vv', action='store_true', help='be more verbose')
argparser.add_argument('--abi-dir', dest='abi_dir', type=str, help='Directory containing bytecode and abi')
argparser.add_argument('--env-prefix', default=os.environ.get('CONFINI_ENV_PREFIX'), dest='env_prefix', type=str, help='environment prefix for variables to overwrite configuration')
argparser.add_argument('--redis-host-callback', dest='redis_host_callback', default='localhost', type=str, help='redis host to use for callback')
argparser.add_argument('--redis-port-callback', dest='redis_port_callback', default=6379, type=int, help='redis port to use for callback')
argparser.add_argument('--batch-size', dest='batch_size', default=10, type=int, help='number of events to process simultaneously')
args = argparser.parse_args()
def subprocessor(subparser):
subparser.formatter_class = formatter_class=RawTextHelpFormatter
subparser.add_argument('--redis-host-callback', dest='redis_host_callback', default='localhost', type=str, help='redis host to use for callback')
subparser.add_argument('--redis-port-callback', dest='redis_port_callback', default=6379, type=int, help='redis port to use for callback')
subparser.add_argument('--batch-size', dest='batch_size', default=10, type=int, help='number of events to process simultaneously')
subparser.description = """Generates traffic on the cic network using dynamically loaded modules as event sources
# handle logging input """
if args.vv:
logging.getLogger().setLevel(logging.DEBUG)
elif args.v:
logging.getLogger().setLevel(logging.INFO)
# handle config input argparser = common.argparse.add(argparser, subprocessor, 'traffic')
config = confini.Config(args.c, args.env_prefix) args = common.argparse.parse(argparser, logg)
config.process() config = common.config.create(args.c, args, args.env_prefix)
args_override = {
'ETH_ABI_DIR': getattr(args, 'abi_dir'),
'ETH_PROVIDER': getattr(args, 'p'),
'CIC_CHAIN_SPEC': getattr(args, 'i'),
'CIC_REGISTRY_ADDRESS': getattr(args, 'r'),
}
config.dict_override(args_override, 'cli flag')
config.validate()
# handle batch size input # map custom args to local config entries
batchsize = args.batch_size batchsize = args.batch_size
if batchsize < 1: if batchsize < 1:
batchsize = 1 batchsize = 1
@ -89,385 +56,49 @@ config.add(batchsize, '_BATCH_SIZE', True)
config.add(args.redis_host_callback, '_REDIS_HOST_CALLBACK', True) config.add(args.redis_host_callback, '_REDIS_HOST_CALLBACK', True)
config.add(args.redis_port_callback, '_REDIS_PORT_CALLBACK', True) config.add(args.redis_port_callback, '_REDIS_PORT_CALLBACK', True)
# keystore
config.add(args.y, '_KEYSTORE_FILE', True)
# queue # queue
config.add(args.q, '_CELERY_QUEUE', True) config.add(args.q, '_CELERY_QUEUE', True)
# signer common.config.log(config)
keystore = DictKeystore()
if args.y == None:
logg.critical('please specify signer keystore file')
sys.exit(1)
logg.debug('loading keystore file {}'.format(args.y))
__signer_address = keystore.import_keystore_file(args.y)
config.add(__signer_address, '_SIGNER_ADDRESS')
logg.debug('now have key for signer address {}'.format(config.get('_SIGNER_ADDRESS')))
signer = EIP155Signer(keystore)
logg.debug('config:\n{}'.format(config)) def main():
# create signer (not currently in use, but needs to be accessible for custom traffic item generators)
(signer_address, signer) = common.signer.from_keystore(config.get('_KEYSTORE_FILE'))
# connect to celery
celery.Celery(broker=config.get('CELERY_BROKER_URL'), backend=config.get('CELERY_RESULT_URL'))
# web3 input # set up registry
# TODO: Replace with chainlib w3 = common.rpc.create(config.get('ETH_PROVIDER')) # replace with HTTPConnection when registry has been so refactored
re_websocket = r'^wss?:' registry = common.registry.init_legacy(config, w3)
re_http = r'^https?:'
blockchain_provider = None
if re.match(re_websocket, config.get('ETH_PROVIDER')):
blockchain_provider = web3.Web3.WebsocketProvider(config.get('ETH_PROVIDER'))
elif re.match(re_http, config.get('ETH_PROVIDER')):
blockchain_provider = web3.Web3.HTTPProvider(config.get('ETH_PROVIDER'))
w3 = web3.Web3(blockchain_provider)
# connect celery # Connect to blockchain with chainlib
celery_app = celery.Celery(broker=config.get('CELERY_BROKER_URL'), backend=config.get('CELERY_RESULT_URL'))
class TrafficItem:
def __init__(self, item):
self.method = item.do
self.uuid = uuid.uuid4()
self.ext = None
self.result = None
self.sender = None
self.recipient = None
self.source_token = None
self.destination_token = None
self.source_value = 0
def __str__(self):
return 'traffic item method {} uuid {}'.format(self.method, self.uuid)
class TrafficRouter:
def __init__(self, batch_size=1):
self.items = []
self.weights = []
self.total_weights = 0
self.batch_size = batch_size
self.reserved = {}
self.reserved_count = 0
self.traffic = {}
def add(self, item, weight):
self.weights.append(self.total_weights)
self.total_weights += weight
m = importlib.import_module(item)
self.items.append(m)
logg.debug('found traffic item {} weight {}'.format(k, v))
def reserve(self):
if len(self.reserved) == self.batch_size:
return None
n = random.randint(0, self.total_weights)
item = self.items[0]
for i in range(len(self.weights)):
if n <= self.weights[i]:
item = self.items[i]
break
ti = TrafficItem(item)
self.reserved[ti.uuid] = ti
return ti
def release(self, traffic_item):
del self.reserved[traffic_item.uuid]
# parse traffic items
traffic_router = TrafficRouter()
for k in config.all():
if len(k) > 8 and k[:8] == 'TRAFFIC_':
v = int(config.get(k))
try:
traffic_router.add(k[8:].lower(), v)
except ModuleNotFoundError as e:
logg.critical('requested traffic item module not found: {}'.format(e))
sys.exit(1)
class TrafficTasker:
oracles = {
'account': None,
'token': None,
}
default_aux = {
}
def __init__(self):
self.tokens = self.oracles['token'].get_tokens()
self.accounts = self.oracles['account'].get_accounts()
self.aux = copy.copy(self.default_aux)
self.__balances = {}
def load_balances(self):
pass
def __cache_balance(self, holder_address, token, value):
if self.__balances.get(holder_address) == None:
self.__balances[holder_address] = {}
self.__balances[holder_address][token] = value
logg.debug('setting cached balance of {} token {} to {}'.format(holder_address, token, value))
def add_aux(self, k, v):
logg.debug('added {} = {} to traffictasker'.format(k, v))
self.aux[k] = v
def balances(self, accounts=None, refresh=False):
if refresh:
if accounts == None:
accounts = self.accounts
for account in accounts:
for token in self.tokens:
value = self.balance(account, token)
self.__cache_balance(account, token.symbol(), value)
logg.debug('balance sender {} token {} = {}'.format(account, token, value))
else:
logg.debug('returning cached balances')
return self.__balances
def balance(self, account, token):
# TODO: use proper redis callback
api = Api(
str(self.aux['chain_spec']),
queue=self.aux['api_queue'],
#callback_param='{}:{}:{}:{}'.format(aux['redis_host_callback'], aux['redis_port_callback'], aux['redis_db'], aux['redis_channel']),
#callback_task='cic_eth.callbacks.redis.redis',
#callback_queue=queue,
)
t = api.balance(account, token.symbol())
r = t.get()
for c in t.collect():
r = c[1]
assert t.successful()
return r[0]
def update_balance(self, account, token, value):
self.__cache_balance(account, token.symbol(), value)
class Handler:
def __init__(self, config, traffic_router):
self.traffic_router = traffic_router
self.redis_channel = str(uuid.uuid4())
self.pubsub = self.__connect_redis(self.redis_channel, config)
self.traffic_items = {}
self.config = config
self.init = False
def __connect_redis(self, redis_channel, config):
r = redis.Redis(config.get('REDIS_HOST'), config.get('REDIS_PORT'), config.get('REDIS_DB'))
redis_pubsub = r.pubsub()
redis_pubsub.subscribe(redis_channel)
logg.debug('redis connected on channel {}'.format(redis_channel))
return redis_pubsub
def refresh(self, block_number, tx_index):
traffic_tasker = TrafficTasker()
traffic_tasker.add_aux('redis_channel', self.redis_channel)
refresh_balance = not self.init
balances = traffic_tasker.balances(refresh=refresh_balance)
self.init = True
if len(traffic_tasker.tokens) == 0:
logg.error('patiently waiting for at least one registered token...')
return
logg.debug('executing handler refresh with accouts {}'.format(traffic_tasker.accounts))
logg.debug('executing handler refresh with tokens {}'.format(traffic_tasker.tokens))
sender_indices = [*range(0, len(traffic_tasker.accounts))]
# TODO: only get balances for the selection that we will be generating for
while True:
traffic_item = traffic_router.reserve()
if traffic_item == None:
logg.debug('no traffic_items left to reserve {}'.format(traffic_item))
break
# TODO: temporary selection
token_pair = [
traffic_tasker.tokens[0],
traffic_tasker.tokens[0],
]
sender_index_index = random.randint(0, len(sender_indices)-1)
sender_index = sender_indices[sender_index_index]
sender = traffic_tasker.accounts[sender_index]
#balance_full = balances[sender][token_pair[0].symbol()]
if len(sender_indices) == 1:
sender_indices[m] = sender_sender_indices[len(senders)-1]
sender_indices = sender_indices[:len(sender_indices)-1]
balance_full = traffic_tasker.balance(sender, token_pair[0])
balance = balance_full['balance_network'] - balance_full['balance_outgoing']
recipient_index = random.randint(0, len(traffic_tasker.accounts)-1)
recipient = traffic_tasker.accounts[recipient_index]
(e, t, balance_result,) = traffic_item.method(
token_pair,
sender,
recipient,
balance,
traffic_tasker.aux,
block_number,
tx_index,
)
traffic_tasker.update_balance(sender, token_pair[0], balance_result)
sender_indices.append(recipient_index)
if e != None:
logg.info('failed {}: {}'.format(str(traffic_item), e))
self.traffic_router.release(traffic_item)
continue
if t == None:
logg.info('traffic method {} completed immediately')
self.traffic_router.release(traffic_item)
traffic_item.ext = t
self.traffic_items[traffic_item.ext] = traffic_item
while True:
m = self.pubsub.get_message(timeout=0.1)
if m == None:
break
logg.debug('redis message {}'.format(m))
if m['type'] == 'message':
message_data = json.loads(m['data'])
uu = message_data['root_id']
match_item = self.traffic_items[uu]
self.traffic_router.release(match_item)
if message_data['status'] == 0:
logg.error('task item {} failed with error code {}'.format(match_item, message_data['status']))
else:
match_item['result'] = message_data['result']
logg.debug('got callback result: {}'.format(match_item))
def name(self):
return 'traffic_item_handler'
def filter(self, conn, block, tx, session):
logg.debug('handler get {}'.format(tx))
class TokenOracle:
def __init__(self, chain_spec, registry):
self.tokens = []
self.chain_spec = chain_spec
self.registry = registry
token_registry_contract = CICRegistry.get_contract(chain_spec, 'TokenRegistry', 'Registry')
self.getter = TokenUniqueSymbolIndex(w3, token_registry_contract.address())
def get_tokens(self):
token_count = self.getter.count()
if token_count == len(self.tokens):
return self.tokens
for i in range(len(self.tokens), token_count):
token_address = self.getter.get_index(i)
t = self.registry.get_address(self.chain_spec, token_address)
token_symbol = t.symbol()
self.tokens.append(t)
logg.debug('adding token idx {} symbol {} address {}'.format(i, token_symbol, token_address))
return copy.copy(self.tokens)
class AccountsOracle:
def __init__(self, chain_spec, registry):
self.accounts = []
self.chain_spec = chain_spec
self.registry = registry
accounts_registry_contract = CICRegistry.get_contract(chain_spec, 'AccountRegistry', 'Registry')
self.getter = AccountRegistry(w3, accounts_registry_contract.address())
def get_accounts(self):
accounts_count = self.getter.count()
if accounts_count == len(self.accounts):
return self.accounts
for i in range(len(self.accounts), accounts_count):
account = self.getter.get_index(i)
self.accounts.append(account)
logg.debug('adding account {}'.format(account))
return copy.copy(self.accounts)
def main(local_config=None):
if local_config != None:
config = local_config
chain_spec = ChainSpec.from_chain_str(config.get('CIC_CHAIN_SPEC'))
CICRegistry.init(w3, config.get('CIC_REGISTRY_ADDRESS'), chain_spec)
CICRegistry.add_path(config.get('ETH_ABI_DIR'))
chain_registry = ChainRegistry(chain_spec)
CICRegistry.add_chain_registry(chain_registry, True)
declarator = CICRegistry.get_contract(chain_spec, 'AddressDeclarator', interface='Declarator')
trusted_addresses_src = config.get('CIC_TRUST_ADDRESS')
if trusted_addresses_src == None:
logg.critical('At least one trusted address must be declared in CIC_TRUST_ADDRESS')
sys.exit(1)
trusted_addresses = trusted_addresses_src.split(',')
for address in trusted_addresses:
logg.info('using trusted address {}'.format(address))
oracle = DeclaratorOracleAdapter(declarator.contract, trusted_addresses)
chain_registry.add_oracle(oracle, 'naive_erc20_oracle')
# Connect to blockchain
conn = HTTPConnection(config.get('ETH_PROVIDER')) conn = HTTPConnection(config.get('ETH_PROVIDER'))
gas_oracle = DefaultGasOracle(conn) gas_oracle = DefaultGasOracle(conn)
nonce_oracle = DefaultNonceOracle(config.get('_SIGNER_ADDRESS'), conn) nonce_oracle = DefaultNonceOracle(signer_address, conn)
# Set up magic traffic handler # Set up magic traffic handler
handler = Handler(config, traffic_router) traffic_router = TrafficRouter()
traffic_router.apply_import_dict(config.all(), config)
handler = TrafficSyncHandler(config, traffic_router)
# Set up syncer # Set up syncer
syncer_backend = MemBackend(str(chain_spec), 0) syncer_backend = MemBackend(config.get('CIC_CHAIN_SPEC'), 0)
o = block_latest() o = block_latest()
r = conn.do(o) r = conn.do(o)
block_offset = int(strip_0x(r), 16) + 1 block_offset = int(strip_0x(r), 16) + 1
syncer_backend.set(block_offset, 0) syncer_backend.set(block_offset, 0)
TrafficTasker.oracles['token']= TokenOracle(chain_spec, CICRegistry) # Set up provisioner for common task input data
TrafficTasker.oracles['account'] = AccountsOracle(chain_spec, CICRegistry) TrafficProvisioner.oracles['token']= common.registry.TokenOracle(w3, config.get('CIC_CHAIN_SPEC'), registry)
TrafficTasker.default_aux = { TrafficProvisioner.oracles['account'] = common.registry.AccountsOracle(w3, config.get('CIC_CHAIN_SPEC'), registry)
'chain_spec': chain_spec, TrafficProvisioner.default_aux = {
'registry': CICRegistry, 'chain_spec': config.get('CIC_CHAIN_SPEC'),
'registry': registry,
'redis_host_callback': config.get('_REDIS_HOST_CALLBACK'), 'redis_host_callback': config.get('_REDIS_HOST_CALLBACK'),
'redis_port_callback': config.get('_REDIS_PORT_CALLBACK'), 'redis_port_callback': config.get('_REDIS_PORT_CALLBACK'),
'redis_db': config.get('REDIS_DB'), 'redis_db': config.get('REDIS_DB'),
@ -478,5 +109,6 @@ def main(local_config=None):
syncer.add_filter(handler) syncer.add_filter(handler)
syncer.loop(1, conn) syncer.loop(1, conn)
if __name__ == '__main__': if __name__ == '__main__':
main(config) main()