diff --git a/chainsyncer/driver.py b/chainsyncer/driver.py index 0dc9306..6f52c86 100644 --- a/chainsyncer/driver.py +++ b/chainsyncer/driver.py @@ -124,6 +124,7 @@ class SyncDriver: def process_single(self, conn, block, tx): + logg.debug('single') self.session.filter(conn, block, tx) diff --git a/chainsyncer/unittest/base.py b/chainsyncer/unittest/base.py index 2674d85..632228c 100644 --- a/chainsyncer/unittest/base.py +++ b/chainsyncer/unittest/base.py @@ -15,6 +15,10 @@ from chainsyncer.driver import SyncDriver logg = logging.getLogger().getChild(__name__) +class MockFilterError(Exception): + pass + + class MockConn: """Noop connection mocker. @@ -82,7 +86,7 @@ class MockStore(State): class MockFilter: - def __init__(self, name, brk=False, z=None): + def __init__(self, name, brk=None, brk_hard=None, z=None): self.name = name if z == None: h = hashlib.sha256() @@ -90,6 +94,7 @@ class MockFilter: z = h.digest() self.z = z self.brk = brk + self.brk_hard = brk_hard self.contents = [] @@ -102,8 +107,21 @@ class MockFilter: def filter(self, conn, block, tx): + r = False self.contents.append((block.number, tx.index, tx.hash,)) - return self.brk + if self.brk_hard != None: + r = True + if self.brk_hard > 0: + r = True + self.brk_hard -= 1 + if r: + raise MockFilterError() + if self.brk != None: + if self.brk > 0: + r = True + self.brk -= 1 + logg.debug('filter {} r {}'.format(self.common_name(), r)) + return r class MockDriver(SyncDriver): diff --git a/tests/test_session.py b/tests/test_session.py index cfb91dd..0c57834 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -21,6 +21,7 @@ from chainsyncer.unittest import ( MockTx, MockBlock, MockDriver, + MockFilterError, ) from chainsyncer.driver import SyncDriver @@ -81,5 +82,36 @@ class TestFilter(unittest.TestCase): self.assertEqual(len(fltr_one.contents), 3) + def test_driver_interrupt(self): + drv = MockDriver(self.store, target=1) + + tx_hash = os.urandom(32).hex() + tx = MockTx(0, tx_hash) + block = MockBlock(0, [tx_hash]) + drv.add_block(block) + + fltr_one = MockFilter('foo', brk_hard=1) + self.store.register(fltr_one) + fltr_two = MockFilter('bar') + self.store.register(fltr_two) + + with self.assertRaises(MockFilterError): + drv.run(self.conn) + + store = SyncFsStore(self.path) + drv = MockDriver(store, target=1) + drv.add_block(block) + + tx_hash_one = os.urandom(32).hex() + tx = MockTx(0, tx_hash_one) + tx_hash_two = os.urandom(32).hex() + tx = MockTx(1, tx_hash_two) + block = MockBlock(1, [tx_hash_one, tx_hash_two]) + drv.add_block(block) + + with self.assertRaises(SyncDone): + drv.run(self.conn) + + if __name__ == '__main__': unittest.main()