Introduce driver object

This commit is contained in:
lash 2022-03-17 23:48:23 +00:00
parent dcf095cc86
commit 18f16d878f
Signed by: lash
GPG Key ID: 21D2E7BB88C2A746
15 changed files with 280 additions and 694 deletions

View File

@ -3,29 +3,55 @@ import logging
import time
# local imports
from .base import Syncer
from chainsyncer.error import (
SyncDone,
NoBlockForYou,
)
from chainsyncer.session import SyncSession
logg = logging.getLogger(__name__)
NS_DIV = 1000000000
class BlockPollSyncer(Syncer):
"""Syncer driver implementation of chainsyncer.driver.base.Syncer that retrieves new blocks through polling.
"""
class SyncDriver:
name = 'blockpoll'
running_global = True
def __init__(self, backend, chain_interface, pre_callback=None, block_callback=None, post_callback=None, idle_callback=None):
super(BlockPollSyncer, self).__init__(backend, chain_interface, pre_callback=pre_callback, block_callback=block_callback, post_callback=post_callback)
def __init__(self, conn, store, pre_callback=None, post_callback=None, block_callback=None, idle_callback=None):
self.store = store
self.running = True
self.pre_callback = pre_callback
self.post_callback = post_callback
self.block_callback = block_callback
self.idle_callback = idle_callback
self.last_start = 0
self.clock_id = time.CLOCK_MONOTONIC_RAW
self.session = SyncSession(self.store)
def __sig_terminate(self, sig, frame):
logg.warning('got signal {}'.format(sig))
self.terminate()
def terminate(self):
"""Set syncer to terminate as soon as possible.
"""
logg.info('termination requested!')
SyncDriver.running_global = False
self.running = False
def run(self):
while self.running_global:
item = self.store.next_item()
logg.debug('item {}'.format(item))
if item == None:
self.running = False
self.running_global = False
break
self.loop(item)
def idle(self, interval):
@ -54,21 +80,8 @@ class BlockPollSyncer(Syncer):
time.sleep(interval)
def loop(self, interval, conn):
"""Indefinite loop polling the given RPC connection for new blocks in the given interval.
:param interval: Seconds to wait for next poll after processing of previous poll has been completed.
:type interval: int
:param conn: RPC connection
:type conn: chainlib.connection.RPCConnection
:rtype: tuple
:returns: See chainsyncer.backend.base.Backend.get
"""
(pair, fltr) = self.backend.get()
start_tx = pair[1]
while self.running and Syncer.running_global:
def loop(self, item):
while self.running and SyncDriver.running_global:
self.last_start = time.clock_gettime_ns(self.clock_id)
if self.pre_callback != None:
self.pre_callback()
@ -97,3 +110,14 @@ class BlockPollSyncer(Syncer):
self.post_callback()
self.idle(interval)
def process_single(self, conn, block, tx):
self.session.filter(conn, block, tx)
def process(self, conn, block):
raise NotImplementedError()
def get(self, conn):
raise NotImplementedError()

View File

@ -1 +0,0 @@
from .base import Syncer

View File

@ -1,126 +0,0 @@
# standard imports
import uuid
import logging
import time
import signal
import json
# external imports
from chainlib.error import JSONRPCException
# local imports
from chainsyncer.filter import SyncFilter
from chainsyncer.error import (
SyncDone,
NoBlockForYou,
)
logg = logging.getLogger(__name__)
def noop_callback(block, tx):
"""Logger-only callback for pre- and post processing.
:param block: Block object
:type block: chainlib.block.Block
:param tx: Transaction object
:type tx: chainlib.tx.Tx
"""
logg.debug('noop callback ({},{})'.format(block, tx))
class Syncer:
"""Base class for syncer implementations.
:param backend: Syncer state backend
:type backend: chainsyncer.backend.base.Backend implementation
:param chain_interface: Chain interface implementation
:type chain_interface: chainlib.interface.ChainInterface implementation
:param pre_callback: Function to call before polling. Function will receive no arguments.
:type pre_callback: function
:param block_callback: Function to call before processing txs in a retrieved block. Function should have signature as chainsyncer.driver.base.noop_callback
:type block_callback: function
:param post_callback: Function to call after polling. Function will receive no arguments.
:type post_callback: function
"""
running_global = True
"""If set to false syncer will terminate polling loop."""
yield_delay=0.005
"""Delay between each processed block."""
signal_request = [signal.SIGINT, signal.SIGTERM]
"""Signals to catch to request shutdown."""
signal_set = False
"""Whether shutdown signal has been received."""
name = 'base'
"""Syncer name, to be overriden for each extended implementation."""
def __init__(self, backend, chain_interface, pre_callback=None, block_callback=None, post_callback=None):
self.chain_interface = chain_interface
self.cursor = None
self.running = True
self.backend = backend
self.filter = SyncFilter(backend)
self.block_callback = block_callback
self.pre_callback = pre_callback
self.post_callback = post_callback
if not Syncer.signal_set:
for sig in Syncer.signal_request:
signal.signal(sig, self.__sig_terminate)
Syncer.signal_set = True
def __sig_terminate(self, sig, frame):
logg.warning('got signal {}'.format(sig))
self.terminate()
def terminate(self):
"""Set syncer to terminate as soon as possible.
"""
logg.info('termination requested!')
Syncer.running_global = False
self.running = False
def add_filter(self, f):
"""Add filter to be processed for each transaction.
:param f: Filter
:type f: Object instance implementing signature as in chainsyncer.filter.NoopFilter.filter
"""
self.filter.add(f)
self.backend.register_filter(str(f))
def process_single(self, conn, block, tx):
"""Set syncer backend cursor to the given transaction index and block height, and apply all registered filters on transaction.
:param conn: RPC connection instance
:type conn: chainlib.connection.RPCConnection
:param block: Block object
:type block: chainlib.block.Block
:param block: Transaction object
:type block: chainlib.tx.Tx
"""
self.backend.set(block.number, tx.index)
self.filter.apply(conn, block, tx)
def loop(self, interval, conn):
raise NotImplementedError()
def process(self, conn, block):
raise NotImplementedError()
def get(self, conn):
raise NotImplementedError()
def __str__(self):
return 'syncer "{}" {}'.format(
self.name,
self.backend,
)

View File

@ -1,86 +0,0 @@
# standard imports
import logging
# external imports
from chainlib.eth.tx import (
transaction,
Tx,
)
from chainlib.error import RPCException
# local imports
from chainsyncer.error import NoBlockForYou
from .poll import BlockPollSyncer
logg = logging.getLogger(__name__)
class HeadSyncer(BlockPollSyncer):
"""Extends the block poller, implementing an open-ended syncer.
"""
name = 'head'
def process(self, conn, block):
"""Process a single block using the given RPC connection.
Processing means that all filters are executed on all transactions in the block.
If the block object does not contain the transaction details, the details will be retrieved from the network (incurring the corresponding performance penalty).
:param conn: RPC connection
:type conn: chainlib.connection.RPCConnection
:param block: Block object
:type block: chainlib.block.Block
"""
(pair, fltr) = self.backend.get()
logg.debug('process block {} (backend {}:{})'.format(block, pair, fltr))
i = pair[1] # set tx index from previous
tx_src = None
while True:
# handle block objects regardless of whether the tx data is embedded or not
try:
tx = block.tx(i)
except AttributeError:
o = transaction(block.txs[i])
r = conn.do(o)
tx_src = Tx.src_normalize(r)
tx = self.chain_interface.tx_from_src(tx_src, block=block)
#except IndexError as e:
# logg.debug('index error syncer tx get {}'.format(e))
# break
rcpt = conn.do(self.chain_interface.tx_receipt(tx.hash))
if rcpt != None:
tx.apply_receipt(self.chain_interface.src_normalize(rcpt))
self.process_single(conn, block, tx)
self.backend.reset_filter()
i += 1
def get(self, conn):
"""Retrieve the block currently defined by the syncer cursor from the RPC provider.
:param conn: RPC connection
:type conn: chainlib.connectin.RPCConnection
:raises NoBlockForYou: Block at the given height does not exist
:rtype: chainlib.block.Block
:returns: Block object
"""
(height, flags) = self.backend.get()
block_number = height[0]
block_hash = []
o = self.chain_interface.block_by_number(block_number)
try:
r = conn.do(o)
except RPCException:
r = None
if r == None:
raise NoBlockForYou()
b = self.chain_interface.block_from_src(r)
b.txs = b.txs[height[1]:]
return b

View File

@ -1,56 +0,0 @@
# standard imports
import logging
# external imports
from chainlib.error import RPCException
# local imports
from .head import HeadSyncer
from chainsyncer.error import SyncDone
from chainlib.error import RPCException
logg = logging.getLogger(__name__)
class HistorySyncer(HeadSyncer):
"""Bounded syncer implementation of the block poller. Reuses the head syncer process method implementation.
"""
name = 'history'
def __init__(self, backend, chain_interface, pre_callback=None, block_callback=None, post_callback=None):
super(HeadSyncer, self).__init__(backend, chain_interface, pre_callback, block_callback, post_callback)
self.block_target = None
(block_number, flags) = self.backend.target()
if block_number == None:
raise AttributeError('backend has no future target. Use HeadSyner instead')
self.block_target = block_number
def get(self, conn):
"""Retrieve the block currently defined by the syncer cursor from the RPC provider.
:param conn: RPC connection
:type conn: chainlib.connectin.RPCConnection
:raises SyncDone: Block target reached (at which point the syncer should terminate).
:rtype: chainlib.block.Block
:returns: Block object
:todo: DRY against HeadSyncer
"""
(height, flags) = self.backend.get()
if self.block_target < height[0]:
raise SyncDone(self.block_target)
block_number = height[0]
block_hash = []
o = self.chain_interface.block_by_number(block_number, include_tx=True)
try:
r = conn.do(o)
# TODO: Disambiguate whether error is temporary or permanent, if permanent, SyncDone should be raised, because a historical sync is attempted into the future
except RPCException:
r = None
if r == None:
raise SyncDone()
b = self.chain_interface.block_from_src(r)
return b

View File

@ -1,133 +0,0 @@
# standard imports
import logging
#import threading
import multiprocessing
import queue
# external imports
from chainlib.error import RPCException
# local imports
from .history import HistorySyncer
from chainsyncer.error import SyncDone
logg = logging.getLogger(__name__)
class ThreadedHistorySyncer(HistorySyncer):
def __init__(self, conn_factory, thread_limit, backend, chain_interface, pre_callback=None, block_callback=None, post_callback=None, conn_limit=0):
super(ThreadedHistorySyncer, self).__init__(backend, chain_interface, pre_callback, block_callback, post_callback)
self.workers = []
if conn_limit == 0:
conn_limit = thread_limit
#self.conn_pool = queue.Queue(conn_limit)
#self.queue = queue.Queue(thread_limit)
#self.quit_queue = queue.Queue(1)
self.conn_pool = multiprocessing.Queue(conn_limit)
self.queue = multiprocessing.Queue(thread_limit)
self.quit_queue = multiprocessing.Queue(1)
#self.lock = threading.Lock()
self.lock = multiprocessing.Lock()
for i in range(thread_limit):
#w = threading.Thread(target=self.worker)
w = multiprocessing.Process(target=self.worker)
self.workers.append(w)
for i in range(conn_limit):
self.conn_pool.put(conn_factory())
def terminate(self):
self.quit_queue.put(())
super(ThreadedHistorySyncer, self).terminate()
def worker(self):
while True:
block_number = None
try:
block_number = self.queue.get(timeout=0.01)
except queue.Empty:
if self.quit_queue.qsize() > 0:
#logg.debug('{} received quit'.format(threading.current_thread().getName()))
logg.debug('{} received quit'.format(multiprocessing.current_process().name))
return
continue
conn = self.conn_pool.get()
try:
logg.debug('processing parent {} {}'.format(conn, block_number))
self.process_parent(conn, block_number)
except IndexError:
pass
except RPCException as e:
logg.error('RPC failure for block {}, resubmitting to queue: {}'.format(block, e))
self.queue.put(block_number)
conn = self.conn_pool.put(conn)
def process_parent(self, conn, block_number):
logg.debug('getting block {}'.format(block_number))
o = self.chain_interface.block_by_number(block_number)
r = conn.do(o)
block = self.chain_interface.block_from_src(r)
logg.debug('got block typ {}'.format(type(block)))
super(ThreadedHistorySyncer, self).process(conn, block)
def process_single(self, conn, block, tx):
self.filter.apply(conn, block, tx)
def process(self, conn, block):
pass
#def process(self, conn, block):
def get(self, conn):
if not self.running:
raise SyncDone()
block_number = None
tx_index = None
flags = None
((block_number, tx_index), flags) = self.backend.get()
try:
#logg.debug('putting {}'.format(block.number))
#self.queue.put((conn, block_number,), timeout=0.1)
self.queue.put(block_number, timeout=0.1)
except queue.Full:
#logg.debug('queue full, try again')
return
target, flags = self.backend.target()
next_block = block_number + 1
if next_block > target:
self.quit_queue.put(())
raise SyncDone()
self.backend.set(self.backend.block_height + 1, 0)
# def get(self, conn):
# try:
# r = super(ThreadedHistorySyncer, self).get(conn)
# return r
# except SyncDone as e:
# self.quit_queue.put(())
# raise e
def loop(self, interval, conn):
for w in self.workers:
w.start()
r = super(ThreadedHistorySyncer, self).loop(interval, conn)
for w in self.workers:
w.join()
while True:
try:
self.quit_queue.get_nowait()
except queue.Empty:
break
logg.info('workers done {}'.format(r))

View File

@ -1,170 +0,0 @@
# standard imports
import logging
#import threading
import multiprocessing
import queue
import time
# external imports
from chainlib.error import RPCException
# local imports
from .history import HistorySyncer
from chainsyncer.error import SyncDone
logg = logging.getLogger(__name__)
def foobarcb(v):
logg.debug('foooz {}'.format(v))
class ThreadPoolTask:
process_func = None
chain_interface = None
def poolworker(self, block_number, conn):
# conn = args[1].get()
try:
logg.debug('processing parent {} {}'.format(conn, block_number))
#self.process_parent(self.conn, block_number)
self.process_parent(conn, block_number)
except IndexError:
pass
except RPCException as e:
logg.error('RPC failure for block {}, resubmitting to queue: {}'.format(block, e))
raise e
#self.queue.put(block_number)
# conn = self.conn_pool.put(conn)
def process_parent(self, conn, block_number):
logg.debug('getting block {}'.format(block_number))
o = self.chain_interface.block_by_number(block_number)
r = conn.do(o)
block = self.chain_interface.block_from_src(r)
logg.debug('got block typ {}'.format(type(block)))
#super(ThreadedHistorySyncer, self).process(conn, block)
self.process_func(conn, block)
class ThreadPoolHistorySyncer(HistorySyncer):
def __init__(self, conn_factory, thread_limit, backend, chain_interface, pre_callback=None, block_callback=None, post_callback=None, conn_limit=0):
super(ThreadPoolHistorySyncer, self).__init__(backend, chain_interface, pre_callback, block_callback, post_callback)
self.workers = []
self.thread_limit = thread_limit
if conn_limit == 0:
self.conn_limit = self.thread_limit
#self.conn_pool = queue.Queue(conn_limit)
#self.queue = queue.Queue(thread_limit)
#self.quit_queue = queue.Queue(1)
#self.conn_pool = multiprocessing.Queue(conn_limit)
#self.queue = multiprocessing.Queue(thread_limit)
#self.quit_queue = multiprocessing.Queue(1)
#self.lock = threading.Lock()
#self.lock = multiprocessing.Lock()
ThreadPoolTask.process_func = super(ThreadPoolHistorySyncer, self).process
ThreadPoolTask.chain_interface = chain_interface
#for i in range(thread_limit):
#w = threading.Thread(target=self.worker)
# w = multiprocessing.Process(target=self.worker)
# self.workers.append(w)
#for i in range(conn_limit):
# self.conn_pool.put(conn_factory())
self.conn_factory = conn_factory
self.worker_pool = None
def terminate(self):
#self.quit_queue.put(())
super(ThreadPoolHistorySyncer, self).terminate()
# def worker(self):
# while True:
# block_number = None
# try:
# block_number = self.queue.get(timeout=0.01)
# except queue.Empty:
# if self.quit_queue.qsize() > 0:
# #logg.debug('{} received quit'.format(threading.current_thread().getName()))
# logg.debug('{} received quit'.format(multiprocessing.current_process().name))
# return
# continue
# conn = self.conn_pool.get()
# try:
# logg.debug('processing parent {} {}'.format(conn, block_number))
# self.process_parent(conn, block_number)
# except IndexError:
# pass
# except RPCException as e:
# logg.error('RPC failure for block {}, resubmitting to queue: {}'.format(block, e))
# self.queue.put(block_number)
# conn = self.conn_pool.put(conn)
#
def process_single(self, conn, block, tx):
self.filter.apply(conn, block, tx)
def process(self, conn, block):
pass
def get(self, conn):
if not self.running:
raise SyncDone()
block_number = None
tx_index = None
flags = None
((block_number, tx_index), flags) = self.backend.get()
#try:
#logg.debug('putting {}'.format(block.number))
#self.queue.put((conn, block_number,), timeout=0.1)
#self.queue.put(block_number, timeout=0.1)
#except queue.Full:
#logg.debug('queue full, try again')
# return
task = ThreadPoolTask()
conn = self.conn_factory()
self.worker_pool.apply_async(task.poolworker, (block_number, conn,), {}, foobarcb)
target, flags = self.backend.target()
next_block = block_number + 1
if next_block > target:
#self.quit_queue.put(())
self.worker_pool.close()
raise SyncDone()
self.backend.set(self.backend.block_height + 1, 0)
# def get(self, conn):
# try:
# r = super(ThreadedHistorySyncer, self).get(conn)
# return r
# except SyncDone as e:
# self.quit_queue.put(())
# raise e
def loop(self, interval, conn):
self.worker_pool = multiprocessing.Pool(self.thread_limit)
#for w in self.workers:
# w.start()
r = super(ThreadPoolHistorySyncer, self).loop(interval, conn)
#for w in self.workers:
# w.join()
#while True:
# try:
# self.quit_queue.get_nowait()
# except queue.Empty:
# break
time.sleep(1)
self.worker_pool.join()
logg.info('workers done {}'.format(r))

View File

@ -1,81 +0,0 @@
# standard imports
import copy
import logging
import multiprocessing
import os
# external iports
from chainlib.eth.connection import RPCConnection
# local imports
from chainsyncer.driver.history import HistorySyncer
from chainsyncer.driver.base import Syncer
from .threadpool import ThreadPoolTask
logg = logging.getLogger(__name__)
def sync_split(block_offset, block_target, count):
block_count = block_target - block_offset
if block_count < count:
logg.warning('block count is less than thread count, adjusting thread count to {}'.format(block_count))
count = block_count
blocks_per_thread = int(block_count / count)
ranges = []
for i in range(count):
block_target = block_offset + blocks_per_thread
offset = block_offset
target = block_target -1
ranges.append((offset, target,))
block_offset = block_target
return ranges
class ThreadPoolRangeTask:
def __init__(self, backend, sync_range, chain_interface, syncer_factory=HistorySyncer, filters=[]):
backend_start = backend.start()
backend_target = backend.target()
backend_class = backend.__class__
tx_offset = 0
flags = 0
if sync_range[0] == backend_start[0][0]:
tx_offset = backend_start[0][1]
flags = backend_start[1]
self.backend = backend_class.custom(backend.chain_spec, sync_range[1], block_offset=sync_range[0], tx_offset=tx_offset, flags=flags, flags_count=0)
self.syncer = syncer_factory(self.backend, chain_interface)
for fltr in filters:
self.syncer.add_filter(fltr)
def start_loop(self, interval):
conn = RPCConnection.connect(self.backend.chain_spec)
return self.syncer.loop(interval, conn)
class ThreadPoolRangeHistorySyncer:
def __init__(self, thread_count, backend, chain_interface, pre_callback=None, block_callback=None, post_callback=None, runlevel_callback=None):
self.src_backend = backend
self.thread_count = thread_count
self.single_sync_offset = 0
self.runlevel_callback = None
backend_start = backend.start()
backend_target = backend.target()
self.ranges = sync_split(backend_start[0][0], backend_target[0], thread_count)
self.chain_interface = chain_interface
self.filters = []
def add_filter(self, f):
self.filters.append(f)
def loop(self, interval, conn):
self.worker_pool = multiprocessing.Pool(processes=self.thread_count)
for sync_range in self.ranges:
task = ThreadPoolRangeTask(self.src_backend, sync_range, self.chain_interface, filters=self.filters)
t = self.worker_pool.apply_async(task.start_loop, (0.1,))
logg.debug('result of worker {}: {}'.format(t, t.get()))
self.worker_pool.close()
self.worker_pool.join()

View File

@ -1,15 +1,19 @@
# standard imports
import uuid
# local imports
from chainsyncer.error import FilterDone
class SyncSession:
def __init__(self, session_store):
self.session_store = session_store
self.filters = []
self.start = self.session_store.start
self.get = self.session_store.get
self.started = self.session_store.started
self.get = self.session_store.get
self.next = self.session_store.next_item
self.item = None
def register(self, fltr):
@ -18,17 +22,24 @@ class SyncSession:
self.session_store.register(fltr)
self.filters.append(fltr)
def start(self, offset=0, target=-1):
self.session_store.start(offset=offset, target=target)
self.item = self.session_store.next_item()
def filter(self, conn, block, tx):
self.sync_state.connect()
for fltr in filters:
self.session_store.connect()
for fltr in self.filters:
try:
self.sync_start.advance()
self.item.advance()
except FilterDone:
break
interrupt = fltr(conn, block, tx)
try:
self.sync_start.release(interrupt=interrupt)
except FilterDone:
break
self.sync_start.disconnect()
interrupt = fltr.filter(conn, block, tx)
self.item.release(interrupt=interrupt)
try:
self.item.advance()
raise BackendError('filter state inconsitent with filter list')
except FilterDone:
self.item.reset()
self.session_store.disconnect()

View File

@ -5,6 +5,7 @@ import logging
logg = logging.getLogger(__name__)
# TODO: properly clarify interface shared with syncfsstore, move to filter module?
class SyncState:
def __init__(self, state_store):
@ -61,7 +62,8 @@ class SyncState:
def connect(self):
if not self.synced:
for v in self.state_store.all():
self.state_store.sync(v)
k = self.state_store.from_name(v)
self.state_store.sync(k)
self.__syncs[v] = True
self.synced = True
self.connected = True
@ -71,10 +73,14 @@ class SyncState:
self.connected = False
def start(self):
self.state_store.start()
def start(self, offset=0, target=-1):
self.state_store.start(offset=offset, target=target)
self.started = True
def get(self, k):
raise NotImplementedError()
return None
def next_item(self):
return None

View File

@ -15,6 +15,7 @@ from chainsyncer.error import (
FilterDone,
InterruptError,
IncompleteFilterError,
SyncDone,
)
logg = logging.getLogger(__name__)
@ -58,6 +59,15 @@ class SyncFsItem:
if self.filter_state.state(self.state_key) & self.filter_state.from_name('DONE') == 0:
raise IncompleteFilterError('reset attempt on {} when incomplete'.format(self.state_key))
self.filter_state.move(self.state_key, self.filter_state.from_name('RESET'))
v = self.sync_state.get(self.state_key)
block_number = int.from_bytes(v, 'big')
block_number += 1
if self.target >= 0 and block_number > self.target:
raise SyncDone(self.target)
v = block_number.to_bytes(4, 'big')
self.sync_state.replace(self.state_key, v)
def advance(self):
@ -82,6 +92,7 @@ class SyncFsItem:
if self.skip_filter:
raise FilterDone()
if interrupt:
self.filter_state.unset(self.state_key, self.filter_state.from_name('LOCK'))
self.filter_state.set(self.state_key, self.filter_state.from_name('INTERRUPT'))
self.filter_state.set(self.state_key, self.filter_state.from_name('DONE'))
return
@ -106,6 +117,7 @@ class SyncFsStore:
self.first = False
self.target = None
self.items = {}
self.item_keys = []
self.started = False
default_path = os.path.join(base_path, 'default')
@ -183,6 +195,7 @@ class SyncFsStore:
item_target = thresholds[i+1]
o = SyncFsItem(block_number, item_target, self.state, self.filter_state, started=True)
self.items[block_number] = o
self.item_keys.append(block_number)
logg.info('added {}'.format(o))
fp = os.path.join(self.session_path, str(target))
@ -198,8 +211,10 @@ class SyncFsStore:
f.close()
self.target = int(v)
logg.debug('target {}'.format(self.target))
def start(self, offset=0, target=0):
def start(self, offset=0, target=-1):
self.__load(target)
if self.first:
@ -210,10 +225,13 @@ class SyncFsStore:
self.filter_state.put(block_number_str)
o = SyncFsItem(block_number, target, self.state, self.filter_state)
self.items[block_number] = o
self.item_keys.append(block_number)
elif offset > 0:
logg.warning('block number argument {} for start ignored for already initiated sync {}'.format(offset, self.session_id))
self.started = True
self.item_keys.sort()
def stop(self):
if self.target == 0:
@ -224,3 +242,19 @@ class SyncFsStore:
def get(self, k):
return self.items[k]
def next_item(self):
try:
k = self.item_keys.pop(0)
except IndexError:
return None
return self.items[k]
def connect(self):
self.filter_state.connect()
def disconnect(self):
self.filter_state.disconnect()

View File

@ -75,7 +75,7 @@ class MockStore(State):
super(MockStore, self).__init__(bits, check_alias=False)
def start(self):
def start(self, offset=0, target=-1):
pass
@ -89,6 +89,7 @@ class MockFilter:
z = h.digest()
self.z = z
self.brk = brk
self.contents = []
def sum(self):
@ -100,6 +101,7 @@ class MockFilter:
def filter(self, conn, block, tx):
self.contents.append((block.number, tx.index, tx.hash,))
return self.brk

75
tests/test_filter.py Normal file
View File

@ -0,0 +1,75 @@
# standard imports
import unittest
import tempfile
import shutil
import logging
import stat
import os
# local imports
from chainsyncer.store.fs import SyncFsStore
from chainsyncer.session import SyncSession
from chainsyncer.error import (
LockError,
FilterDone,
IncompleteFilterError,
)
from chainsyncer.unittest import (
MockFilter,
MockConn,
MockTx,
MockBlock,
)
logging.basicConfig(level=logging.DEBUG)
logg = logging.getLogger()
class TestFilter(unittest.TestCase):
def setUp(self):
self.path = tempfile.mkdtemp()
self.store = SyncFsStore(self.path)
self.session = SyncSession(self.store)
self.session.start()
self.conn = MockConn()
def tearDown(self):
shutil.rmtree(self.path)
def test_filter_basic(self):
fltr_one = MockFilter('foo')
self.session.register(fltr_one)
fltr_two = MockFilter('bar')
self.session.register(fltr_two)
tx_hash = os.urandom(32).hex()
tx = MockTx(42, tx_hash)
block = MockBlock(13, [tx_hash])
self.session.filter(self.conn, block, tx)
self.assertEqual(len(fltr_one.contents), 1)
self.assertEqual(len(fltr_two.contents), 1)
def test_filter_interrupt(self):
fltr_one = MockFilter('foo', brk=True)
self.session.register(fltr_one)
fltr_two = MockFilter('bar')
self.session.register(fltr_two)
tx_hash = os.urandom(32).hex()
tx = MockTx(42, tx_hash)
block = MockBlock(13, [tx_hash])
self.session.filter(self.conn, block, tx)
self.assertEqual(len(fltr_one.contents), 1)
self.assertEqual(len(fltr_two.contents), 0)
if __name__ == '__main__':
unittest.main()

View File

@ -13,6 +13,7 @@ from chainsyncer.error import (
LockError,
FilterDone,
IncompleteFilterError,
SyncDone,
)
from chainsyncer.unittest import MockFilter
@ -25,6 +26,7 @@ class TestFs(unittest.TestCase):
def setUp(self):
self.path = tempfile.mkdtemp()
def tearDown(self):
shutil.rmtree(self.path)
@ -167,6 +169,23 @@ class TestFs(unittest.TestCase):
o.reset()
def test_sync_process_done(self):
store = SyncFsStore(self.path)
session = SyncSession(store)
fltr_one = MockFilter('foo')
session.register(fltr_one)
session.start(target=0)
o = session.get(0)
o.advance()
o.release()
with self.assertRaises(FilterDone):
o.advance()
with self.assertRaises(SyncDone):
o.reset()
if __name__ == '__main__':
unittest.main()

68
tests/test_session.py Normal file
View File

@ -0,0 +1,68 @@
# standard imports
import unittest
import tempfile
import shutil
import logging
import stat
import os
# local imports
from chainsyncer.store.fs import SyncFsStore
from chainsyncer.session import SyncSession
from chainsyncer.error import (
LockError,
FilterDone,
IncompleteFilterError,
SyncDone,
)
from chainsyncer.unittest import (
MockFilter,
MockConn,
MockTx,
MockBlock,
)
from chainsyncer.driver import SyncDriver
logging.basicConfig(level=logging.DEBUG)
logg = logging.getLogger()
class TestFilter(unittest.TestCase):
def setUp(self):
self.path = tempfile.mkdtemp()
self.store = SyncFsStore(self.path)
self.conn = MockConn()
def tearDown(self):
shutil.rmtree(self.path)
def test_filter_basic(self):
session = SyncSession(self.store)
session.start(target=1)
fltr_one = MockFilter('foo')
session.register(fltr_one)
tx_hash = os.urandom(32).hex()
tx = MockTx(42, tx_hash)
block = MockBlock(0, [tx_hash])
session.filter(self.conn, block, tx)
tx_hash = os.urandom(32).hex()
tx = MockTx(42, tx_hash)
block = MockBlock(1, [tx_hash])
with self.assertRaises(SyncDone):
session.filter(self.conn, block, tx)
self.assertEqual(len(fltr_one.contents), 2)
def test_driver(self):
drv = SyncDriver(self.conn, self.store)
drv.run()
if __name__ == '__main__':
unittest.main()