Accept external db sessions in methods

This commit is contained in:
nolash 2021-04-03 00:17:05 +02:00
parent 73f097f2be
commit 6271cd5d3d
Signed by: lash
GPG Key ID: 21D2E7BB88C2A746
4 changed files with 64 additions and 65 deletions

View File

@ -91,20 +91,19 @@ class TxCache(SessionBase):
q = session.query(TxCache) q = session.query(TxCache)
q = q.join(Otx) q = q.join(Otx)
q = q.filter(Otx.tx_hash==tx_hash_original) q = q.filter(Otx.tx_hash==strip_0x(tx_hash_original))
txc = q.first() txc = q.first()
if txc == None: if txc == None:
SessionBase.release_session(session) SessionBase.release_session(session)
raise NotLocalTxError('original {}'.format(tx_hash_original)) raise NotLocalTxError('original {}'.format(tx_hash_original))
if txc.block_number != None: if txc.tx_index != None:
SessionBase.release_session(session) SessionBase.release_session(session)
raise TxStateChangeError('cannot clone tx cache of confirmed tx {}'.format(tx_hash_original)) raise TxStateChangeError('cannot clone tx cache of confirmed tx {}'.format(tx_hash_original))
session.flush() session.flush()
q = session.query(Otx)
q = q.filter(Otx.tx_hash==tx_hash_new) otx = Otx.load(tx_hash_new, session=session)
otx = q.first()
if otx == None: if otx == None:
SessionBase.release_session(session) SessionBase.release_session(session)

View File

@ -28,7 +28,7 @@ from chainqueue.error import (
logg = logging.getLogger().getChild(__name__) logg = logging.getLogger().getChild(__name__)
def get_tx_cache(chain_spec, tx_hash): def get_tx_cache(chain_spec, tx_hash, session=None):
"""Returns an aggregate dictionary of outgoing transaction data and metadata """Returns an aggregate dictionary of outgoing transaction data and metadata
:param tx_hash: Transaction hash of record to modify :param tx_hash: Transaction hash of record to modify
@ -37,11 +37,11 @@ def get_tx_cache(chain_spec, tx_hash):
:returns: Transaction data :returns: Transaction data
:rtype: dict :rtype: dict
""" """
session = SessionBase.create_session() session = SessionBase.bind_session(session)
otx = Otx.load(tx_hash, session=session) otx = Otx.load(tx_hash, session=session)
if otx == None: if otx == None:
session.close() SessionBase.release_session(session)
raise NotLocalTxError(tx_hash) raise NotLocalTxError(tx_hash)
session.flush() session.flush()
@ -50,7 +50,7 @@ def get_tx_cache(chain_spec, tx_hash):
q = q.filter(TxCache.otx_id==otx.id) q = q.filter(TxCache.otx_id==otx.id)
txc = q.first() txc = q.first()
session.close() SessionBase.release_session(session)
# TODO: DRY, get_tx_cache / get_tx # TODO: DRY, get_tx_cache / get_tx
tx = { tx = {
@ -75,7 +75,7 @@ def get_tx_cache(chain_spec, tx_hash):
return tx return tx
def get_tx(chain_spec, tx_hash): def get_tx(chain_spec, tx_hash, session=None):
"""Retrieve a transaction queue record by transaction hash """Retrieve a transaction queue record by transaction hash
:param tx_hash: Transaction hash of record to modify :param tx_hash: Transaction hash of record to modify
@ -84,10 +84,10 @@ def get_tx(chain_spec, tx_hash):
:returns: nonce, address and signed_tx (raw signed transaction) :returns: nonce, address and signed_tx (raw signed transaction)
:rtype: dict :rtype: dict
""" """
session = SessionBase.create_session() session = SessionBase.bind_session(session)
otx = Otx.load(tx_hash, session=session) otx = Otx.load(tx_hash, session=session)
if otx == None: if otx == None:
session.close() SessionBase.release_session(session)
raise NotLocalTxError('queue does not contain tx hash {}'.format(tx_hash)) raise NotLocalTxError('queue does not contain tx hash {}'.format(tx_hash))
o = { o = {
@ -97,11 +97,11 @@ def get_tx(chain_spec, tx_hash):
'status': otx.status, 'status': otx.status,
} }
logg.debug('get tx {}'.format(o)) logg.debug('get tx {}'.format(o))
session.close() SessionBase.release_session(session)
return o return o
def get_nonce_tx_cache(chain_spec, nonce, sender, decoder=None): def get_nonce_tx_cache(chain_spec, nonce, sender, decoder=None, session=None):
"""Retrieve all transactions for address with specified nonce """Retrieve all transactions for address with specified nonce
:param nonce: Nonce :param nonce: Nonce
@ -113,7 +113,7 @@ def get_nonce_tx_cache(chain_spec, nonce, sender, decoder=None):
""" """
chain_id = chain_spec.chain_id() chain_id = chain_spec.chain_id()
session = SessionBase.create_session() session = SessionBase.bind_session(session)
q = session.query(Otx) q = session.query(Otx)
q = q.join(TxCache) q = q.join(TxCache)
q = q.filter(TxCache.sender==sender) q = q.filter(TxCache.sender==sender)
@ -128,7 +128,7 @@ def get_nonce_tx_cache(chain_spec, nonce, sender, decoder=None):
raise IntegrityError('Cache sender {} does not match sender in tx {} using decoder {}'.format(sender, r.tx_hash, str(decoder))) raise IntegrityError('Cache sender {} does not match sender in tx {} using decoder {}'.format(sender, r.tx_hash, str(decoder)))
txs[r.tx_hash] = r.signed_tx txs[r.tx_hash] = r.signed_tx
session.close() SessionBase.release_session(session)
return txs return txs
@ -297,7 +297,7 @@ def get_upcoming_tx(chain_spec, status=StatusEnum.READYSEND, not_status=None, re
return txs return txs
def get_account_tx(chain_spec, address, as_sender=True, as_recipient=True, counterpart=None): def get_account_tx(chain_spec, address, as_sender=True, as_recipient=True, counterpart=None, session=None):
"""Returns all local queue transactions for a given Ethereum address """Returns all local queue transactions for a given Ethereum address
:param address: Ethereum address :param address: Ethereum address
@ -317,7 +317,7 @@ def get_account_tx(chain_spec, address, as_sender=True, as_recipient=True, count
txs = {} txs = {}
session = SessionBase.create_session() session = SessionBase.bind_session(session)
q = session.query(Otx) q = session.query(Otx)
q = q.join(TxCache) q = q.join(TxCache)
if as_sender and as_recipient: if as_sender and as_recipient:
@ -334,6 +334,7 @@ def get_account_tx(chain_spec, address, as_sender=True, as_recipient=True, count
logg.debug('tx {} already recorded'.format(r.tx_hash)) logg.debug('tx {} already recorded'.format(r.tx_hash))
continue continue
txs[r.tx_hash] = r.signed_tx txs[r.tx_hash] = r.signed_tx
session.close()
SessionBase.release_session(session)
return txs return txs

View File

@ -22,7 +22,7 @@ from chainqueue.error import (
logg = logging.getLogger().getChild(__name__) logg = logging.getLogger().getChild(__name__)
def set_sent(tx_hash, fail=False): def set_sent(chain_spec, tx_hash, fail=False, session=None):
"""Used to set the status after a send attempt """Used to set the status after a send attempt
:param tx_hash: Transaction hash of record to modify :param tx_hash: Transaction hash of record to modify
@ -33,11 +33,11 @@ def set_sent(tx_hash, fail=False):
:returns: True if tx is known, False otherwise :returns: True if tx is known, False otherwise
:rtype: boolean :rtype: boolean
""" """
session = SessionBase.create_session() session = SessionBase.bind_session(session)
o = Otx.load(tx_hash, session=session) o = Otx.load(tx_hash, session=session)
if o == None: if o == None:
logg.warning('not local tx, skipping {}'.format(tx_hash)) logg.warning('not local tx, skipping {}'.format(tx_hash))
session.close() SessionBase.release_session(session)
return False return False
try: try:
@ -47,21 +47,21 @@ def set_sent(tx_hash, fail=False):
o.sent(session=session) o.sent(session=session)
except TxStateChangeError as e: except TxStateChangeError as e:
logg.exception('set sent fail: {}'.format(e)) logg.exception('set sent fail: {}'.format(e))
session.close() SessionBase.release_session(session)
raise(e) raise(e)
except Exception as e: except Exception as e:
logg.exception('set sent UNEXPECED fail: {}'.format(e)) logg.exception('set sent UNEXPECED fail: {}'.format(e))
session.close() SessionBase.release_session(session)
raise(e) raise(e)
session.commit() session.commit()
session.close() SessionBase.release_session(session)
return tx_hash return tx_hash
def set_final(tx_hash, block=None, fail=False): def set_final(chain_spec, tx_hash, block=None, fail=False, session=None):
"""Used to set the status of an incoming transaction result. """Used to set the status of an incoming transaction result.
:param tx_hash: Transaction hash of record to modify :param tx_hash: Transaction hash of record to modify
@ -72,11 +72,11 @@ def set_final(tx_hash, block=None, fail=False):
:type fail: boolean :type fail: boolean
:raises NotLocalTxError: If transaction not found in queue. :raises NotLocalTxError: If transaction not found in queue.
""" """
session = SessionBase.create_session() session = SessionBase.bind_session(session)
o = Otx.load(tx_hash, session=session) o = Otx.load(tx_hash, session=session)
if o == None: if o == None:
session.close() SessionBase.release_session(session)
raise NotLocalTxError('queue does not contain tx hash {}'.format(tx_hash)) raise NotLocalTxError('queue does not contain tx hash {}'.format(tx_hash))
session.flush() session.flush()
@ -89,11 +89,11 @@ def set_final(tx_hash, block=None, fail=False):
session.commit() session.commit()
except TxStateChangeError as e: except TxStateChangeError as e:
logg.exception('set final fail: {}'.format(e)) logg.exception('set final fail: {}'.format(e))
session.close() SessionBase.release_session(session)
raise(e) raise(e)
except Exception as e: except Exception as e:
logg.exception('set final UNEXPECTED fail: {}'.format(e)) logg.exception('set final UNEXPECTED fail: {}'.format(e))
session.close() SessionBase.release_session(session)
raise(e) raise(e)
q = session.query(TxCache) q = session.query(TxCache)
@ -101,12 +101,12 @@ def set_final(tx_hash, block=None, fail=False):
q = q.filter(Otx.tx_hash==strip_0x(tx_hash)) q = q.filter(Otx.tx_hash==strip_0x(tx_hash))
o = q.first() o = q.first()
session.close() SessionBase.release_session(session)
return tx_hash return tx_hash
def set_cancel(tx_hash, manual=False): def set_cancel(chain_spec, tx_hash, manual=False, session=None):
"""Used to set the status when a transaction is cancelled. """Used to set the status when a transaction is cancelled.
Will set the state to CANCELLED or OVERRIDDEN Will set the state to CANCELLED or OVERRIDDEN
@ -118,10 +118,10 @@ def set_cancel(tx_hash, manual=False):
:raises NotLocalTxError: If transaction not found in queue. :raises NotLocalTxError: If transaction not found in queue.
""" """
session = SessionBase.create_session() session = SessionBase.bind_session(session)
o = Otx.load(tx_hash, session=session) o = Otx.load(tx_hash, session=session)
if o == None: if o == None:
session.close() SessionBase.release_session(session)
raise NotLocalTxError('queue does not contain tx hash {}'.format(tx_hash)) raise NotLocalTxError('queue does not contain tx hash {}'.format(tx_hash))
session.flush() session.flush()
@ -136,12 +136,12 @@ def set_cancel(tx_hash, manual=False):
logg.exception('set cancel fail: {}'.format(e)) logg.exception('set cancel fail: {}'.format(e))
except Exception as e: except Exception as e:
logg.exception('set cancel UNEXPECTED fail: {}'.format(e)) logg.exception('set cancel UNEXPECTED fail: {}'.format(e))
session.close() SessionBase.release_session(session)
return tx_hash return tx_hash
def set_rejected(tx_hash): def set_rejected(chain_spec, tx_hash, session=None):
"""Used to set the status when the node rejects sending a transaction to network """Used to set the status when the node rejects sending a transaction to network
Will set the state to REJECTED Will set the state to REJECTED
@ -151,22 +151,22 @@ def set_rejected(tx_hash):
:raises NotLocalTxError: If transaction not found in queue. :raises NotLocalTxError: If transaction not found in queue.
""" """
session = SessionBase.create_session() session = SessionBase.bind_session(session)
o = Otx.load(tx_hash, session=session) o = Otx.load(tx_hash, session=session)
if o == None: if o == None:
session.close() SessionBase.release_session(session)
raise NotLocalTxError('queue does not contain tx hash {}'.format(tx_hash)) raise NotLocalTxError('queue does not contain tx hash {}'.format(tx_hash))
session.flush() session.flush()
o.reject(session=session) o.reject(session=session)
session.commit() session.commit()
session.close() SessionBase.release_session(session)
return tx_hash return tx_hash
def set_fubar(tx_hash): def set_fubar(chain_spec, tx_hash, session=None):
"""Used to set the status when an unexpected error occurs. """Used to set the status when an unexpected error occurs.
Will set the state to FUBAR Will set the state to FUBAR
@ -176,22 +176,22 @@ def set_fubar(tx_hash):
:raises NotLocalTxError: If transaction not found in queue. :raises NotLocalTxError: If transaction not found in queue.
""" """
session = SessionBase.create_session() session = SessionBase.bind_session(session)
o = Otx.load(tx_hash, session=session) o = Otx.load(tx_hash, session=session)
if o == None: if o == None:
session.close() SessionBase.release_session(session)
raise NotLocalTxError('queue does not contain tx hash {}'.format(tx_hash)) raise NotLocalTxError('queue does not contain tx hash {}'.format(tx_hash))
session.flush() session.flush()
o.fubar(session=session) o.fubar(session=session)
session.commit() session.commit()
session.close() SessionBase.release_session(session)
return tx_hash return tx_hash
def set_manual(tx_hash): def set_manual(chain_spec, tx_hash, session=None):
"""Used to set the status when queue is manually changed """Used to set the status when queue is manually changed
Will set the state to MANUAL Will set the state to MANUAL
@ -201,32 +201,32 @@ def set_manual(tx_hash):
:raises NotLocalTxError: If transaction not found in queue. :raises NotLocalTxError: If transaction not found in queue.
""" """
session = SessionBase.create_session() session = SessionBase.bind_session(session)
o = Otx.load(tx_hash, session=session) o = Otx.load(tx_hash, session=session)
if o == None: if o == None:
session.close() SessionBase.release_session(session)
raise NotLocalTxError('queue does not contain tx hash {}'.format(tx_hash)) raise NotLocalTxError('queue does not contain tx hash {}'.format(tx_hash))
session.flush() session.flush()
o.manual(session=session) o.manual(session=session)
session.commit() session.commit()
session.close() SessionBase.release_session(session)
return tx_hash return tx_hash
def set_ready(tx_hash): def set_ready(chain_spec, tx_hash, session=None):
"""Used to mark a transaction as ready to be sent to network """Used to mark a transaction as ready to be sent to network
:param tx_hash: Transaction hash of record to modify :param tx_hash: Transaction hash of record to modify
:type tx_hash: str, 0x-hex :type tx_hash: str, 0x-hex
:raises NotLocalTxError: If transaction not found in queue. :raises NotLocalTxError: If transaction not found in queue.
""" """
session = SessionBase.create_session() session = SessionBase.bind_session(session)
o = Otx.load(tx_hash, session=session) o = Otx.load(tx_hash, session=session)
if o == None: if o == None:
session.close() SessionBase.release_session(session)
raise NotLocalTxError('queue does not contain tx hash {}'.format(tx_hash)) raise NotLocalTxError('queue does not contain tx hash {}'.format(tx_hash))
session.flush() session.flush()
@ -236,29 +236,29 @@ def set_ready(tx_hash):
o.retry(session=session) o.retry(session=session)
session.commit() session.commit()
session.close() SessionBase.release_session(session)
return tx_hash return tx_hash
def set_reserved(tx_hash): def set_reserved(chain_spec, tx_hash, session=None):
session = SessionBase.create_session() session = SessionBase.bind_session(session)
o = Otx.load(tx_hash, session=session) o = Otx.load(tx_hash, session=session)
if o == None: if o == None:
session.close() SessionBase.release_session(session)
raise NotLocalTxError('queue does not contain tx hash {}'.format(tx_hash)) raise NotLocalTxError('queue does not contain tx hash {}'.format(tx_hash))
session.flush() session.flush()
o.reserve(session=session) o.reserve(session=session)
session.commit() session.commit()
session.close() SessionBase.release_session(session)
return tx_hash return tx_hash
def set_waitforgas(tx_hash): def set_waitforgas(chain_spec, tx_hash, session=None):
"""Used to set the status when a transaction must be deferred due to gas refill """Used to set the status when a transaction must be deferred due to gas refill
Will set the state to WAITFORGAS Will set the state to WAITFORGAS
@ -267,27 +267,26 @@ def set_waitforgas(tx_hash):
:type tx_hash: str, 0x-hex :type tx_hash: str, 0x-hex
:raises NotLocalTxError: If transaction not found in queue. :raises NotLocalTxError: If transaction not found in queue.
""" """
session = SessionBase.bind_session(session)
session = SessionBase.create_session()
o = Otx.load(tx_hash, session=session) o = Otx.load(tx_hash, session=session)
if o == None: if o == None:
session.close() SessionBase.release_session(session)
raise NotLocalTxError('queue does not contain tx hash {}'.format(tx_hash)) raise NotLocalTxError('queue does not contain tx hash {}'.format(tx_hash))
session.flush() session.flush()
o.waitforgas(session=session) o.waitforgas(session=session)
session.commit() session.commit()
session.close() SessionBase.release_session(session)
return tx_hash return tx_hash
def get_state_log(tx_hash): def get_state_log(chain_spec, tx_hash, session=None):
logs = [] logs = []
session = SessionBase.create_session() session = SessionBase.bind_session(session)
q = session.query(OtxStateLog) q = session.query(OtxStateLog)
q = q.join(Otx) q = q.join(Otx)
@ -296,13 +295,13 @@ def get_state_log(tx_hash):
for l in q.all(): for l in q.all():
logs.append((l.date, l.status,)) logs.append((l.date, l.status,))
session.close() SessionBase.release_session(session)
return logs return logs
def cancel_obsoletes_by_cache(tx_hash): def cancel_obsoletes_by_cache(chain_spec, tx_hash):
session = SessionBase.create_session() session = SessionBase.create_session()
q = session.query( q = session.query(
Otx.nonce.label('nonce'), Otx.nonce.label('nonce'),

View File

@ -24,7 +24,7 @@ class TestOtxState(TestOtxBase):
set_sent(self.tx_hash) set_sent(self.tx_hash)
set_final(self.tx_hash, block=1042) set_final(self.tx_hash, block=1042)
state_log = get_state_log(self.tx_hash) state_log = get_state_log(self.chain_spec, self.tx_hash)
self.assertEqual(state_log[0][1], StatusEnum.READYSEND) self.assertEqual(state_log[0][1], StatusEnum.READYSEND)
self.assertEqual(state_log[1][1], StatusEnum.RESERVED) self.assertEqual(state_log[1][1], StatusEnum.RESERVED)
self.assertEqual(state_log[2][1], StatusEnum.SENT) self.assertEqual(state_log[2][1], StatusEnum.SENT)