Add thread range generator
This commit is contained in:
parent
db6128f823
commit
a49c152e24
@ -1,5 +1,6 @@
|
|||||||
# standard imports
|
# standard imports
|
||||||
import logging
|
import logging
|
||||||
|
import uuid
|
||||||
|
|
||||||
# local imports
|
# local imports
|
||||||
from .base import Backend
|
from .base import Backend
|
||||||
@ -20,19 +21,38 @@ class MemBackend(Backend):
|
|||||||
:type target_block: int
|
:type target_block: int
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, chain_spec, object_id, target_block=None, block_height=0, tx_height=0, flags=0):
|
def __init__(self, chain_spec, object_id):
|
||||||
super(MemBackend, self).__init__(object_id)
|
super(MemBackend, self).__init__(object_id)
|
||||||
self.chain_spec = chain_spec
|
self.chain_spec = chain_spec
|
||||||
self.block_height_offset = block_height
|
|
||||||
self.block_height_cursor = block_height
|
|
||||||
self.tx_height_offset = tx_height
|
|
||||||
self.tx_height_cursor = tx_height
|
|
||||||
self.block_height_target = target_block
|
|
||||||
self.db_session = None
|
self.db_session = None
|
||||||
self.flags = flags
|
self.block_height_offset = 0
|
||||||
|
self.block_height_cursor = 0
|
||||||
|
self.tx_height_offset = 0
|
||||||
|
self.tx_height_cursor = 0
|
||||||
|
self.block_height_target = None
|
||||||
|
self.flags = 0
|
||||||
|
self.flags_start = 0
|
||||||
|
self.flags_target = 0
|
||||||
self.filter_names = []
|
self.filter_names = []
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def custom(chain_spec, target_block, block_offset=0, tx_offset=0, flags=0, flags_count=0, *args, **kwargs):
|
||||||
|
object_id = kwargs.get('object_id', str(uuid.uuid4()))
|
||||||
|
backend = MemBackend(chain_spec, object_id)
|
||||||
|
backend.block_height_offset = block_offset
|
||||||
|
backend.block_height_cursor = block_offset
|
||||||
|
backend.tx_height_offset = tx_offset
|
||||||
|
backend.tx_height_cursor = tx_offset
|
||||||
|
backend.block_height_target = target_block
|
||||||
|
backend.flags = flags
|
||||||
|
backend.flags_count = flags_count
|
||||||
|
backend.flags_start = flags
|
||||||
|
flags_target = (2 ** flags_count) - 1
|
||||||
|
backend.flags_target = flags_target
|
||||||
|
return backend
|
||||||
|
|
||||||
|
|
||||||
def connect(self):
|
def connect(self):
|
||||||
"""NOOP as memory backend implements no connection.
|
"""NOOP as memory backend implements no connection.
|
||||||
"""
|
"""
|
||||||
@ -67,13 +87,22 @@ class MemBackend(Backend):
|
|||||||
return ((self.block_height_cursor, self.tx_height_cursor), self.flags)
|
return ((self.block_height_cursor, self.tx_height_cursor), self.flags)
|
||||||
|
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
"""Get the initial syncer state
|
||||||
|
|
||||||
|
:rtype: tuple
|
||||||
|
:returns: block height / tx index tuple, and filter flags value
|
||||||
|
"""
|
||||||
|
return ((self.block_height_offset, self.tx_height_offset), self.flags_start)
|
||||||
|
|
||||||
|
|
||||||
def target(self):
|
def target(self):
|
||||||
"""Returns the syncer target.
|
"""Returns the syncer target.
|
||||||
|
|
||||||
:rtype: tuple
|
:rtype: tuple
|
||||||
:returns: block height / tx index tuple
|
:returns: block height / tx index tuple
|
||||||
"""
|
"""
|
||||||
return (self.block_height_target, self.flags)
|
return (self.block_height_target, self.flags_target)
|
||||||
|
|
||||||
|
|
||||||
def register_filter(self, name):
|
def register_filter(self, name):
|
||||||
|
@ -148,6 +148,30 @@ class SQLBackend(Backend):
|
|||||||
return (target, filter_target,)
|
return (target, filter_target,)
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def custom(chain_spec, target_block, block_offset=0, tx_offset=0, flags=0, flag_count=0, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
|
||||||
|
:param flags: flags bit field
|
||||||
|
:type flags: bytes
|
||||||
|
:param flag_count: number of flags in bit field
|
||||||
|
:type flag_count:
|
||||||
|
"""
|
||||||
|
session = SessionBase.create_session()
|
||||||
|
o = BlockchainSync(str(chain_spec), block_offset, tx_offset, target_block)
|
||||||
|
session.add(o)
|
||||||
|
session.commit()
|
||||||
|
object_id = o.id
|
||||||
|
|
||||||
|
of = BlockchainSyncFilter(o, flag_count, flags, kwargs.get('flags_digest'))
|
||||||
|
session.add(of)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
return SQLBackend(chain_spec, object_id)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def first(chain_spec):
|
def first(chain_spec):
|
||||||
"""Returns the model object of the most recent syncer in backend.
|
"""Returns the model object of the most recent syncer in backend.
|
||||||
|
@ -38,7 +38,9 @@ class BlockchainSyncFilter(SessionBase):
|
|||||||
count = Column(Integer)
|
count = Column(Integer)
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, chain_sync, count=0, flags=None, digest=zero_digest):
|
def __init__(self, chain_sync, count=0, flags=None, digest=None):
|
||||||
|
if digest == None:
|
||||||
|
digest = zero_digest
|
||||||
self.digest = digest
|
self.digest = digest
|
||||||
self.count = count
|
self.count = count
|
||||||
|
|
||||||
|
@ -9,7 +9,8 @@ from chainsyncer.error import (
|
|||||||
NoBlockForYou,
|
NoBlockForYou,
|
||||||
)
|
)
|
||||||
|
|
||||||
logg = logging.getLogger(__name__)
|
#logg = logging.getLogger(__name__)
|
||||||
|
logg = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -29,6 +30,7 @@ class BlockPollSyncer(Syncer):
|
|||||||
:rtype: tuple
|
:rtype: tuple
|
||||||
:returns: See chainsyncer.backend.base.Backend.get
|
:returns: See chainsyncer.backend.base.Backend.get
|
||||||
"""
|
"""
|
||||||
|
raise ValueError()
|
||||||
(pair, fltr) = self.backend.get()
|
(pair, fltr) = self.backend.get()
|
||||||
start_tx = pair[1]
|
start_tx = pair[1]
|
||||||
|
|
||||||
|
@ -115,7 +115,6 @@ class ThreadPoolHistorySyncer(HistorySyncer):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
#def process(self, conn, block):
|
|
||||||
def get(self, conn):
|
def get(self, conn):
|
||||||
if not self.running:
|
if not self.running:
|
||||||
raise SyncDone()
|
raise SyncDone()
|
||||||
|
68
chainsyncer/driver/threadrange.py
Normal file
68
chainsyncer/driver/threadrange.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
# standard imports
|
||||||
|
import copy
|
||||||
|
import logging
|
||||||
|
import multiprocessing
|
||||||
|
|
||||||
|
# local imports
|
||||||
|
from chainsyncer.driver.history import HistorySyncer
|
||||||
|
from .threadpool import ThreadPoolTask
|
||||||
|
|
||||||
|
#logg = logging.getLogger(__name__)
|
||||||
|
logg = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
def range_to_backends(chain_spec, block_offset, tx_offset, block_target, flags, flags_count, backend_class, backend_count):
|
||||||
|
block_count = block_target - block_offset
|
||||||
|
if block_count < backend_count:
|
||||||
|
logg.warning('block count is less than thread count, adjusting thread count to {}'.format(block_count))
|
||||||
|
backend_count = block_count
|
||||||
|
blocks_per_thread = int(block_count / backend_count)
|
||||||
|
|
||||||
|
backends = []
|
||||||
|
for i in range(backend_count):
|
||||||
|
block_target = block_offset + blocks_per_thread
|
||||||
|
backend = backend_class.custom(chain_spec, block_target - 1, block_offset=block_offset, tx_offset=tx_offset, flags=flags, flags_count=flags_count)
|
||||||
|
backends.append(backend)
|
||||||
|
block_offset = block_target
|
||||||
|
tx_offset = 0
|
||||||
|
flags = 0
|
||||||
|
|
||||||
|
return backends
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadPoolRangeTask:
|
||||||
|
|
||||||
|
loop_func = None
|
||||||
|
|
||||||
|
def __init__(self, backend, conn):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def foo(self, a, b):
|
||||||
|
return self.loop_func()
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadPoolRangeHistorySyncer(HistorySyncer):
|
||||||
|
|
||||||
|
def __init__(self, conn_factory, thread_count, backends, chain_interface, loop_func=HistorySyncer.loop, pre_callback=None, block_callback=None, post_callback=None, runlevel_callback=None):
|
||||||
|
if thread_count > len(backends):
|
||||||
|
raise ValueError('thread count {} is greater than than backend count {}'.format(thread_count, len(backends)))
|
||||||
|
self.backends = backends
|
||||||
|
self.thread_count = thread_count
|
||||||
|
self.conn_factory = conn_factory
|
||||||
|
self.single_sync_offset = 0
|
||||||
|
self.runlevel_callback = None
|
||||||
|
|
||||||
|
ThreadPoolRangeTask.loop_func = loop_func
|
||||||
|
|
||||||
|
|
||||||
|
def loop(self, interval, conn):
|
||||||
|
super_loop = super(ThreadPoolRangeHistorySyncer, self).loop
|
||||||
|
self.worker_pool = multiprocessing.Pool(processes=self.thread_count)
|
||||||
|
for backend in self.backends:
|
||||||
|
conn = self.conn_factory()
|
||||||
|
task = ThreadPoolRangeTask(backend, conn)
|
||||||
|
t = self.worker_pool.apply_async(task.foo, (backend, conn,))
|
||||||
|
print(t.get())
|
||||||
|
self.worker_pool.close()
|
||||||
|
self.worker_pool.join()
|
@ -12,7 +12,6 @@ from chainsyncer.error import NoBlockForYou
|
|||||||
logg = logging.getLogger().getChild(__name__)
|
logg = logging.getLogger().getChild(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class MockConn:
|
class MockConn:
|
||||||
"""Noop connection mocker.
|
"""Noop connection mocker.
|
||||||
|
|
||||||
|
@ -184,5 +184,16 @@ class TestDatabase(TestBase):
|
|||||||
self.assertEqual(flags, 5)
|
self.assertEqual(flags, 5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_backend_sql_custom(self):
|
||||||
|
chain_spec = ChainSpec('evm', 'bloxberg', 8996, 'foo')
|
||||||
|
flags = 5
|
||||||
|
flags_target = 1023
|
||||||
|
flag_count = 10
|
||||||
|
backend = SQLBackend.custom(chain_spec, 666, 42, 2, flags, flag_count)
|
||||||
|
self.assertEqual(((42, 2), flags), backend.start())
|
||||||
|
self.assertEqual(((42, 2), flags), backend.get())
|
||||||
|
self.assertEqual((666, flags_target), backend.target())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
29
tests/test_mem.py
Normal file
29
tests/test_mem.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
# standard imports
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
# external imports
|
||||||
|
from chainlib.chain import ChainSpec
|
||||||
|
|
||||||
|
# local imports
|
||||||
|
from chainsyncer.backend.memory import MemBackend
|
||||||
|
|
||||||
|
# testutil imports
|
||||||
|
from tests.chainsyncer_base import TestBase
|
||||||
|
|
||||||
|
|
||||||
|
class TestMem(TestBase):
|
||||||
|
|
||||||
|
def test_backend_mem_custom(self):
|
||||||
|
chain_spec = ChainSpec('evm', 'bloxberg', 8996, 'foo')
|
||||||
|
flags = int(5).to_bytes(2, 'big')
|
||||||
|
#flags_target = int(1024-1).to_bytes(2, 'big')
|
||||||
|
flag_count = 10
|
||||||
|
backend = MemBackend.custom(chain_spec, 666, 42, 2, flags, flag_count, object_id='xyzzy')
|
||||||
|
self.assertEqual(((42, 2), flags), backend.start())
|
||||||
|
self.assertEqual(((42, 2), flags), backend.get())
|
||||||
|
self.assertEqual((666, flags), backend.target())
|
||||||
|
self.assertEqual(backend.object_id, 'xyzzy')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
12
tests/test_thread.py
Normal file
12
tests/test_thread.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
# standard imports
|
||||||
|
import logging
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
# test imports
|
||||||
|
from tests.chainsyncer_base import TestBase
|
||||||
|
|
||||||
|
|
||||||
|
class TestThreadRange(TestBase):
|
||||||
|
|
||||||
|
def test_hello(self):
|
||||||
|
ThreadPoolRangeHistorySyncer(None, 3)
|
56
tests/test_thread_range.py
Normal file
56
tests/test_thread_range.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
# standard imports
|
||||||
|
import unittest
|
||||||
|
import logging
|
||||||
|
|
||||||
|
# external imports
|
||||||
|
from chainlib.chain import ChainSpec
|
||||||
|
|
||||||
|
# local imports
|
||||||
|
from chainsyncer.backend.memory import MemBackend
|
||||||
|
from chainsyncer.driver.threadrange import (
|
||||||
|
range_to_backends,
|
||||||
|
ThreadPoolRangeHistorySyncer,
|
||||||
|
)
|
||||||
|
from chainsyncer.unittest.base import MockConn
|
||||||
|
|
||||||
|
# testutil imports
|
||||||
|
from tests.chainsyncer_base import TestBase
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
logg = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
class TestThreadRange(TestBase):
|
||||||
|
|
||||||
|
def test_range_split_even(self):
|
||||||
|
chain_spec = ChainSpec('evm', 'bloxberg', 8996, 'foo')
|
||||||
|
backends = range_to_backends(chain_spec, 5, 3, 20, 5, 10, MemBackend, 3)
|
||||||
|
self.assertEqual(len(backends), 3)
|
||||||
|
self.assertEqual(((5, 3), 5), backends[0].start())
|
||||||
|
self.assertEqual((9, 1023), backends[0].target())
|
||||||
|
self.assertEqual(((10, 0), 0), backends[1].start())
|
||||||
|
self.assertEqual((14, 1023), backends[1].target())
|
||||||
|
self.assertEqual(((15, 0), 0), backends[2].start())
|
||||||
|
self.assertEqual((19, 1023), backends[2].target())
|
||||||
|
|
||||||
|
|
||||||
|
def test_range_split_underflow(self):
|
||||||
|
chain_spec = ChainSpec('evm', 'bloxberg', 8996, 'foo')
|
||||||
|
backends = range_to_backends(chain_spec, 5, 3, 7, 5, 10, MemBackend, 3)
|
||||||
|
self.assertEqual(len(backends), 2)
|
||||||
|
self.assertEqual(((5, 3), 5), backends[0].start())
|
||||||
|
self.assertEqual((5, 1023), backends[0].target())
|
||||||
|
self.assertEqual(((6, 0), 0), backends[1].start())
|
||||||
|
self.assertEqual((6, 1023), backends[1].target())
|
||||||
|
|
||||||
|
|
||||||
|
def test_range_syncer(self):
|
||||||
|
chain_spec = ChainSpec('evm', 'bloxberg', 8996, 'foo')
|
||||||
|
backends = range_to_backends(chain_spec, 5, 3, 20, 5, 10, MemBackend, 3)
|
||||||
|
|
||||||
|
syncer = ThreadPoolRangeHistorySyncer(MockConn, 3, backends, self.interface)
|
||||||
|
syncer.loop(1, None)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
Loading…
Reference in New Issue
Block a user