diff --git a/chainsyncer/driver.py b/chainsyncer/driver.py index 440e9ae..e652a3b 100644 --- a/chainsyncer/driver.py +++ b/chainsyncer/driver.py @@ -72,8 +72,8 @@ class Syncer: self.backend.register_filter(str(f)) - def process_single(self, conn, block, tx, block_height, tx_index): - self.backend.set(block_height, tx_index) + def process_single(self, conn, block, tx): + self.backend.set(block.number, tx.index) self.filter.apply(conn, block, tx) @@ -133,7 +133,7 @@ class HeadSyncer(BlockPollSyncer): rcpt = conn.do(receipt(tx.hash)) tx.apply_receipt(rcpt) - self.process_single(conn, block, tx, block.number, i) + self.process_single(conn, block, tx) i += 1 diff --git a/chainsyncer/unittest/base.py b/chainsyncer/unittest/base.py new file mode 100644 index 0000000..9e320e5 --- /dev/null +++ b/chainsyncer/unittest/base.py @@ -0,0 +1,59 @@ +# standard imports +import os + +# external imports +from hexathon import add_0x + +# local imports +from chainsyncer.driver import HistorySyncer +from chainsyncer.error import NoBlockForYou + + +class MockTx: + + def __init__(self, tx_hash, index): + self.hash = tx_hash + self.index = index + + +class MockBlock: + + def __init__(self, number, txs): + self.number = number + self.txs = txs + + + def tx(self, i): + return MockTx(i, self.txs[i]) + + +class TestSyncer(HistorySyncer): + + + def __init__(self, backend, tx_counts=[]): + self.tx_counts = tx_counts + super(TestSyncer, self).__init__(backend) + + + def get(self, conn): + if self.backend.block_height == self.backend.target_block: + self.running = False + raise NoBlockForYou() + if self.backend.block_height > len(self.tx_counts): + return [] + + block_txs = [] + for i in range(self.tx_counts[self.backend.block_height]): + block_txs.append(add_0x(os.urandom(32).hex())) + + return MockBlock(self.backend.block_height, block_txs) + + + def process(self, conn, block): + i = 0 + for tx in block.txs: + self.process_single(conn, block, tx) + i += 1 + self.backend.set(self.backend.block_height + 1, 0) + + diff --git a/tests/test_interrupt.py b/tests/test_interrupt.py index 4cd61c4..9f4203d 100644 --- a/tests/test_interrupt.py +++ b/tests/test_interrupt.py @@ -9,45 +9,18 @@ from hexathon import add_0x # local imports from chainsyncer.backend.memory import MemBackend -from chainsyncer.driver import HeadSyncer -from chainsyncer.error import NoBlockForYou # test imports from tests.base import TestBase +from chainsyncer.unittest.base import ( + MockBlock, + TestSyncer, + ) logging.basicConfig(level=logging.DEBUG) logg = logging.getLogger() -class TestSyncer(HeadSyncer): - - - def __init__(self, backend, tx_counts=[]): - self.tx_counts = tx_counts - super(TestSyncer, self).__init__(backend) - - - def get(self, conn): - if self.backend.block_height == self.backend.target_block: - raise NoBlockForYou() - if self.backend.block_height > len(self.tx_counts): - return [] - - block_txs = [] - for i in range(self.tx_counts[self.backend.block_height]): - block_txs.append(add_0x(os.urandom(32).hex())) - - return block_txs - - - def process(self, conn, block): - i = 0 - for tx in block: - self.process_single(conn, block, tx, self.backend.block_height, i) - i += 1 - - - class NaughtyCountExceptionFilter: def __init__(self, name, croak_on): @@ -100,7 +73,16 @@ class TestInterrupt(unittest.TestCase): for fltr in fltrs: self.syncer.add_filter(fltr) + try: + self.syncer.loop(0.1, None) + except RuntimeError: + pass self.syncer.loop(0.1, None) + for fltr in fltrs: + logg.debug('{} {}'.format(str(fltr), fltr.c)) + self.assertEqual(fltr.c, 9) + + if __name__ == '__main__': unittest.main()