diff --git a/chainsyncer/backend.py b/chainsyncer/backend.py index a5da1f4..fdafda9 100644 --- a/chainsyncer/backend.py +++ b/chainsyncer/backend.py @@ -79,7 +79,7 @@ class SyncerBackend: """ self.connect() pair = self.db_object.cursor() - filter_state = self.db_object_filter.cursor() + (filter_state, count, digest) = self.db_object_filter.cursor() self.disconnect() return (pair, filter_state,) @@ -95,7 +95,7 @@ class SyncerBackend: """ self.connect() pair = self.db_object.set(block_height, tx_height) - filter_state = self.db_object_filter.cursor() + (filter_state, count, digest)= self.db_object_filter.cursor() self.disconnect() return (pair, filter_state,) @@ -108,7 +108,7 @@ class SyncerBackend: """ self.connect() pair = self.db_object.start() - filter_state = self.db_object_filter.start() + (filter_state, count, digest) = self.db_object_filter.start() self.disconnect() return (pair, filter_state,) @@ -121,7 +121,7 @@ class SyncerBackend: """ self.connect() target = self.db_object.target() - filter_state = self.db_object_filter.target() + (filter_state, count, digest) = self.db_object_filter.target() self.disconnect() return (target, filter_target,) @@ -144,7 +144,7 @@ class SyncerBackend: @staticmethod - def initial(chain, block_height): + def initial(chain_spec, target_block_height, start_block_height=0): """Creates a new syncer session and commit its initial state to backend. :param chain: Chain spec of chain that syncer is running for. @@ -154,24 +154,31 @@ class SyncerBackend: :returns: New syncer object :rtype: cic_eth.db.models.BlockchainSync """ + if start_block_height >= target_block_height: + raise ValueError('start block height must be lower than target block height') object_id = None session = SessionBase.create_session() - o = BlockchainSync(chain, 0, 0, block_height) + o = BlockchainSync(str(chain_spec), start_block_height, 0, target_block_height) session.add(o) session.commit() object_id = o.id + + of = BlockchainSyncFilter(o) + session.add(of) + session.commit() + session.close() - return SyncerBackend(chain, object_id) + return SyncerBackend(chain_spec, object_id) @staticmethod - def resume(chain, block_height): + def resume(chain_spec, block_height): """Retrieves and returns all previously unfinished syncer sessions. - :param chain: Chain spec of chain that syncer is running for. - :type chain: cic_registry.chain.ChainSpec + :param chain_spec: Chain spec of chain that syncer is running for. + :type chain_spec: cic_registry.chain.ChainSpec :param block_height: Target block height :type block_height: number :returns: Syncer objects of unfinished syncs @@ -185,16 +192,39 @@ class SyncerBackend: for object_id in BlockchainSync.get_unsynced(session=session): logg.debug('block syncer resume added previously unsynced sync entry id {}'.format(object_id)) - syncers.append(SyncerBackend(chain, object_id)) + syncers.append(SyncerBackend(chain_spec, object_id)) - (block_resume, tx_resume) = BlockchainSync.get_last_live_height(block_height, session=session) - if block_height != block_resume: - o = BlockchainSync(chain, block_resume, tx_resume, block_height) - session.add(o) - session.commit() - object_id = o.id - syncers.append(SyncerBackend(chain, object_id)) - logg.debug('block syncer resume added new sync entry from previous run id {}, start{}:{} target {}'.format(object_id, block_resume, tx_resume, block_height)) + last_live_id = BlockchainSync.get_last_live(block_height, session=session) + logg.debug('last_live_id {}'.format(last_live_id)) + if last_live_id != None: + + q = session.query(BlockchainSync) + o = q.get(last_live_id) + + (block_resume, tx_resume) = o.cursor() + session.flush() + + if block_height != block_resume: + + q = session.query(BlockchainSyncFilter) + q = q.filter(BlockchainSyncFilter.chain_sync_id==last_live_id) + of = q.first() + (flags, count, digest) = of.cursor() + + session.flush() + + o = BlockchainSync(str(chain_spec), block_resume, tx_resume, block_height) + session.add(o) + session.flush() + object_id = o.id + + of = BlockchainSyncFilter(o, count, flags, digest) + session.add(of) + session.commit() + + syncers.append(SyncerBackend(chain_spec, object_id)) + + logg.debug('block syncer resume added new sync entry from previous run id {}, start{}:{} target {}'.format(object_id, block_resume, tx_resume, block_height)) session.close() diff --git a/chainsyncer/db/models/filter.py b/chainsyncer/db/models/filter.py index 33964b7..93cf4c2 100644 --- a/chainsyncer/db/models/filter.py +++ b/chainsyncer/db/models/filter.py @@ -31,6 +31,9 @@ class BlockchainSyncFilter(SessionBase): if flags == None: flags = bytearray(0) + else: # TODO: handle bytes too + bytecount = int((count - 1) / 8 + 1) + 1 + flags = flags.to_bytes(bytecount, 'big') self.flags_start = flags self.flags = flags @@ -54,22 +57,22 @@ class BlockchainSyncFilter(SessionBase): def start(self): - return int.from_bytes(self.flags_start, 'big') + return (int.from_bytes(self.flags_start, 'big'), self.count, self.digest) def cursor(self): - return int.from_bytes(self.flags, 'big') - - - def clear(self): - self.flags = 0 + return (int.from_bytes(self.flags, 'big'), self.count, self.digest) def target(self): n = 0 for i in range(self.count): n |= (1 << self.count) - 1 - return n + return (n, self.count, self.digest) + + + def clear(self): + self.flags = 0 def set(self, n): diff --git a/chainsyncer/db/models/sync.py b/chainsyncer/db/models/sync.py index aafdd74..01c28c1 100644 --- a/chainsyncer/db/models/sync.py +++ b/chainsyncer/db/models/sync.py @@ -61,7 +61,7 @@ class BlockchainSync(SessionBase): @staticmethod - def get_last_live_height(current, session=None): + def get_last_live(current, session=None): """Get the most recent open-ended ("live") syncer record. :param current: Current block number @@ -71,21 +71,19 @@ class BlockchainSync(SessionBase): :returns: Block and transaction number, respectively :rtype: tuple """ - local_session = False - if session == None: - session = SessionBase.create_session() - local_session = True - q = session.query(BlockchainSync) + session = SessionBase.bind_session(session) + + q = session.query(BlockchainSync.id) q = q.filter(BlockchainSync.block_target==None) q = q.order_by(BlockchainSync.date_created.desc()) - o = q.first() - if local_session: - session.close() + object_id = q.first() - if o == None: - return (0, 0) + SessionBase.release_session(session) - return (o.block_cursor, o.tx_cursor) + if object_id == None: + return None + + return object_id[0] @staticmethod diff --git a/tests/test_database.py b/tests/test_database.py index ac3baf2..5066f69 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -54,19 +54,21 @@ class TestDatabase(TestBase): o = session.query(BlockchainSyncFilter).get(filter_id) self.assertEqual(len(o.flags), 2) - t = o.target() + (t, c, d) = o.target() self.assertEqual(t, (1 << 9) - 1) for i in range(9): o.set(i) - c = o.cursor() - self.assertEqual(c, t) + (f, c, d) = o.cursor() + self.assertEqual(f, t) + self.assertEqual(c, 9) + self.assertEqual(d, o.digest) session.close() - def test_backend_resume(self): + def test_backend_retrieve(self): s = SyncerBackend.live(self.chain_spec, 42) s.register_filter('foo') s.register_filter('bar') @@ -77,7 +79,43 @@ class TestDatabase(TestBase): s = SyncerBackend.first(self.chain_spec) logg.debug('start {}'.format(s)) self.assertEqual(s.get(), ((42,13), 0)) + + + def test_backend_initial(self): + with self.assertRaises(ValueError): + s = SyncerBackend.initial(self.chain_spec, 42, 42) + with self.assertRaises(ValueError): + s = SyncerBackend.initial(self.chain_spec, 42, 43) + + s = SyncerBackend.initial(self.chain_spec, 42, 13) + + s.set(43, 13) + + s = SyncerBackend.first(self.chain_spec) + self.assertEqual(s.get(), ((43,13), 0)) + self.assertEqual(s.start(), ((13,0), 0)) + + + def test_backend_resume(self): + s = SyncerBackend.resume(self.chain_spec, 666) + self.assertEqual(len(s), 0) + + s = SyncerBackend.live(self.chain_spec, 42) + original_id = s.object_id + s = SyncerBackend.resume(self.chain_spec, 666) + self.assertEqual(len(s), 1) + resumed_id = s[0].object_id + self.assertEqual(resumed_id, original_id + 1) + + + def test_backend_resume_several(self): + s = SyncerBackend.live(self.chain_spec, 42) + s.set(43, 13) + s = SyncerBackend.resume(self.chain_spec, 666) + s[0].set(123, 2) + s = SyncerBackend.resume(self.chain_spec, 1024) + self.assertEqual(len(s), 2) if __name__ == '__main__': unittest.main()