diff --git a/chainsyncer/backend/memory.py b/chainsyncer/backend/memory.py index 14f7fdf..2e75496 100644 --- a/chainsyncer/backend/memory.py +++ b/chainsyncer/backend/memory.py @@ -14,6 +14,8 @@ class MemBackend: self.flags = 0 self.target_block = target_block self.db_session = None + self.filter_names = [] + self.filter_values = [] def connect(self): @@ -28,6 +30,8 @@ class MemBackend: logg.debug('stateless backend received {} {}'.format(block_height, tx_height)) self.block_height = block_height self.tx_height = tx_height + for i in range(len(self.filter_values)): + self.filter_values[i] = False def get(self): @@ -39,11 +43,13 @@ class MemBackend: def register_filter(self, name): - pass + self.filter_names.append(name) + self.filter_values.append(False) def complete_filter(self, n): - pass + self.filter_values[n-1] = True + logg.debug('set filter {}'.format(self.filter_names[n-1])) def __str__(self): diff --git a/chainsyncer/driver.py b/chainsyncer/driver.py index 661c0f2..440e9ae 100644 --- a/chainsyncer/driver.py +++ b/chainsyncer/driver.py @@ -72,6 +72,11 @@ class Syncer: self.backend.register_filter(str(f)) + def process_single(self, conn, block, tx, block_height, tx_index): + self.backend.set(block_height, tx_index) + self.filter.apply(conn, block, tx) + + class BlockPollSyncer(Syncer): def __init__(self, backend, pre_callback=None, block_callback=None, post_callback=None): @@ -120,14 +125,16 @@ class HeadSyncer(BlockPollSyncer): while True: try: tx = block.tx(i) - rcpt = conn.do(receipt(tx.hash)) - tx.apply_receipt(rcpt) - self.backend.set(block.number, i) - self.filter.apply(conn, block, tx) except IndexError as e: logg.debug('index error syncer rcpt get {}'.format(e)) self.backend.set(block.number + 1, 0) break + + rcpt = conn.do(receipt(tx.hash)) + tx.apply_receipt(rcpt) + + self.process_single(conn, block, tx, block.number, i) + i += 1 diff --git a/sql_requirements.txt b/sql_requirements.txt new file mode 100644 index 0000000..602c2c5 --- /dev/null +++ b/sql_requirements.txt @@ -0,0 +1,2 @@ +psycopg2==2.8.6 +SQLAlchemy==1.3.20 diff --git a/tests/test_interrupt.py b/tests/test_interrupt.py new file mode 100644 index 0000000..4cd61c4 --- /dev/null +++ b/tests/test_interrupt.py @@ -0,0 +1,106 @@ +# standard imports +import logging +import unittest +import os + +# external imports +from chainlib.chain import ChainSpec +from hexathon import add_0x + +# local imports +from chainsyncer.backend.memory import MemBackend +from chainsyncer.driver import HeadSyncer +from chainsyncer.error import NoBlockForYou + +# test imports +from tests.base import TestBase + +logging.basicConfig(level=logging.DEBUG) +logg = logging.getLogger() + + +class TestSyncer(HeadSyncer): + + + def __init__(self, backend, tx_counts=[]): + self.tx_counts = tx_counts + super(TestSyncer, self).__init__(backend) + + + def get(self, conn): + if self.backend.block_height == self.backend.target_block: + raise NoBlockForYou() + if self.backend.block_height > len(self.tx_counts): + return [] + + block_txs = [] + for i in range(self.tx_counts[self.backend.block_height]): + block_txs.append(add_0x(os.urandom(32).hex())) + + return block_txs + + + def process(self, conn, block): + i = 0 + for tx in block: + self.process_single(conn, block, tx, self.backend.block_height, i) + i += 1 + + + +class NaughtyCountExceptionFilter: + + def __init__(self, name, croak_on): + self.c = 0 + self.croak = croak_on + self.name = name + + + def filter(self, conn, block, tx, db_session=None): + self.c += 1 + if self.c == self.croak: + raise RuntimeError('foo') + + + def __str__(self): + return '{} {}'.format(self.__class__.__name__, self.name) + + +class CountFilter: + + def __init__(self, name): + self.c = 0 + self.name = name + + + def filter(self, conn, block, tx, db_session=None): + self.c += 1 + + + def __str__(self): + return '{} {}'.format(self.__class__.__name__, self.name) + + +class TestInterrupt(unittest.TestCase): + + def setUp(self): + self.chain_spec = ChainSpec('foo', 'bar', 42, 'baz') + self.backend = MemBackend(self.chain_spec, None, target_block=2) + self.syncer = TestSyncer(self.backend, [4, 2, 3]) + + def test_filter_interrupt(self): + + fltrs = [ + CountFilter('foo'), + CountFilter('bar'), + NaughtyCountExceptionFilter('xyzzy', 2), + CountFilter('baz'), + ] + + for fltr in fltrs: + self.syncer.add_filter(fltr) + + self.syncer.loop(0.1, None) + +if __name__ == '__main__': + unittest.main()