Add test for resume tx, add flag state to resume

This commit is contained in:
nolash 2021-02-22 18:40:41 +01:00
parent 5cf8b4f296
commit 37d0a36303
Signed by: lash
GPG Key ID: 21D2E7BB88C2A746
4 changed files with 111 additions and 42 deletions

View File

@ -79,7 +79,7 @@ class SyncerBackend:
""" """
self.connect() self.connect()
pair = self.db_object.cursor() pair = self.db_object.cursor()
filter_state = self.db_object_filter.cursor() (filter_state, count, digest) = self.db_object_filter.cursor()
self.disconnect() self.disconnect()
return (pair, filter_state,) return (pair, filter_state,)
@ -95,7 +95,7 @@ class SyncerBackend:
""" """
self.connect() self.connect()
pair = self.db_object.set(block_height, tx_height) pair = self.db_object.set(block_height, tx_height)
filter_state = self.db_object_filter.cursor() (filter_state, count, digest)= self.db_object_filter.cursor()
self.disconnect() self.disconnect()
return (pair, filter_state,) return (pair, filter_state,)
@ -108,7 +108,7 @@ class SyncerBackend:
""" """
self.connect() self.connect()
pair = self.db_object.start() pair = self.db_object.start()
filter_state = self.db_object_filter.start() (filter_state, count, digest) = self.db_object_filter.start()
self.disconnect() self.disconnect()
return (pair, filter_state,) return (pair, filter_state,)
@ -121,7 +121,7 @@ class SyncerBackend:
""" """
self.connect() self.connect()
target = self.db_object.target() target = self.db_object.target()
filter_state = self.db_object_filter.target() (filter_state, count, digest) = self.db_object_filter.target()
self.disconnect() self.disconnect()
return (target, filter_target,) return (target, filter_target,)
@ -144,7 +144,7 @@ class SyncerBackend:
@staticmethod @staticmethod
def initial(chain, block_height): def initial(chain_spec, target_block_height, start_block_height=0):
"""Creates a new syncer session and commit its initial state to backend. """Creates a new syncer session and commit its initial state to backend.
:param chain: Chain spec of chain that syncer is running for. :param chain: Chain spec of chain that syncer is running for.
@ -154,24 +154,31 @@ class SyncerBackend:
:returns: New syncer object :returns: New syncer object
:rtype: cic_eth.db.models.BlockchainSync :rtype: cic_eth.db.models.BlockchainSync
""" """
if start_block_height >= target_block_height:
raise ValueError('start block height must be lower than target block height')
object_id = None object_id = None
session = SessionBase.create_session() session = SessionBase.create_session()
o = BlockchainSync(chain, 0, 0, block_height) o = BlockchainSync(str(chain_spec), start_block_height, 0, target_block_height)
session.add(o) session.add(o)
session.commit() session.commit()
object_id = o.id object_id = o.id
of = BlockchainSyncFilter(o)
session.add(of)
session.commit()
session.close() session.close()
return SyncerBackend(chain, object_id) return SyncerBackend(chain_spec, object_id)
@staticmethod @staticmethod
def resume(chain, block_height): def resume(chain_spec, block_height):
"""Retrieves and returns all previously unfinished syncer sessions. """Retrieves and returns all previously unfinished syncer sessions.
:param chain: Chain spec of chain that syncer is running for. :param chain_spec: Chain spec of chain that syncer is running for.
:type chain: cic_registry.chain.ChainSpec :type chain_spec: cic_registry.chain.ChainSpec
:param block_height: Target block height :param block_height: Target block height
:type block_height: number :type block_height: number
:returns: Syncer objects of unfinished syncs :returns: Syncer objects of unfinished syncs
@ -185,16 +192,39 @@ class SyncerBackend:
for object_id in BlockchainSync.get_unsynced(session=session): for object_id in BlockchainSync.get_unsynced(session=session):
logg.debug('block syncer resume added previously unsynced sync entry id {}'.format(object_id)) logg.debug('block syncer resume added previously unsynced sync entry id {}'.format(object_id))
syncers.append(SyncerBackend(chain, object_id)) syncers.append(SyncerBackend(chain_spec, object_id))
(block_resume, tx_resume) = BlockchainSync.get_last_live_height(block_height, session=session) last_live_id = BlockchainSync.get_last_live(block_height, session=session)
if block_height != block_resume: logg.debug('last_live_id {}'.format(last_live_id))
o = BlockchainSync(chain, block_resume, tx_resume, block_height) if last_live_id != None:
session.add(o)
session.commit() q = session.query(BlockchainSync)
object_id = o.id o = q.get(last_live_id)
syncers.append(SyncerBackend(chain, object_id))
logg.debug('block syncer resume added new sync entry from previous run id {}, start{}:{} target {}'.format(object_id, block_resume, tx_resume, block_height)) (block_resume, tx_resume) = o.cursor()
session.flush()
if block_height != block_resume:
q = session.query(BlockchainSyncFilter)
q = q.filter(BlockchainSyncFilter.chain_sync_id==last_live_id)
of = q.first()
(flags, count, digest) = of.cursor()
session.flush()
o = BlockchainSync(str(chain_spec), block_resume, tx_resume, block_height)
session.add(o)
session.flush()
object_id = o.id
of = BlockchainSyncFilter(o, count, flags, digest)
session.add(of)
session.commit()
syncers.append(SyncerBackend(chain_spec, object_id))
logg.debug('block syncer resume added new sync entry from previous run id {}, start{}:{} target {}'.format(object_id, block_resume, tx_resume, block_height))
session.close() session.close()

View File

@ -31,6 +31,9 @@ class BlockchainSyncFilter(SessionBase):
if flags == None: if flags == None:
flags = bytearray(0) flags = bytearray(0)
else: # TODO: handle bytes too
bytecount = int((count - 1) / 8 + 1) + 1
flags = flags.to_bytes(bytecount, 'big')
self.flags_start = flags self.flags_start = flags
self.flags = flags self.flags = flags
@ -54,22 +57,22 @@ class BlockchainSyncFilter(SessionBase):
def start(self): def start(self):
return int.from_bytes(self.flags_start, 'big') return (int.from_bytes(self.flags_start, 'big'), self.count, self.digest)
def cursor(self): def cursor(self):
return int.from_bytes(self.flags, 'big') return (int.from_bytes(self.flags, 'big'), self.count, self.digest)
def clear(self):
self.flags = 0
def target(self): def target(self):
n = 0 n = 0
for i in range(self.count): for i in range(self.count):
n |= (1 << self.count) - 1 n |= (1 << self.count) - 1
return n return (n, self.count, self.digest)
def clear(self):
self.flags = 0
def set(self, n): def set(self, n):

View File

@ -61,7 +61,7 @@ class BlockchainSync(SessionBase):
@staticmethod @staticmethod
def get_last_live_height(current, session=None): def get_last_live(current, session=None):
"""Get the most recent open-ended ("live") syncer record. """Get the most recent open-ended ("live") syncer record.
:param current: Current block number :param current: Current block number
@ -71,21 +71,19 @@ class BlockchainSync(SessionBase):
:returns: Block and transaction number, respectively :returns: Block and transaction number, respectively
:rtype: tuple :rtype: tuple
""" """
local_session = False session = SessionBase.bind_session(session)
if session == None:
session = SessionBase.create_session() q = session.query(BlockchainSync.id)
local_session = True
q = session.query(BlockchainSync)
q = q.filter(BlockchainSync.block_target==None) q = q.filter(BlockchainSync.block_target==None)
q = q.order_by(BlockchainSync.date_created.desc()) q = q.order_by(BlockchainSync.date_created.desc())
o = q.first() object_id = q.first()
if local_session:
session.close()
if o == None: SessionBase.release_session(session)
return (0, 0)
return (o.block_cursor, o.tx_cursor) if object_id == None:
return None
return object_id[0]
@staticmethod @staticmethod

View File

@ -54,19 +54,21 @@ class TestDatabase(TestBase):
o = session.query(BlockchainSyncFilter).get(filter_id) o = session.query(BlockchainSyncFilter).get(filter_id)
self.assertEqual(len(o.flags), 2) self.assertEqual(len(o.flags), 2)
t = o.target() (t, c, d) = o.target()
self.assertEqual(t, (1 << 9) - 1) self.assertEqual(t, (1 << 9) - 1)
for i in range(9): for i in range(9):
o.set(i) o.set(i)
c = o.cursor() (f, c, d) = o.cursor()
self.assertEqual(c, t) self.assertEqual(f, t)
self.assertEqual(c, 9)
self.assertEqual(d, o.digest)
session.close() session.close()
def test_backend_resume(self): def test_backend_retrieve(self):
s = SyncerBackend.live(self.chain_spec, 42) s = SyncerBackend.live(self.chain_spec, 42)
s.register_filter('foo') s.register_filter('foo')
s.register_filter('bar') s.register_filter('bar')
@ -79,5 +81,41 @@ class TestDatabase(TestBase):
self.assertEqual(s.get(), ((42,13), 0)) self.assertEqual(s.get(), ((42,13), 0))
def test_backend_initial(self):
with self.assertRaises(ValueError):
s = SyncerBackend.initial(self.chain_spec, 42, 42)
with self.assertRaises(ValueError):
s = SyncerBackend.initial(self.chain_spec, 42, 43)
s = SyncerBackend.initial(self.chain_spec, 42, 13)
s.set(43, 13)
s = SyncerBackend.first(self.chain_spec)
self.assertEqual(s.get(), ((43,13), 0))
self.assertEqual(s.start(), ((13,0), 0))
def test_backend_resume(self):
s = SyncerBackend.resume(self.chain_spec, 666)
self.assertEqual(len(s), 0)
s = SyncerBackend.live(self.chain_spec, 42)
original_id = s.object_id
s = SyncerBackend.resume(self.chain_spec, 666)
self.assertEqual(len(s), 1)
resumed_id = s[0].object_id
self.assertEqual(resumed_id, original_id + 1)
def test_backend_resume_several(self):
s = SyncerBackend.live(self.chain_spec, 42)
s.set(43, 13)
s = SyncerBackend.resume(self.chain_spec, 666)
s[0].set(123, 2)
s = SyncerBackend.resume(self.chain_spec, 1024)
self.assertEqual(len(s), 2)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()