diff --git a/chainsyncer/error.py b/chainsyncer/error.py index daf42c4..fab66ce 100644 --- a/chainsyncer/error.py +++ b/chainsyncer/error.py @@ -3,6 +3,7 @@ class SyncDone(Exception): """ pass + class NoBlockForYou(Exception): """Exception raised when attempt to retrieve a block from network that does not (yet) exist. """ @@ -27,6 +28,20 @@ class LockError(Exception): pass +class FilterDone(Exception): + """Exception raised when all registered filters have been executed + """ + + +class InterruptError(FilterDone): + """Exception for interrupting or attempting to use an interrupted sync + """ + + +class IncompleteFilterError(Exception): + """Exception raised if filter reset is executed prematurely + """ + #class AbortTx(Exception): # """ # """ diff --git a/chainsyncer/session.py b/chainsyncer/session.py index aa2c08e..680984e 100644 --- a/chainsyncer/session.py +++ b/chainsyncer/session.py @@ -7,24 +7,28 @@ class SyncSession: def __init__(self, session_store): self.session_store = session_store self.filters = [] - self.started = False + self.start = self.session_store.start + self.get = self.session_store.get + self.started = self.session_store.started - def add_filter(self, fltr): + def register(self, fltr): if self.started: raise RuntimeError('filters cannot be changed after syncer start') self.session_store.register(fltr) self.filters.append(fltr) - def start(self): - self.started = True - - def filter(self, conn, block, tx): self.sync_state.connect() for fltr in filters: - self.sync_start.lock() - self.sync_start.unlock() + try: + self.sync_start.advance() + except FilterDone: + break + interrupt = fltr(conn, block, tx) + try: + self.sync_start.release(interrupt=interrupt) + except FilterDone: + break self.sync_start.disconnect() - diff --git a/chainsyncer/state/base.py b/chainsyncer/state/base.py index d0b370c..a8616a0 100644 --- a/chainsyncer/state/base.py +++ b/chainsyncer/state/base.py @@ -1,5 +1,8 @@ # standard imports import hashlib +import logging + +logg = logging.getLogger(__name__) class SyncState: @@ -11,8 +14,20 @@ class SyncState: self.__syncs = {} self.synced = False self.connected = False - self.state_store.add('INTERRUPT') + self.state_store.add('DONE') self.state_store.add('LOCK') + self.state_store.add('INTERRUPT') + self.state_store.add('RESET') + self.state = self.state_store.state + self.put = self.state_store.put + self.set = self.state_store.set + self.next = self.state_store.next + self.move = self.state_store.move + self.unset = self.state_store.unset + self.from_name = self.state_store.from_name + self.state_store.sync() + self.all = self.state_store.all + self.started = False def __verify_sum(self, v): @@ -30,6 +45,9 @@ class SyncState: self.digest += z s = fltr.common_name() self.state_store.add(s) + n = self.state_store.from_name(s) + logg.debug('add {} {} {}'.format(s, n, self)) + def sum(self): @@ -53,9 +71,10 @@ class SyncState: self.connected = False - def lock(self): - pass + def start(self): + self.state_store.start() + self.started = True - def unlock(self): - pass + def get(self, k): + raise NotImplementedError() diff --git a/chainsyncer/store/fs.py b/chainsyncer/store/fs.py index 3083173..191ce48 100644 --- a/chainsyncer/store/fs.py +++ b/chainsyncer/store/fs.py @@ -6,32 +6,92 @@ import logging # external imports from shep.store.file import SimpleFileStoreFactory from shep.persist import PersistedState +from shep.error import StateInvalid # local imports from chainsyncer.state import SyncState - +from chainsyncer.error import ( + LockError, + FilterDone, + InterruptError, + IncompleteFilterError, + ) logg = logging.getLogger(__name__) +# NOT thread safe class SyncFsItem: - def __init__(self, offset, target, sync_state, filter_state, started=False): + def __init__(self, offset, target, sync_state, filter_state, started=False, ignore_invalid=False): self.offset = offset self.target = target self.sync_state = sync_state self.filter_state = filter_state - s = str(offset) + self.state_key = str(offset) match_state = self.sync_state.NEW if started: match_state = self.sync_state.SYNC - v = self.sync_state.get(s) + v = self.sync_state.get(self.state_key) self.cursor = int.from_bytes(v, 'big') + if self.filter_state.state(self.state_key) & self.filter_state.from_name('LOCK') and not ignore_invalid: + raise LockError(s) - def next(self): - pass + self.count = len(self.filter_state.all(pure=True)) - 3 + self.skip_filter = False + if self.count == 0: + self.skip_filter = True + else: + self.filter_state.move(self.state_key, self.filter_state.from_name('RESET')) + def __check_done(self): + if self.filter_state.state(self.state_key) & self.filter_state.from_name('INTERRUPT') > 0: + raise InterruptError(self.state_key) + if self.filter_state.state(self.state_key) & self.filter_state.from_name('DONE') > 0: + raise FilterDone(self.state_key) + + + def reset(self): + if self.filter_state.state(self.state_key) & self.filter_state.from_name('LOCK') > 0: + raise LockError('reset attempt on {} when state locked'.format(self.state_key)) + if self.filter_state.state(self.state_key) & self.filter_state.from_name('DONE') == 0: + raise IncompleteFilterError('reset attempt on {} when incomplete'.format(self.state_key)) + self.filter_state.move(self.state_key, self.filter_state.from_name('RESET')) + + + def advance(self): + if self.skip_filter: + raise FilterDone() + self.__check_done() + + if self.filter_state.state(self.state_key) & self.filter_state.from_name('LOCK') > 0: + raise LockError('advance attempt on {} when state locked'.format(self.state_key)) + done = False + try: + self.filter_state.next(self.state_key) + except StateInvalid: + done = True + if done: + self.filter_state.set(self.state_key, self.filter_state.from_name('DONE')) + raise FilterDone() + self.filter_state.set(self.state_key, self.filter_state.from_name('LOCK')) + + + def release(self, interrupt=False): + if self.skip_filter: + raise FilterDone() + if interrupt: + self.filter_state.set(self.state_key, self.filter_state.from_name('INTERRUPT')) + self.filter_state.set(self.state_key, self.filter_state.from_name('DONE')) + return + + state = self.filter_state.state(self.state_key) + if state & self.filter_state.from_name('LOCK') == 0: + raise LockError('release attempt on {} when state unlocked'.format(self.state_key)) + self.filter_state.unset(self.state_key, self.filter_state.from_name('LOCK')) + + def __str__(self): return 'syncitem offset {} target {} cursor {}'.format(self.offset, self.target, self.cursor) @@ -46,6 +106,7 @@ class SyncFsStore: self.first = False self.target = None self.items = {} + self.started = False default_path = os.path.join(base_path, 'default') @@ -76,10 +137,10 @@ class SyncFsStore: base_filter_path = os.path.join(self.session_path, 'filter') factory = SimpleFileStoreFactory(base_filter_path, binary=True) - filter_state_backend = PersistedState(factory, 0) + filter_state_backend = PersistedState(factory.add, 0, check_alias=False) self.filter_state = SyncState(filter_state_backend) self.register = self.filter_state.register - + def __create_path(self, base_path, default_path, session_id=None): logg.debug('fs store path {} does not exist, creating'.format(self.session_path)) @@ -144,12 +205,22 @@ class SyncFsStore: if self.first: block_number = offset block_number_bytes = block_number.to_bytes(4, 'big') - self.state.put(str(block_number), block_number_bytes) + block_number_str = str(block_number) + self.state.put(block_number_str, block_number_bytes) + self.filter_state.put(block_number_str) + o = SyncFsItem(block_number, target, self.state, self.filter_state) + self.items[block_number] = o elif offset > 0: logg.warning('block number argument {} for start ignored for already initiated sync {}'.format(offset, self.session_id)) + self.started = True + def stop(self): if self.target == 0: block_number = self.height + 1 block_number_bytes = block_number.to_bytes(4, 'big') self.state.put(str(block_number), block_number_bytes) + + + def get(self, k): + return self.items[k] diff --git a/chainsyncer/unittest/__init__.py b/chainsyncer/unittest/__init__.py new file mode 100644 index 0000000..9b5ed21 --- /dev/null +++ b/chainsyncer/unittest/__init__.py @@ -0,0 +1 @@ +from .base import * diff --git a/chainsyncer/unittest/base.py b/chainsyncer/unittest/base.py index d3ed7b3..2ed0f11 100644 --- a/chainsyncer/unittest/base.py +++ b/chainsyncer/unittest/base.py @@ -1,12 +1,14 @@ # standard imports import os import logging +import hashlib # external imports from hexathon import add_0x +from shep.state import State # local imports -from chainsyncer.driver.history import HistorySyncer +#from chainsyncer.driver.history import HistorySyncer from chainsyncer.error import NoBlockForYou logg = logging.getLogger().getChild(__name__) @@ -67,42 +69,77 @@ class MockBlock: return MockTx(i, self.txs[i]) -class TestSyncer(HistorySyncer): - """Unittest extension of history syncer driver. +class MockStore(State): - :param backend: Syncer backend - :type backend: chainsyncer.backend.base.Backend implementation - :param chain_interface: Chain interface - :type chain_interface: chainlib.interface.ChainInterface implementation - :param tx_counts: List of integer values defining how many mock transactions to generate per block. Mock blocks will be generated for each element in list. - :type tx_counts: list - """ + def __init__(self, bits=0): + super(MockStore, self).__init__(bits, check_alias=False) - def __init__(self, backend, chain_interface, tx_counts=[]): - self.tx_counts = tx_counts - super(TestSyncer, self).__init__(backend, chain_interface) + + def start(self): + pass - def get(self, conn): - """Implements the block getter of chainsyncer.driver.base.Syncer. +class MockFilter: - :param conn: RPC connection - :type conn: chainlib.connection.RPCConnection - :raises NoBlockForYou: End of mocked block array reached - :rtype: chainsyncer.unittest.base.MockBlock - :returns: Mock block. - """ - (pair, fltr) = self.backend.get() - (target_block, fltr) = self.backend.target() - block_height = pair[0] + def __init__(self, name, brk=False, z=None): + self.name = name + if z == None: + h = hashlib.sha256() + h.update(self.name.encode('utf-8')) + z = h.digest() + self.z = z + self.brk = brk - if block_height == target_block: - self.running = False - raise NoBlockForYou() - block_txs = [] - if block_height < len(self.tx_counts): - for i in range(self.tx_counts[block_height]): - block_txs.append(add_0x(os.urandom(32).hex())) - - return MockBlock(block_height, block_txs) + def sum(self): + return self.z + + + def common_name(self): + return self.name + + + def filter(self, conn, block, tx): + return self.brk + + + +#class TestSyncer(HistorySyncer): +# """Unittest extension of history syncer driver. +# +# :param backend: Syncer backend +# :type backend: chainsyncer.backend.base.Backend implementation +# :param chain_interface: Chain interface +# :type chain_interface: chainlib.interface.ChainInterface implementation +# :param tx_counts: List of integer values defining how many mock transactions to generate per block. Mock blocks will be generated for each element in list. +# :type tx_counts: list +# """ +# +# def __init__(self, backend, chain_interface, tx_counts=[]): +# self.tx_counts = tx_counts +# super(TestSyncer, self).__init__(backend, chain_interface) +# +# +# def get(self, conn): +# """Implements the block getter of chainsyncer.driver.base.Syncer. +# +# :param conn: RPC connection +# :type conn: chainlib.connection.RPCConnection +# :raises NoBlockForYou: End of mocked block array reached +# :rtype: chainsyncer.unittest.base.MockBlock +# :returns: Mock block. +# """ +# (pair, fltr) = self.backend.get() +# (target_block, fltr) = self.backend.target() +# block_height = pair[0] +# +# if block_height == target_block: +# self.running = False +# raise NoBlockForYou() +# +# block_txs = [] +# if block_height < len(self.tx_counts): +# for i in range(self.tx_counts[block_height]): +# block_txs.append(add_0x(os.urandom(32).hex())) +# +# return MockBlock(block_height, block_txs) diff --git a/tests/test_basic.py b/tests/test_basic.py index ac65c25..990d501 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -1,50 +1,21 @@ # standard imports import unittest -import hashlib import tempfile import shutil import logging -# external imports -from shep.state import State - # local imports from chainsyncer.session import SyncSession from chainsyncer.state import SyncState from chainsyncer.store.fs import SyncFsStore +from chainsyncer.unittest import ( + MockStore, + MockFilter, + ) logging.basicConfig(level=logging.DEBUG) logg = logging.getLogger() -class MockStore(State): - - def __init__(self, bits=0): - super(MockStore, self).__init__(bits, check_alias=False) - - -class MockFilter: - - def __init__(self, name, brk=False, z=None): - self.name = name - if z == None: - h = hashlib.sha256() - h.update(self.name.encode('utf-8')) - z = h.digest() - self.z = z - self.brk = brk - - - def sum(self): - return self.z - - - def common_name(self): - return self.name - - - def filter(self, conn, block, tx): - return self.brk - class TestSync(unittest.TestCase): @@ -64,7 +35,7 @@ class TestSync(unittest.TestCase): def test_sum(self): - store = MockStore(4) + store = MockStore(6) state = SyncState(store) b = b'\x2a' * 32 diff --git a/tests/test_fs.py b/tests/test_fs.py index f41d8d7..125c073 100644 --- a/tests/test_fs.py +++ b/tests/test_fs.py @@ -8,6 +8,13 @@ import os # local imports from chainsyncer.store.fs import SyncFsStore +from chainsyncer.session import SyncSession +from chainsyncer.error import ( + LockError, + FilterDone, + IncompleteFilterError, + ) +from chainsyncer.unittest import MockFilter logging.basicConfig(level=logging.DEBUG) logg = logging.getLogger() @@ -69,8 +76,97 @@ class TestFs(unittest.TestCase): store = SyncFsStore(self.path) store.start(13) self.assertTrue(store.first) + # todo not done + def test_sync_process_nofilter(self): + store = SyncFsStore(self.path) + session = SyncSession(store) + session.start() + o = session.get(0) + with self.assertRaises(FilterDone): + o.advance() + + + def test_sync_process_onefilter(self): + store = SyncFsStore(self.path) + session = SyncSession(store) + + fltr_one = MockFilter('foo') + session.register(fltr_one) + + session.start() + o = session.get(0) + o.advance() + o.release() + + + def test_sync_process_outoforder(self): + store = SyncFsStore(self.path) + session = SyncSession(store) + + fltr_one = MockFilter('foo') + session.register(fltr_one) + fltr_two = MockFilter('two') + session.register(fltr_two) + + session.start() + o = session.get(0) + o.advance() + with self.assertRaises(LockError): + o.advance() + + o.release() + with self.assertRaises(LockError): + o.release() + + o.advance() + o.release() + + + def test_sync_process_interrupt(self): + store = SyncFsStore(self.path) + session = SyncSession(store) + + fltr_one = MockFilter('foo') + session.register(fltr_one) + fltr_two = MockFilter('bar') + session.register(fltr_two) + + session.start() + o = session.get(0) + o.advance() + o.release(interrupt=True) + with self.assertRaises(FilterDone): + o.advance() + + + def test_sync_process_reset(self): + store = SyncFsStore(self.path) + session = SyncSession(store) + + fltr_one = MockFilter('foo') + session.register(fltr_one) + fltr_two = MockFilter('bar') + session.register(fltr_two) + + session.start() + o = session.get(0) + o.advance() + with self.assertRaises(LockError): + o.reset() + o.release() + with self.assertRaises(IncompleteFilterError): + o.reset() + + o.advance() + o.release() + + with self.assertRaises(FilterDone): + o.advance() + + o.reset() + if __name__ == '__main__': unittest.main()