diff --git a/chainsyncer/backend/memory.py b/chainsyncer/backend/memory.py index 2ac291d..e49102f 100644 --- a/chainsyncer/backend/memory.py +++ b/chainsyncer/backend/memory.py @@ -1,5 +1,6 @@ # standard imports import logging +import uuid # local imports from .base import Backend @@ -20,19 +21,38 @@ class MemBackend(Backend): :type target_block: int """ - def __init__(self, chain_spec, object_id, target_block=None, block_height=0, tx_height=0, flags=0): + def __init__(self, chain_spec, object_id): super(MemBackend, self).__init__(object_id) self.chain_spec = chain_spec - self.block_height_offset = block_height - self.block_height_cursor = block_height - self.tx_height_offset = tx_height - self.tx_height_cursor = tx_height - self.block_height_target = target_block self.db_session = None - self.flags = flags + self.block_height_offset = 0 + self.block_height_cursor = 0 + self.tx_height_offset = 0 + self.tx_height_cursor = 0 + self.block_height_target = None + self.flags = 0 + self.flags_start = 0 + self.flags_target = 0 self.filter_names = [] + @staticmethod + def custom(chain_spec, target_block, block_offset=0, tx_offset=0, flags=0, flags_count=0, *args, **kwargs): + object_id = kwargs.get('object_id', str(uuid.uuid4())) + backend = MemBackend(chain_spec, object_id) + backend.block_height_offset = block_offset + backend.block_height_cursor = block_offset + backend.tx_height_offset = tx_offset + backend.tx_height_cursor = tx_offset + backend.block_height_target = target_block + backend.flags = flags + backend.flags_count = flags_count + backend.flags_start = flags + flags_target = (2 ** flags_count) - 1 + backend.flags_target = flags_target + return backend + + def connect(self): """NOOP as memory backend implements no connection. """ @@ -67,13 +87,22 @@ class MemBackend(Backend): return ((self.block_height_cursor, self.tx_height_cursor), self.flags) + def start(self): + """Get the initial syncer state + + :rtype: tuple + :returns: block height / tx index tuple, and filter flags value + """ + return ((self.block_height_offset, self.tx_height_offset), self.flags_start) + + def target(self): """Returns the syncer target. :rtype: tuple :returns: block height / tx index tuple """ - return (self.block_height_target, self.flags) + return (self.block_height_target, self.flags_target) def register_filter(self, name): diff --git a/chainsyncer/backend/sql.py b/chainsyncer/backend/sql.py index a930ce5..4eaaf9b 100644 --- a/chainsyncer/backend/sql.py +++ b/chainsyncer/backend/sql.py @@ -148,6 +148,30 @@ class SQLBackend(Backend): return (target, filter_target,) + @staticmethod + def custom(chain_spec, target_block, block_offset=0, tx_offset=0, flags=0, flag_count=0, *args, **kwargs): + """ + + :param flags: flags bit field + :type flags: bytes + :param flag_count: number of flags in bit field + :type flag_count: + """ + session = SessionBase.create_session() + o = BlockchainSync(str(chain_spec), block_offset, tx_offset, target_block) + session.add(o) + session.commit() + object_id = o.id + + of = BlockchainSyncFilter(o, flag_count, flags, kwargs.get('flags_digest')) + session.add(of) + session.commit() + + session.close() + + return SQLBackend(chain_spec, object_id) + + @staticmethod def first(chain_spec): """Returns the model object of the most recent syncer in backend. diff --git a/chainsyncer/db/models/filter.py b/chainsyncer/db/models/filter.py index e8b0a12..ee5339b 100644 --- a/chainsyncer/db/models/filter.py +++ b/chainsyncer/db/models/filter.py @@ -38,7 +38,9 @@ class BlockchainSyncFilter(SessionBase): count = Column(Integer) - def __init__(self, chain_sync, count=0, flags=None, digest=zero_digest): + def __init__(self, chain_sync, count=0, flags=None, digest=None): + if digest == None: + digest = zero_digest self.digest = digest self.count = count diff --git a/chainsyncer/driver/poll.py b/chainsyncer/driver/poll.py index 9fb7abf..1cbbf11 100644 --- a/chainsyncer/driver/poll.py +++ b/chainsyncer/driver/poll.py @@ -9,7 +9,8 @@ from chainsyncer.error import ( NoBlockForYou, ) -logg = logging.getLogger(__name__) +#logg = logging.getLogger(__name__) +logg = logging.getLogger() @@ -29,6 +30,7 @@ class BlockPollSyncer(Syncer): :rtype: tuple :returns: See chainsyncer.backend.base.Backend.get """ + raise ValueError() (pair, fltr) = self.backend.get() start_tx = pair[1] diff --git a/chainsyncer/driver/threadpool.py b/chainsyncer/driver/threadpool.py index bedcef4..e62ffb0 100644 --- a/chainsyncer/driver/threadpool.py +++ b/chainsyncer/driver/threadpool.py @@ -115,7 +115,6 @@ class ThreadPoolHistorySyncer(HistorySyncer): pass - #def process(self, conn, block): def get(self, conn): if not self.running: raise SyncDone() diff --git a/chainsyncer/driver/threadrange.py b/chainsyncer/driver/threadrange.py new file mode 100644 index 0000000..d69adc4 --- /dev/null +++ b/chainsyncer/driver/threadrange.py @@ -0,0 +1,68 @@ +# standard imports +import copy +import logging +import multiprocessing + +# local imports +from chainsyncer.driver.history import HistorySyncer +from .threadpool import ThreadPoolTask + +#logg = logging.getLogger(__name__) +logg = logging.getLogger() + + +def range_to_backends(chain_spec, block_offset, tx_offset, block_target, flags, flags_count, backend_class, backend_count): + block_count = block_target - block_offset + if block_count < backend_count: + logg.warning('block count is less than thread count, adjusting thread count to {}'.format(block_count)) + backend_count = block_count + blocks_per_thread = int(block_count / backend_count) + + backends = [] + for i in range(backend_count): + block_target = block_offset + blocks_per_thread + backend = backend_class.custom(chain_spec, block_target - 1, block_offset=block_offset, tx_offset=tx_offset, flags=flags, flags_count=flags_count) + backends.append(backend) + block_offset = block_target + tx_offset = 0 + flags = 0 + + return backends + + +class ThreadPoolRangeTask: + + loop_func = None + + def __init__(self, backend, conn): + pass + + + def foo(self, a, b): + return self.loop_func() + + +class ThreadPoolRangeHistorySyncer(HistorySyncer): + + def __init__(self, conn_factory, thread_count, backends, chain_interface, loop_func=HistorySyncer.loop, pre_callback=None, block_callback=None, post_callback=None, runlevel_callback=None): + if thread_count > len(backends): + raise ValueError('thread count {} is greater than than backend count {}'.format(thread_count, len(backends))) + self.backends = backends + self.thread_count = thread_count + self.conn_factory = conn_factory + self.single_sync_offset = 0 + self.runlevel_callback = None + + ThreadPoolRangeTask.loop_func = loop_func + + + def loop(self, interval, conn): + super_loop = super(ThreadPoolRangeHistorySyncer, self).loop + self.worker_pool = multiprocessing.Pool(processes=self.thread_count) + for backend in self.backends: + conn = self.conn_factory() + task = ThreadPoolRangeTask(backend, conn) + t = self.worker_pool.apply_async(task.foo, (backend, conn,)) + print(t.get()) + self.worker_pool.close() + self.worker_pool.join() diff --git a/chainsyncer/unittest/base.py b/chainsyncer/unittest/base.py index 2b2c120..d3ed7b3 100644 --- a/chainsyncer/unittest/base.py +++ b/chainsyncer/unittest/base.py @@ -12,7 +12,6 @@ from chainsyncer.error import NoBlockForYou logg = logging.getLogger().getChild(__name__) - class MockConn: """Noop connection mocker. diff --git a/tests/test_database.py b/tests/test_database.py index 8978f02..2ea0b77 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -184,5 +184,16 @@ class TestDatabase(TestBase): self.assertEqual(flags, 5) + def test_backend_sql_custom(self): + chain_spec = ChainSpec('evm', 'bloxberg', 8996, 'foo') + flags = 5 + flags_target = 1023 + flag_count = 10 + backend = SQLBackend.custom(chain_spec, 666, 42, 2, flags, flag_count) + self.assertEqual(((42, 2), flags), backend.start()) + self.assertEqual(((42, 2), flags), backend.get()) + self.assertEqual((666, flags_target), backend.target()) + + if __name__ == '__main__': unittest.main() diff --git a/tests/test_mem.py b/tests/test_mem.py new file mode 100644 index 0000000..03d9a8f --- /dev/null +++ b/tests/test_mem.py @@ -0,0 +1,29 @@ +# standard imports +import unittest + +# external imports +from chainlib.chain import ChainSpec + +# local imports +from chainsyncer.backend.memory import MemBackend + +# testutil imports +from tests.chainsyncer_base import TestBase + + +class TestMem(TestBase): + + def test_backend_mem_custom(self): + chain_spec = ChainSpec('evm', 'bloxberg', 8996, 'foo') + flags = int(5).to_bytes(2, 'big') + #flags_target = int(1024-1).to_bytes(2, 'big') + flag_count = 10 + backend = MemBackend.custom(chain_spec, 666, 42, 2, flags, flag_count, object_id='xyzzy') + self.assertEqual(((42, 2), flags), backend.start()) + self.assertEqual(((42, 2), flags), backend.get()) + self.assertEqual((666, flags), backend.target()) + self.assertEqual(backend.object_id, 'xyzzy') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_thread.py b/tests/test_thread.py new file mode 100644 index 0000000..5c8893a --- /dev/null +++ b/tests/test_thread.py @@ -0,0 +1,12 @@ +# standard imports +import logging +import unittest + +# test imports +from tests.chainsyncer_base import TestBase + + +class TestThreadRange(TestBase): + + def test_hello(self): + ThreadPoolRangeHistorySyncer(None, 3) diff --git a/tests/test_thread_range.py b/tests/test_thread_range.py new file mode 100644 index 0000000..5a57646 --- /dev/null +++ b/tests/test_thread_range.py @@ -0,0 +1,56 @@ +# standard imports +import unittest +import logging + +# external imports +from chainlib.chain import ChainSpec + +# local imports +from chainsyncer.backend.memory import MemBackend +from chainsyncer.driver.threadrange import ( + range_to_backends, + ThreadPoolRangeHistorySyncer, + ) +from chainsyncer.unittest.base import MockConn + +# testutil imports +from tests.chainsyncer_base import TestBase + +logging.basicConfig(level=logging.DEBUG) +logg = logging.getLogger() + + +class TestThreadRange(TestBase): + + def test_range_split_even(self): + chain_spec = ChainSpec('evm', 'bloxberg', 8996, 'foo') + backends = range_to_backends(chain_spec, 5, 3, 20, 5, 10, MemBackend, 3) + self.assertEqual(len(backends), 3) + self.assertEqual(((5, 3), 5), backends[0].start()) + self.assertEqual((9, 1023), backends[0].target()) + self.assertEqual(((10, 0), 0), backends[1].start()) + self.assertEqual((14, 1023), backends[1].target()) + self.assertEqual(((15, 0), 0), backends[2].start()) + self.assertEqual((19, 1023), backends[2].target()) + + + def test_range_split_underflow(self): + chain_spec = ChainSpec('evm', 'bloxberg', 8996, 'foo') + backends = range_to_backends(chain_spec, 5, 3, 7, 5, 10, MemBackend, 3) + self.assertEqual(len(backends), 2) + self.assertEqual(((5, 3), 5), backends[0].start()) + self.assertEqual((5, 1023), backends[0].target()) + self.assertEqual(((6, 0), 0), backends[1].start()) + self.assertEqual((6, 1023), backends[1].target()) + + + def test_range_syncer(self): + chain_spec = ChainSpec('evm', 'bloxberg', 8996, 'foo') + backends = range_to_backends(chain_spec, 5, 3, 20, 5, 10, MemBackend, 3) + + syncer = ThreadPoolRangeHistorySyncer(MockConn, 3, backends, self.interface) + syncer.loop(1, None) + + +if __name__ == '__main__': + unittest.main()