diff --git a/chainsyncer/backend/base.py b/chainsyncer/backend/base.py new file mode 100644 index 0000000..4b6cd63 --- /dev/null +++ b/chainsyncer/backend/base.py @@ -0,0 +1,22 @@ +# standard imports +import logging + +logg = logging.getLogger().getChild(__name__) + + +class Backend: + + def __init__(self, flags_reversed=False): + self.filter_count = 0 + self.flags_reversed = flags_reversed + + + def check_filter(self, n, flags): + if self.flags_reversed: + try: + v = 1 << flags.bit_length() - 1 + return (v >> n) & flags > 0 + except ValueError: + pass + return False + return flags & (1 << n) > 0 diff --git a/chainsyncer/backend/file.py b/chainsyncer/backend/file.py index 61bcb68..0f9f840 100644 --- a/chainsyncer/backend/file.py +++ b/chainsyncer/backend/file.py @@ -4,6 +4,9 @@ import uuid import shutil import logging +# local imports +from .base import Backend + logg = logging.getLogger().getChild(__name__) base_dir = '/var/lib' @@ -19,9 +22,10 @@ def data_dir_for(chain_spec, object_id, base_dir=base_dir): return os.path.join(chain_dir, object_id) -class SyncerFileBackend: +class FileBackend(Backend): def __init__(self, chain_spec, object_id=None, base_dir=base_dir): + super(FileBackend, self).__init__(flags_reversed=True) self.object_data_dir = data_dir_for(chain_spec, object_id, base_dir=base_dir) self.block_height_offset = 0 @@ -38,7 +42,6 @@ class SyncerFileBackend: self.db_object_filter = None self.chain_spec = chain_spec - self.filter_count = 0 self.filter = b'\x00' self.filter_names = [] @@ -47,7 +50,6 @@ class SyncerFileBackend: self.disconnect() - @staticmethod def create_object(chain_spec, object_id=None, base_dir=base_dir): if object_id == None: @@ -157,7 +159,11 @@ class SyncerFileBackend: def get(self): logg.debug('filter {}'.format(self.filter.hex())) - return ((self.block_height_cursor, self.tx_index_cursor), int.from_bytes(self.filter, 'little')) + return ((self.block_height_cursor, self.tx_index_cursor), self.get_flags()) + + + def get_flags(self): + return int.from_bytes(self.filter, 'little') def set(self, block_height, tx_index): @@ -172,7 +178,7 @@ class SyncerFileBackend: # c += f.write(self.filter[c:]) # f.close() - return ((self.block_height_cursor, self.tx_index_cursor), int.from_bytes(self.filter, 'little')) + return ((self.block_height_cursor, self.tx_index_cursor), self.get_flags()) def __set(self, block_height, tx_index, category): @@ -195,9 +201,9 @@ class SyncerFileBackend: if start_block_height >= target_block_height: raise ValueError('start block height must be lower than target block height') - uu = SyncerFileBackend.create_object(chain_spec, base_dir=base_dir) + uu = FileBackend.create_object(chain_spec, base_dir=base_dir) - o = SyncerFileBackend(chain_spec, uu, base_dir=base_dir) + o = FileBackend(chain_spec, uu, base_dir=base_dir) o.__set(target_block_height, 0, 'target') o.__set(start_block_height, 0, 'offset') @@ -227,7 +233,7 @@ class SyncerFileBackend: logg.debug('found syncer entry {} in {}'.format(object_id, d)) - o = SyncerFileBackend(chain_spec, object_id, base_dir=base_dir) + o = FileBackend(chain_spec, object_id, base_dir=base_dir) entries[o.block_height_offset] = o @@ -240,13 +246,13 @@ class SyncerFileBackend: @staticmethod def resume(chain_spec, base_dir=base_dir): - return SyncerFileBackend.__sorted_entries(chain_spec, base_dir=base_dir) + return FileBackend.__sorted_entries(chain_spec, base_dir=base_dir) @staticmethod def first(chain_spec, base_dir=base_dir): - entries = SyncerFileBackend.__sorted_entries(chain_spec, base_dir=base_dir) + entries = FileBackend.__sorted_entries(chain_spec, base_dir=base_dir) return entries[len(entries)-1] diff --git a/chainsyncer/backend/memory.py b/chainsyncer/backend/memory.py index 9b4d2a0..8bcaf20 100644 --- a/chainsyncer/backend/memory.py +++ b/chainsyncer/backend/memory.py @@ -1,12 +1,16 @@ # standard imports import logging +# local imports +from .base import Backend + logg = logging.getLogger().getChild(__name__) -class MemBackend: +class MemBackend(Backend): def __init__(self, chain_spec, object_id, target_block=None): + super(MemBackend, self).__init__() self.object_id = object_id self.chain_spec = chain_spec self.block_height = 0 @@ -41,6 +45,7 @@ class MemBackend: def register_filter(self, name): self.filter_names.append(name) + self.filter_count += 1 def complete_filter(self, n): @@ -53,6 +58,10 @@ class MemBackend: logg.debug('reset filters') self.flags = 0 + + def get_flags(self): + return flags + def __str__(self): return "syncer membackend chain {} cursor".format(self.get()) diff --git a/chainsyncer/backend/sql.py b/chainsyncer/backend/sql.py index 83487a3..3e224ff 100644 --- a/chainsyncer/backend/sql.py +++ b/chainsyncer/backend/sql.py @@ -9,11 +9,12 @@ from chainlib.chain import ChainSpec from chainsyncer.db.models.sync import BlockchainSync from chainsyncer.db.models.filter import BlockchainSyncFilter from chainsyncer.db.models.base import SessionBase +from .base import Backend logg = logging.getLogger().getChild(__name__) -class SyncerBackend: +class SyncerBackend(Backend): """Interface to block and transaction sync state. :param chain_spec: Chain spec for the chain that syncer is running for. @@ -22,6 +23,7 @@ class SyncerBackend: :type object_id: number """ def __init__(self, chain_spec, object_id): + super(SyncerBackend, self).__init__() self.db_session = None self.db_object = None self.db_object_filter = None diff --git a/chainsyncer/filter.py b/chainsyncer/filter.py index 7e83e9e..f2597dc 100644 --- a/chainsyncer/filter.py +++ b/chainsyncer/filter.py @@ -36,7 +36,8 @@ class SyncFilter: i = 0 (pair, flags) = self.backend.get() for f in self.filters: - if flags & (1 << i) == 0: + if not self.backend.check_filter(i, flags): + #if flags & (1 << i) == 0: logg.debug('applying filter {} {}'.format(str(f), flags)) f.filter(conn, block, tx, session) self.backend.complete_filter(i) diff --git a/tests/test_interrupt.py b/tests/test_interrupt.py index 303ad66..a6e7b4c 100644 --- a/tests/test_interrupt.py +++ b/tests/test_interrupt.py @@ -11,7 +11,7 @@ from chainlib.chain import ChainSpec from chainsyncer.backend.memory import MemBackend from chainsyncer.backend.sql import SyncerBackend from chainsyncer.backend.file import ( - SyncerFileBackend, + FileBackend, data_dir_for, ) @@ -111,22 +111,21 @@ class TestInterrupt(TestBase): self.assertEqual(fltr.c, z) - @unittest.skip('foo') def test_filter_interrupt_memory(self): for vector in self.vectors: self.backend = MemBackend(self.chain_spec, None, target_block=len(vector)) self.assert_filter_interrupt(vector) - def test_filter_interrpt_file(self): - for vector in self.vectors: + def test_filter_interrupt_file(self): + #for vector in self.vectors: + vector = self.vectors.pop() d = tempfile.mkdtemp() #os.makedirs(data_dir_for(self.chain_spec, 'foo', d)) - self.backend = SyncerFileBackend.initial(self.chain_spec, len(vector), base_dir=d) #'foo', base_dir=d) + self.backend = FileBackend.initial(self.chain_spec, len(vector), base_dir=d) #'foo', base_dir=d) self.assert_filter_interrupt(vector) - @unittest.skip('foo') def test_filter_interrupt_sql(self): for vector in self.vectors: self.backend = SyncerBackend.initial(self.chain_spec, len(vector))