diff --git a/chainqueue/db/enum.py b/chainqueue/db/enum.py index 0223b37..7cbcece 100644 --- a/chainqueue/db/enum.py +++ b/chainqueue/db/enum.py @@ -51,6 +51,7 @@ class StatusEnum(enum.IntEnum): SENDFAIL = StatusBits.DEFERRED | StatusBits.LOCAL_ERROR RETRY = StatusBits.QUEUED | StatusBits.DEFERRED READYSEND = StatusBits.QUEUED + RESERVED = StatusBits.RESERVED OBSOLETED = StatusBits.OBSOLETE | StatusBits.IN_NETWORK @@ -124,6 +125,15 @@ def is_error_status(v): return bool(v & all_errors()) +__ignore_manual_value = ~StatusBits.MANUAL +def ignore_manual(v): + return v & __ignore_manual_value + + +def is_nascent(v): + return ignore_manual(v) == StatusEnum.PENDING + + def dead(): """Bit mask defining whether a transaction is still likely to be processed on the network. diff --git a/chainqueue/db/migrations/default/versions/2215c497248b_transaction_cache.py b/chainqueue/db/migrations/default/versions/2215c497248b_transaction_cache.py index 720f500..7e5a8b2 100644 --- a/chainqueue/db/migrations/default/versions/2215c497248b_transaction_cache.py +++ b/chainqueue/db/migrations/default/versions/2215c497248b_transaction_cache.py @@ -30,7 +30,7 @@ def upgrade(): sa.Column('recipient', sa.String(42), nullable=False), sa.Column('from_value', sa.NUMERIC(), nullable=False), sa.Column('to_value', sa.NUMERIC(), nullable=True), - sa.Column('block_number', sa.BIGINT(), nullable=True), +# sa.Column('block_number', sa.BIGINT(), nullable=True), sa.Column('tx_index', sa.Integer, nullable=True), ) diff --git a/chainqueue/db/models/otx.py b/chainqueue/db/models/otx.py index 7abfaf6..ebac505 100644 --- a/chainqueue/db/models/otx.py +++ b/chainqueue/db/models/otx.py @@ -519,23 +519,7 @@ class Otx(SessionBase): return q.first() - @staticmethod - def account(account_address): - """Retrieves all transaction hashes for which the given Ethereum address is sender or recipient. - - :param account_address: Ethereum address to use in query. - :type account_address: str, 0x-hex - :returns: Outgoing transactions - :rtype: tuple, where first element is transaction hash - """ - session = Otx.create_session() - q = session.query(Otx.tx_hash) - q = q.join(TxCache) - q = q.filter(or_(TxCache.sender==account_address, TxCache.recipient==account_address)) - txs = q.all() - session.close() - return list(txs) - + def __state_log(self, session): l = OtxStateLog(self) diff --git a/chainqueue/db/models/tx.py b/chainqueue/db/models/tx.py index d402307..e64c2d5 100644 --- a/chainqueue/db/models/tx.py +++ b/chainqueue/db/models/tx.py @@ -59,7 +59,7 @@ class TxCache(SessionBase): recipient = Column(String(42)) from_value = Column(NUMERIC()) to_value = Column(NUMERIC()) - block_number = Column(Integer()) + #block_number = Column(Integer()) tx_index = Column(Integer()) date_created = Column(DateTime, default=datetime.datetime.utcnow) date_updated = Column(DateTime, default=datetime.datetime.utcnow) @@ -126,6 +126,65 @@ class TxCache(SessionBase): SessionBase.release_session(session) + # TODO: possible dead code + @staticmethod + def account(account_address, session=None): + """Retrieves all transaction hashes for which the given Ethereum address is sender or recipient. + + :param account_address: Ethereum address to use in query. + :type account_address: str, 0x-hex + :returns: Outgoing transactions + :rtype: tuple, where first element is transaction hash + """ + session = SessionBase.bind_session(session) + + q = session.query(Otx.tx_hash) + q = q.join(TxCache) + q = q.filter(or_(TxCache.sender==account_address, TxCache.recipient==account_address)) + txs = q.all() + + SessionBase.release_session(session) + return list(txs) + + + @staticmethod + def set_final(tx_hash, block_number, tx_index, session=None): + session = SessionBase.bind_session(session) + + q = session.query(TxCache) + q = q.join(Otx) + q = q.filter(Otx.tx_hash==strip_0x(tx_hash)) + q = q.filter(Otx.block==block_number) + o = q.first() + + if o == None: + raise NotLocalTxError(tx_hash, block_number) + + o.tx_index = tx_index + session.add(o) + session.flush() + + SessionBase.release_session(session) + + + @staticmethod + def load(tx_hash, session=None): + """Retrieves the outgoing transaction record by transaction hash. + + :param tx_hash: Transaction hash + :type tx_hash: str, 0x-hex + """ + session = SessionBase.bind_session(session) + + q = session.query(TxCache) + q = q.join(Otx) + q = q.filter(Otx.tx_hash==strip_0x(tx_hash)) + + SessionBase.release_session(session) + + return q.first() + + def __init__(self, tx_hash, sender, recipient, source_token_address, destination_token_address, from_value, to_value, block_number=None, tx_index=None, session=None): session = SessionBase.bind_session(session) q = session.query(Otx) diff --git a/chainqueue/error.py b/chainqueue/error.py index 8710d33..b0b642e 100644 --- a/chainqueue/error.py +++ b/chainqueue/error.py @@ -5,7 +5,9 @@ class ChainQueueException(Exception): class NotLocalTxError(ChainQueueException): """Exception raised when trying to access a tx not originated from a local task """ - pass + + def __init__(self, tx_hash, block=None): + super(NotLocalTxError, self).__init__(tx_hash, block) class TxStateChangeError(ChainQueueException): diff --git a/chainqueue/state.py b/chainqueue/state.py index d0b1fc7..0f2719d 100644 --- a/chainqueue/state.py +++ b/chainqueue/state.py @@ -11,6 +11,7 @@ from chainqueue.db.models.base import SessionBase from chainqueue.db.enum import ( StatusEnum, StatusBits, + is_nascent, ) from chainqueue.db.models.otx import OtxStateLog from chainqueue.error import ( @@ -95,6 +96,14 @@ def set_final(tx_hash, block=None, fail=False): session.close() raise(e) + q = session.query(TxCache) + q = q.join(Otx) + q = q.filter(Otx.tx_hash==strip_0x(tx_hash)) + o = q.first() + + if o != None: + + session.close() return tx_hash @@ -224,7 +233,7 @@ def set_ready(tx_hash): raise NotLocalTxError('queue does not contain tx hash {}'.format(tx_hash)) session.flush() - if o.status & StatusBits.GAS_ISSUES or o.status == StatusEnum.PENDING: + if o.status & StatusBits.GAS_ISSUES or is_nascent(o.status): o.readysend(session=session) else: o.retry(session=session) @@ -285,7 +294,7 @@ def get_state_log(tx_hash): q = session.query(OtxStateLog) q = q.join(Otx) - q = q.filter(Otx.tx_hash==tx_hash) + q = q.filter(Otx.tx_hash==strip_0x(tx_hash)) q = q.order_by(OtxStateLog.date.asc()) for l in q.all(): logs.append((l.date, l.status,)) @@ -332,4 +341,3 @@ def cancel_obsoletes_by_cache(tx_hash): session.close() return tx_hash - diff --git a/chainqueue/tx.py b/chainqueue/tx.py index b37184b..e2926f1 100644 --- a/chainqueue/tx.py +++ b/chainqueue/tx.py @@ -39,6 +39,7 @@ def create(nonce, holder_address, tx_hash, signed_tx, chain_spec, obsolete_prede ) session.flush() + # TODO: No magic, please, should be separate step if obsolete_predecessors: q = session.query(Otx) q = q.join(TxCache) @@ -60,7 +61,6 @@ def create(nonce, holder_address, tx_hash, signed_tx, chain_spec, obsolete_prede session.close() raise(e) - session.commit() SessionBase.release_session(session) logg.debug('queue created nonce {} from {} hash {}'.format(nonce, holder_address, tx_hash)) diff --git a/tests/base.py b/tests/base.py index cc0a557..4e0a5e2 100644 --- a/tests/base.py +++ b/tests/base.py @@ -6,13 +6,17 @@ import os #import pysqlite # external imports +from chainqueue.db.models.otx import Otx +from chainqueue.db.models.tx import TxCache from chainlib.chain import ChainSpec import alembic import alembic.config +from hexathon import add_0x # local imports from chainqueue.db import dsn_from_config from chainqueue.db.models.base import SessionBase +from chainqueue.tx import create script_dir = os.path.realpath(os.path.dirname(__file__)) @@ -57,3 +61,44 @@ class TestBase(unittest.TestCase): def tearDown(self): self.session.commit() self.session.close() + + +class TestOtxBase(TestBase): + + def setUp(self): + super(TestOtxBase, self).setUp() + self.tx_hash = add_0x(os.urandom(32).hex()) + self.tx = add_0x(os.urandom(128).hex()) + self.nonce = 42 + self.alice = add_0x(os.urandom(20).hex()) + + tx_hash = create(self.nonce, self.alice, self.tx_hash, self.tx, self.chain_spec, session=self.session) + self.assertEqual(tx_hash, self.tx_hash) + + +class TestTxBase(TestOtxBase): + + def setUp(self): + super(TestTxBase, self).setUp() + self.bob = add_0x(os.urandom(20).hex()) + self.foo_token = add_0x(os.urandom(20).hex()) + self.bar_token = add_0x(os.urandom(20).hex()) + self.from_value = 42 + self.to_value = 13 + + txc = TxCache( + self.tx_hash, + self.alice, + self.bob, + self.foo_token, + self.bar_token, + self.from_value, + self.to_value, + session=self.session, + ) + self.session.add(txc) + self.session.commit() + + otx = Otx.load(self.tx_hash) + self.assertEqual(txc.otx_id, otx.id) + diff --git a/tests/test_otx.py b/tests/test_otx.py index 9aa6301..09efc46 100644 --- a/tests/test_otx.py +++ b/tests/test_otx.py @@ -4,16 +4,10 @@ import logging import unittest # external imports -from hexathon import ( - strip_0x, - add_0x, - ) from chainlib.chain import ChainSpec # local imports from chainqueue.db.models.otx import Otx -from chainqueue.db.models.tx import TxCache -from chainqueue.tx import create from chainqueue.state import * from chainqueue.db.enum import ( is_alive, @@ -21,24 +15,13 @@ from chainqueue.db.enum import ( ) # test imports -from tests.base import TestBase +from tests.base import TestOtxBase logging.basicConfig(level=logging.DEBUG) logg = logging.getLogger() -class TestOtx(TestBase): - - def setUp(self): - super(TestOtx, self).setUp() - self.tx_hash = add_0x(os.urandom(32).hex()) - self.tx = add_0x(os.urandom(128).hex()) - self.nonce = 42 - self.alice = add_0x(os.urandom(20).hex()) - - tx_hash = create(self.nonce, self.alice, self.tx_hash, self.tx, self.chain_spec, session=self.session) - self.assertEqual(tx_hash, self.tx_hash) - +class TestOtx(TestOtxBase): def test_ideal_state_sequence(self): set_ready(self.tx_hash) @@ -131,6 +114,7 @@ class TestOtx(TestBase): otx = Otx.load(self.tx_hash, session=self.session) self.assertFalse(is_alive(otx.status)) self.assertTrue(is_error_status(otx.status)) + self.assertEqual(otx.status & StatusBits.NETWORK_ERROR, StatusBits.NETWORK_ERROR) def test_final_protected(self): @@ -154,10 +138,31 @@ class TestOtx(TestBase): set_cancel(self.tx_hash) self.session.refresh(otx) self.assertEqual(otx.status & StatusBits.OBSOLETE, 0) + + set_cancel(self.tx_hash, manual=True) + self.session.refresh(otx) + self.assertEqual(otx.status & StatusBits.OBSOLETE, 0) with self.assertRaises(TxStateChangeError): set_reserved(self.tx_hash) + with self.assertRaises(TxStateChangeError): + set_waitforgas(self.tx_hash) + + with self.assertRaises(TxStateChangeError): + set_manual(self.tx_hash) + + + def test_manual_persist(self): + set_manual(self.tx_hash) + set_ready(self.tx_hash) + set_reserved(self.tx_hash) + set_sent(self.tx_hash) + set_final(self.tx_hash, block=1042) + + otx = Otx.load(self.tx_hash, session=self.session) + self.assertEqual(otx.status & StatusBits.MANUAL, StatusBits.MANUAL) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_otx_status_log.py b/tests/test_otx_status_log.py new file mode 100644 index 0000000..113badd --- /dev/null +++ b/tests/test_otx_status_log.py @@ -0,0 +1,35 @@ +# standard imports +import unittest + +# local imports +from chainqueue.db.models.otx import Otx +from chainqueue.state import * + +# test imports +from tests.base import TestOtxBase + + +class TestOtxState(TestOtxBase): + + + def setUp(self): + super(TestOtxState, self).setUp() + Otx.tracing = True + logg.debug('state trace') + + + def test_state_log(self): + set_ready(self.tx_hash) + set_reserved(self.tx_hash) + set_sent(self.tx_hash) + set_final(self.tx_hash, block=1042) + + state_log = get_state_log(self.tx_hash) + self.assertEqual(state_log[0][1], StatusEnum.READYSEND) + self.assertEqual(state_log[1][1], StatusEnum.RESERVED) + self.assertEqual(state_log[2][1], StatusEnum.SENT) + self.assertEqual(state_log[3][1], StatusEnum.SUCCESS) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_tx_cache.py b/tests/test_tx_cache.py new file mode 100644 index 0000000..d955440 --- /dev/null +++ b/tests/test_tx_cache.py @@ -0,0 +1,35 @@ +# standard imports +import unittest + +# local imports +from chainqueue.db.models.tx import TxCache +from chainqueue.error import NotLocalTxError +from chainqueue.state import * + +# test imports +from tests.base import TestTxBase + +class TestTxCache(TestTxBase): + + def test_mine(self): + with self.assertRaises(NotLocalTxError): + TxCache.set_final(self.tx_hash, 1024, 13, session=self.session) + + set_ready(self.tx_hash) + set_reserved(self.tx_hash) + set_sent(self.tx_hash) + set_final(self.tx_hash, block=1024) + + with self.assertRaises(NotLocalTxError): + TxCache.set_final(self.tx_hash, 1023, 13, session=self.session) + + TxCache.set_final(self.tx_hash, 1024, 13, session=self.session) + + self.session.commit() + + txc = TxCache.load(self.tx_hash) + self.assertEqual(txc.tx_index, 13) + + +if __name__ == '__main__': + unittest.main()