diff --git a/chainsyncer/backend.py b/chainsyncer/backend.py index 5953ff2..a5da1f4 100644 --- a/chainsyncer/backend.py +++ b/chainsyncer/backend.py @@ -79,7 +79,7 @@ class SyncerBackend: """ self.connect() pair = self.db_object.cursor() - filter_state = self.db_object_filter.filter() + filter_state = self.db_object_filter.cursor() self.disconnect() return (pair, filter_state,) @@ -95,7 +95,7 @@ class SyncerBackend: """ self.connect() pair = self.db_object.set(block_height, tx_height) - filter_state = self.db_object_filter.filter() + filter_state = self.db_object_filter.cursor() self.disconnect() return (pair, filter_state,) diff --git a/chainsyncer/db/models/filter.py b/chainsyncer/db/models/filter.py index 656ed62..33964b7 100644 --- a/chainsyncer/db/models/filter.py +++ b/chainsyncer/db/models/filter.py @@ -54,11 +54,11 @@ class BlockchainSyncFilter(SessionBase): def start(self): - return self.flags_start + return int.from_bytes(self.flags_start, 'big') def cursor(self): - return self.flags_current + return int.from_bytes(self.flags, 'big') def clear(self): @@ -68,12 +68,19 @@ class BlockchainSyncFilter(SessionBase): def target(self): n = 0 for i in range(self.count): - n |= 2 << i + n |= (1 << self.count) - 1 return n def set(self, n): - if self.flags & n > 0: + if n > self.count: + raise IndexError('bit flag out of range') + + b = 1 << (n % 8) + i = int((n - 1) / 8 + 1) + if self.flags[i] & b > 0: SessionBase.release_session(session) raise AttributeError('Filter bit already set') - r.flags |= n + flags = bytearray(self.flags) + flags[i] |= b + self.flags = flags diff --git a/chainsyncer/db/models/sync.py b/chainsyncer/db/models/sync.py index 4f2f156..aafdd74 100644 --- a/chainsyncer/db/models/sync.py +++ b/chainsyncer/db/models/sync.py @@ -126,6 +126,7 @@ class BlockchainSync(SessionBase): """ self.block_cursor = block_height self.tx_cursor = tx_height + return (self.block_cursor, self.tx_cursor,) def cursor(self): diff --git a/tests/test_database.py b/tests/test_database.py index 22f919b..ac3baf2 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -1,5 +1,6 @@ # standard imports import unittest +import logging # external imports from chainlib.chain import ChainSpec @@ -12,6 +13,9 @@ from chainsyncer.backend import SyncerBackend # testutil imports from tests.base import TestBase +logg = logging.getLogger() + + class TestDatabase(TestBase): @@ -49,7 +53,31 @@ class TestDatabase(TestBase): session = SessionBase.create_session() o = session.query(BlockchainSyncFilter).get(filter_id) self.assertEqual(len(o.flags), 2) + + t = o.target() + self.assertEqual(t, (1 << 9) - 1) + + for i in range(9): + o.set(i) + + c = o.cursor() + self.assertEqual(c, t) + session.close() + + def test_backend_resume(self): + s = SyncerBackend.live(self.chain_spec, 42) + s.register_filter('foo') + s.register_filter('bar') + s.register_filter('baz') + + s.set(42, 13) + + s = SyncerBackend.first(self.chain_spec) + logg.debug('start {}'.format(s)) + self.assertEqual(s.get(), ((42,13), 0)) + + if __name__ == '__main__': unittest.main()