Improve reuse of db sessions

This commit is contained in:
nolash 2021-02-18 21:44:49 +01:00
parent 725ef54cf5
commit 1a31876e2e
Signed by untrusted user who does not match committer: lash
GPG Key ID: 21D2E7BB88C2A746
9 changed files with 71 additions and 49 deletions

View File

@ -6,6 +6,11 @@ from sqlalchemy import Column, Integer
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import (
StaticPool,
QueuePool,
AssertionPool,
)
logg = logging.getLogger() logg = logging.getLogger()
@ -49,7 +54,7 @@ class SessionBase(Model):
@staticmethod @staticmethod
def connect(dsn, debug=False): def connect(dsn, pool_size=8, debug=False):
"""Create new database connection engine and connect to database backend. """Create new database connection engine and connect to database backend.
:param dsn: DSN string defining connection. :param dsn: DSN string defining connection.
@ -57,14 +62,28 @@ class SessionBase(Model):
""" """
e = None e = None
if SessionBase.poolable: if SessionBase.poolable:
e = create_engine( poolclass = QueuePool
dsn, if pool_size > 1:
max_overflow=50, e = create_engine(
pool_pre_ping=True, dsn,
pool_size=20, max_overflow=pool_size*3,
pool_recycle=10, pool_pre_ping=True,
echo=debug, pool_size=pool_size,
) pool_recycle=60,
poolclass=poolclass,
echo=debug,
)
else:
if debug:
poolclass = AssertionPool
else:
poolclass = StaticPool
e = create_engine(
dsn,
poolclass=poolclass,
echo=debug,
)
else: else:
e = create_engine( e = create_engine(
dsn, dsn,

View File

@ -85,18 +85,18 @@ class TxCache(SessionBase):
:param tx_hash_new: tx hash to associate the copied entry with :param tx_hash_new: tx hash to associate the copied entry with
:type tx_hash_new: str, 0x-hex :type tx_hash_new: str, 0x-hex
""" """
localsession = session localsession = SessionBase.bind_session(session)
if localsession == None:
localsession = SessionBase.create_session()
q = localsession.query(TxCache) q = localsession.query(TxCache)
q = q.join(Otx) q = q.join(Otx)
q = q.filter(Otx.tx_hash==tx_hash_original) q = q.filter(Otx.tx_hash==tx_hash_original)
txc = q.first() txc = q.first()
if txc == None: if txc == None:
SessionBase.release_session(localsession)
raise NotLocalTxError('original {}'.format(tx_hash_original)) raise NotLocalTxError('original {}'.format(tx_hash_original))
if txc.block_number != None: if txc.block_number != None:
SessionBase.release_session(localsession)
raise TxStateChangeError('cannot clone tx cache of confirmed tx {}'.format(tx_hash_original)) raise TxStateChangeError('cannot clone tx cache of confirmed tx {}'.format(tx_hash_original))
q = localsession.query(Otx) q = localsession.query(Otx)
@ -104,6 +104,7 @@ class TxCache(SessionBase):
otx = q.first() otx = q.first()
if otx == None: if otx == None:
SessionBase.release_session(localsession)
raise NotLocalTxError('new {}'.format(tx_hash_new)) raise NotLocalTxError('new {}'.format(tx_hash_new))
txc_new = TxCache( txc_new = TxCache(
@ -118,15 +119,14 @@ class TxCache(SessionBase):
localsession.add(txc_new) localsession.add(txc_new)
localsession.commit() localsession.commit()
if session == None: SessionBase.release_session(localsession)
localsession.close()
def __init__(self, tx_hash, sender, recipient, source_token_address, destination_token_address, from_value, to_value, block_number=None, tx_index=None): def __init__(self, tx_hash, sender, recipient, source_token_address, destination_token_address, from_value, to_value, block_number=None, tx_index=None, session=None):
session = SessionBase.create_session() localsession = SessionBase.bind_session(session)
tx = session.query(Otx).filter(Otx.tx_hash==tx_hash).first() tx = localsession.query(Otx).filter(Otx.tx_hash==tx_hash).first()
if tx == None: if tx == None:
session.close() SessionBase.release_session(localsession)
raise FileNotFoundError('outgoing transaction record unknown {} (add a Tx first)'.format(tx_hash)) raise FileNotFoundError('outgoing transaction record unknown {} (add a Tx first)'.format(tx_hash))
self.otx_id = tx.id self.otx_id = tx.id
@ -143,4 +143,5 @@ class TxCache(SessionBase):
self.date_updated = self.date_created self.date_updated = self.date_created
self.date_checked = self.date_created self.date_checked = self.date_created
SessionBase.release_session(localsession)

View File

@ -304,6 +304,8 @@ def cache_gift_data(
tx = unpack_signed_raw_tx(tx_signed_raw_bytes, chain_spec.chain_id()) tx = unpack_signed_raw_tx(tx_signed_raw_bytes, chain_spec.chain_id())
tx_data = unpack_gift(tx['data']) tx_data = unpack_gift(tx['data'])
session = SessionBase.create_session()
tx_cache = TxCache( tx_cache = TxCache(
tx_hash_hex, tx_hash_hex,
tx['from'], tx['from'],
@ -312,9 +314,9 @@ def cache_gift_data(
zero_address, zero_address,
0, 0,
0, 0,
session=session,
) )
session = SessionBase.create_session()
session.add(tx_cache) session.add(tx_cache)
session.commit() session.commit()
cache_id = tx_cache.id cache_id = tx_cache.id
@ -347,6 +349,7 @@ def cache_account_data(
tx = unpack_signed_raw_tx(tx_signed_raw_bytes, chain_spec.chain_id()) tx = unpack_signed_raw_tx(tx_signed_raw_bytes, chain_spec.chain_id())
tx_data = unpack_register(tx['data']) tx_data = unpack_register(tx['data'])
session = SessionBase.create_session()
tx_cache = TxCache( tx_cache = TxCache(
tx_hash_hex, tx_hash_hex,
tx['from'], tx['from'],
@ -355,9 +358,8 @@ def cache_account_data(
zero_address, zero_address,
0, 0,
0, 0,
session=session,
) )
session = SessionBase.create_session()
session.add(tx_cache) session.add(tx_cache)
session.commit() session.commit()
cache_id = tx_cache.id cache_id = tx_cache.id

View File

@ -381,6 +381,7 @@ def cache_transfer_data(
tx['to'], tx['to'],
tx_data['amount'], tx_data['amount'],
tx_data['amount'], tx_data['amount'],
session=session,
) )
session.add(tx_cache) session.add(tx_cache)
session.commit() session.commit()
@ -440,6 +441,7 @@ def cache_approve_data(
tx['to'], tx['to'],
tx_data['amount'], tx_data['amount'],
tx_data['amount'], tx_data['amount'],
session=session,
) )
session.add(tx_cache) session.add(tx_cache)
session.commit() session.commit()

View File

@ -78,7 +78,6 @@ def check_gas(self, tx_hashes, chain_str, txs=[], address=None, gas_required=Non
# TODO: it should not be necessary to pass address explicitly, if not passed should be derived from the tx # TODO: it should not be necessary to pass address explicitly, if not passed should be derived from the tx
balance = c.w3.eth.getBalance(address) balance = c.w3.eth.getBalance(address)
logg.debug('check gas txs {}'.format(tx_hashes))
logg.debug('address {} has gas {} needs {}'.format(address, balance, gas_required)) logg.debug('address {} has gas {} needs {}'.format(address, balance, gas_required))
if gas_required > balance: if gas_required > balance:
@ -126,7 +125,6 @@ def check_gas(self, tx_hashes, chain_str, txs=[], address=None, gas_required=Non
queue=queue, queue=queue,
) )
ready_tasks.append(s) ready_tasks.append(s)
logg.debug('tasks {}'.format(ready_tasks))
celery.group(ready_tasks)() celery.group(ready_tasks)()
return txs return txs
@ -143,7 +141,6 @@ def hashes_to_txs(self, tx_hashes):
:returns: Signed raw transactions :returns: Signed raw transactions
:rtype: list of str, 0x-hex :rtype: list of str, 0x-hex
""" """
#logg = celery_app.log.get_default_logger()
if len(tx_hashes) == 0: if len(tx_hashes) == 0:
raise ValueError('no transaction to send') raise ValueError('no transaction to send')
@ -351,15 +348,12 @@ def send(self, txs, chain_str):
tx_hash_hex = tx_hash.hex() tx_hash_hex = tx_hash.hex()
queue = self.request.delivery_info.get('routing_key', None) queue = self.request.delivery_info.get('routing_key', None)
if queue == None:
logg.debug('send tx {} has no queue', tx_hash)
c = RpcClient(chain_spec) c = RpcClient(chain_spec)
r = None r = None
try: try:
r = c.w3.eth.send_raw_transaction(tx_hex) r = c.w3.eth.send_raw_transaction(tx_hex)
except Exception as e: except Exception as e:
logg.debug('e {}'.format(e))
raiser = ParityNodeHandler(chain_spec, queue) raiser = ParityNodeHandler(chain_spec, queue)
(t, e, m) = raiser.handle(e, tx_hash_hex, tx_hex) (t, e, m) = raiser.handle(e, tx_hash_hex, tx_hex)
raise e(m) raise e(m)
@ -423,7 +417,7 @@ def refill_gas(self, recipient_address, chain_str):
gas_price = c.gas_price() gas_price = c.gas_price()
gas_limit = c.default_gas_limit gas_limit = c.default_gas_limit
refill_amount = c.refill_amount() refill_amount = c.refill_amount()
logg.debug('gas price {} nonce {}'.format(gas_price, nonce)) logg.debug('tx send gas price {} nonce {}'.format(gas_price, nonce))
# create and sign transaction # create and sign transaction
tx_send_gas = { tx_send_gas = {
@ -436,7 +430,6 @@ def refill_gas(self, recipient_address, chain_str):
'value': refill_amount, 'value': refill_amount,
'data': '', 'data': '',
} }
logg.debug('txsend_gas {}'.format(tx_send_gas))
tx_send_gas_signed = c.w3.eth.sign_transaction(tx_send_gas) tx_send_gas_signed = c.w3.eth.sign_transaction(tx_send_gas)
tx_hash = web3.Web3.keccak(hexstr=tx_send_gas_signed['raw']) tx_hash = web3.Web3.keccak(hexstr=tx_send_gas_signed['raw'])
tx_hash_hex = tx_hash.hex() tx_hash_hex = tx_hash.hex()
@ -487,11 +480,14 @@ def resend_with_higher_gas(self, txold_hash_hex, chain_str, gas=None, default_fa
:rtype: str, 0x-hex :rtype: str, 0x-hex
""" """
session = SessionBase.create_session() session = SessionBase.create_session()
otx = session.query(Otx).filter(Otx.tx_hash==txold_hash_hex).first()
if otx == None:
session.close() q = session.query(Otx)
raise NotLocalTxError(txold_hash_hex) q = q.filter(Otx.tx_hash==txold_hash_hex)
otx = q.first()
session.close() session.close()
if otx == None:
raise NotLocalTxError(txold_hash_hex)
chain_spec = ChainSpec.from_chain_str(chain_str) chain_spec = ChainSpec.from_chain_str(chain_str)
c = RpcClient(chain_spec) c = RpcClient(chain_spec)
@ -508,7 +504,7 @@ def resend_with_higher_gas(self, txold_hash_hex, chain_str, gas=None, default_fa
else: else:
gas_price = c.gas_price() gas_price = c.gas_price()
if tx['gasPrice'] > gas_price: if tx['gasPrice'] > gas_price:
logg.warning('Network gas price {} is lower than overdue tx gas price {}'.format(gas_price, tx['gasPrice'])) logg.info('Network gas price {} is lower than overdue tx gas price {}'.format(gas_price, tx['gasPrice']))
#tx['gasPrice'] = int(tx['gasPrice'] * default_factor) #tx['gasPrice'] = int(tx['gasPrice'] * default_factor)
tx['gasPrice'] += 1 tx['gasPrice'] += 1
else: else:
@ -518,9 +514,6 @@ def resend_with_higher_gas(self, txold_hash_hex, chain_str, gas=None, default_fa
else: else:
tx['gasPrice'] = new_gas_price tx['gasPrice'] = new_gas_price
logg.debug('after {}'.format(tx))
#(tx_hash_hex, tx_signed_raw_hex) = sign_and_register_tx(tx, chain_str, queue)
(tx_hash_hex, tx_signed_raw_hex) = sign_tx(tx, chain_str) (tx_hash_hex, tx_signed_raw_hex) = sign_tx(tx, chain_str)
queue_create( queue_create(
tx['nonce'], tx['nonce'],
@ -540,6 +533,7 @@ def resend_with_higher_gas(self, txold_hash_hex, chain_str, gas=None, default_fa
queue=queue, queue=queue,
) )
s.apply_async() s.apply_async()
return tx_hash_hex return tx_hash_hex
@ -602,7 +596,9 @@ def resume_tx(self, txpending_hash_hex, chain_str):
chain_spec = ChainSpec.from_chain_str(chain_str) chain_spec = ChainSpec.from_chain_str(chain_str)
session = SessionBase.create_session() session = SessionBase.create_session()
r = session.query(Otx.signed_tx).filter(Otx.tx_hash==txpending_hash_hex).first() q = session.query(Otx.signed_tx)
q = q.filter(Otx.tx_hash==txpending_hash_hex)
r = q.first()
session.close() session.close()
if r == None: if r == None:
raise NotLocalTxError(txpending_hash_hex) raise NotLocalTxError(txpending_hash_hex)

View File

@ -35,8 +35,7 @@ celery_app = celery.current_app
logg = logging.getLogger() logg = logging.getLogger()
@celery_app.task() def create(nonce, holder_address, tx_hash, signed_tx, chain_str, obsolete_predecessors=True, session=None):
def create(nonce, holder_address, tx_hash, signed_tx, chain_str, obsolete_predecessors=True):
"""Create a new transaction queue record. """Create a new transaction queue record.
:param nonce: Transaction nonce :param nonce: Transaction nonce
@ -52,10 +51,10 @@ def create(nonce, holder_address, tx_hash, signed_tx, chain_str, obsolete_predec
:returns: transaction hash :returns: transaction hash
:rtype: str, 0x-hash :rtype: str, 0x-hash
""" """
session = SessionBase.create_session() session = SessionBase.bind_session(session)
lock = Lock.check_aggregate(chain_str, LockEnum.QUEUE, holder_address, session=session) lock = Lock.check_aggregate(chain_str, LockEnum.QUEUE, holder_address, session=session)
if lock > 0: if lock > 0:
session.close() SessionBase.release_session(session)
raise LockedError(lock) raise LockedError(lock)
o = Otx.add( o = Otx.add(
@ -81,7 +80,7 @@ def create(nonce, holder_address, tx_hash, signed_tx, chain_str, obsolete_predec
otx.cancel(confirmed=False, session=session) otx.cancel(confirmed=False, session=session)
session.commit() session.commit()
session.close() SessionBase.release_session(session)
logg.debug('queue created nonce {} from {} hash {}'.format(nonce, holder_address, tx_hash)) logg.debug('queue created nonce {} from {} hash {}'.format(nonce, holder_address, tx_hash))
return tx_hash return tx_hash
@ -100,7 +99,9 @@ def set_sent_status(tx_hash, fail=False):
:rtype: boolean :rtype: boolean
""" """
session = SessionBase.create_session() session = SessionBase.create_session()
o = session.query(Otx).filter(Otx.tx_hash==tx_hash).first() q = session.query(Otx)
q = q.filter(Otx.tx_hash==tx_hash)
o = q.first()
if o == None: if o == None:
logg.warning('not local tx, skipping {}'.format(tx_hash)) logg.warning('not local tx, skipping {}'.format(tx_hash))
session.close() session.close()
@ -454,6 +455,7 @@ def get_tx(tx_hash):
session = SessionBase.create_session() session = SessionBase.create_session()
tx = session.query(Otx).filter(Otx.tx_hash==tx_hash).first() tx = session.query(Otx).filter(Otx.tx_hash==tx_hash).first()
if tx == None: if tx == None:
session.close()
raise NotLocalTxError('queue does not contain tx hash {}'.format(tx_hash)) raise NotLocalTxError('queue does not contain tx hash {}'.format(tx_hash))
o = { o = {
@ -602,6 +604,7 @@ def get_upcoming_tx(status=StatusEnum.READYSEND, recipient=None, before=None, ch
q_outer = q_outer.filter(or_(Lock.flags==None, Lock.flags.op('&')(LockEnum.SEND.value)==0)) q_outer = q_outer.filter(or_(Lock.flags==None, Lock.flags.op('&')(LockEnum.SEND.value)==0))
if not is_alive(status): if not is_alive(status):
session.close()
raise ValueError('not a valid non-final tx value: {}'.format(status)) raise ValueError('not a valid non-final tx value: {}'.format(status))
if status == StatusEnum.PENDING: if status == StatusEnum.PENDING:
q_outer = q_outer.filter(Otx.status==status.value) q_outer = q_outer.filter(Otx.status==status.value)

View File

@ -37,7 +37,7 @@ class CallbackFilter(SyncFilter):
transfer_type, transfer_type,
int(rcpt.status == 0), int(rcpt.status == 0),
], ],
queue=tc.queue, queue=self.queue,
) )
# s_translate = celery.signature( # s_translate = celery.signature(
# 'cic_eth.ext.address.translate', # 'cic_eth.ext.address.translate',

View File

@ -118,7 +118,7 @@ declarator = CICRegistry.get_contract(chain_spec, 'AddressDeclarator', interface
dsn = dsn_from_config(config) dsn = dsn_from_config(config)
SessionBase.connect(dsn) SessionBase.connect(dsn, pool_size=1, debug=config.true('DATABASE_DEBUG'))
def main(): def main():

View File

@ -78,7 +78,7 @@ logg.debug('config loaded from {}:\n{}'.format(args.c, config))
# connect to database # connect to database
dsn = dsn_from_config(config) dsn = dsn_from_config(config)
SessionBase.connect(dsn) SessionBase.connect(dsn, pool)
# verify database connection with minimal sanity query # verify database connection with minimal sanity query
session = SessionBase.create_session() session = SessionBase.create_session()
@ -179,7 +179,6 @@ def web3ext_constructor():
return (blockchain_provider, w3) return (blockchain_provider, w3)
RpcClient.set_constructor(web3ext_constructor) RpcClient.set_constructor(web3ext_constructor)
logg.info('ccc {}'.format(config.store['TASKS_TRACE_QUEUE_STATUS']))
Otx.tracing = config.true('TASKS_TRACE_QUEUE_STATUS') Otx.tracing = config.true('TASKS_TRACE_QUEUE_STATUS')