Complete syncitem filter advance

This commit is contained in:
lash 2022-03-17 22:07:19 +00:00
parent 5968a19042
commit dcf095cc86
Signed by: lash
GPG Key ID: 21D2E7BB88C2A746
8 changed files with 304 additions and 90 deletions

View File

@ -3,6 +3,7 @@ class SyncDone(Exception):
""" """
pass pass
class NoBlockForYou(Exception): class NoBlockForYou(Exception):
"""Exception raised when attempt to retrieve a block from network that does not (yet) exist. """Exception raised when attempt to retrieve a block from network that does not (yet) exist.
""" """
@ -27,6 +28,20 @@ class LockError(Exception):
pass pass
class FilterDone(Exception):
"""Exception raised when all registered filters have been executed
"""
class InterruptError(FilterDone):
"""Exception for interrupting or attempting to use an interrupted sync
"""
class IncompleteFilterError(Exception):
"""Exception raised if filter reset is executed prematurely
"""
#class AbortTx(Exception): #class AbortTx(Exception):
# """ # """
# """ # """

View File

@ -7,24 +7,28 @@ class SyncSession:
def __init__(self, session_store): def __init__(self, session_store):
self.session_store = session_store self.session_store = session_store
self.filters = [] self.filters = []
self.started = False self.start = self.session_store.start
self.get = self.session_store.get
self.started = self.session_store.started
def add_filter(self, fltr): def register(self, fltr):
if self.started: if self.started:
raise RuntimeError('filters cannot be changed after syncer start') raise RuntimeError('filters cannot be changed after syncer start')
self.session_store.register(fltr) self.session_store.register(fltr)
self.filters.append(fltr) self.filters.append(fltr)
def start(self):
self.started = True
def filter(self, conn, block, tx): def filter(self, conn, block, tx):
self.sync_state.connect() self.sync_state.connect()
for fltr in filters: for fltr in filters:
self.sync_start.lock() try:
self.sync_start.unlock() self.sync_start.advance()
except FilterDone:
break
interrupt = fltr(conn, block, tx)
try:
self.sync_start.release(interrupt=interrupt)
except FilterDone:
break
self.sync_start.disconnect() self.sync_start.disconnect()

View File

@ -1,5 +1,8 @@
# standard imports # standard imports
import hashlib import hashlib
import logging
logg = logging.getLogger(__name__)
class SyncState: class SyncState:
@ -11,8 +14,20 @@ class SyncState:
self.__syncs = {} self.__syncs = {}
self.synced = False self.synced = False
self.connected = False self.connected = False
self.state_store.add('INTERRUPT') self.state_store.add('DONE')
self.state_store.add('LOCK') self.state_store.add('LOCK')
self.state_store.add('INTERRUPT')
self.state_store.add('RESET')
self.state = self.state_store.state
self.put = self.state_store.put
self.set = self.state_store.set
self.next = self.state_store.next
self.move = self.state_store.move
self.unset = self.state_store.unset
self.from_name = self.state_store.from_name
self.state_store.sync()
self.all = self.state_store.all
self.started = False
def __verify_sum(self, v): def __verify_sum(self, v):
@ -30,6 +45,9 @@ class SyncState:
self.digest += z self.digest += z
s = fltr.common_name() s = fltr.common_name()
self.state_store.add(s) self.state_store.add(s)
n = self.state_store.from_name(s)
logg.debug('add {} {} {}'.format(s, n, self))
def sum(self): def sum(self):
@ -53,9 +71,10 @@ class SyncState:
self.connected = False self.connected = False
def lock(self): def start(self):
pass self.state_store.start()
self.started = True
def unlock(self): def get(self, k):
pass raise NotImplementedError()

View File

@ -6,30 +6,90 @@ import logging
# external imports # external imports
from shep.store.file import SimpleFileStoreFactory from shep.store.file import SimpleFileStoreFactory
from shep.persist import PersistedState from shep.persist import PersistedState
from shep.error import StateInvalid
# local imports # local imports
from chainsyncer.state import SyncState from chainsyncer.state import SyncState
from chainsyncer.error import (
LockError,
FilterDone,
InterruptError,
IncompleteFilterError,
)
logg = logging.getLogger(__name__) logg = logging.getLogger(__name__)
# NOT thread safe
class SyncFsItem: class SyncFsItem:
def __init__(self, offset, target, sync_state, filter_state, started=False): def __init__(self, offset, target, sync_state, filter_state, started=False, ignore_invalid=False):
self.offset = offset self.offset = offset
self.target = target self.target = target
self.sync_state = sync_state self.sync_state = sync_state
self.filter_state = filter_state self.filter_state = filter_state
s = str(offset) self.state_key = str(offset)
match_state = self.sync_state.NEW match_state = self.sync_state.NEW
if started: if started:
match_state = self.sync_state.SYNC match_state = self.sync_state.SYNC
v = self.sync_state.get(s) v = self.sync_state.get(self.state_key)
self.cursor = int.from_bytes(v, 'big') self.cursor = int.from_bytes(v, 'big')
if self.filter_state.state(self.state_key) & self.filter_state.from_name('LOCK') and not ignore_invalid:
raise LockError(s)
def next(self): self.count = len(self.filter_state.all(pure=True)) - 3
pass self.skip_filter = False
if self.count == 0:
self.skip_filter = True
else:
self.filter_state.move(self.state_key, self.filter_state.from_name('RESET'))
def __check_done(self):
if self.filter_state.state(self.state_key) & self.filter_state.from_name('INTERRUPT') > 0:
raise InterruptError(self.state_key)
if self.filter_state.state(self.state_key) & self.filter_state.from_name('DONE') > 0:
raise FilterDone(self.state_key)
def reset(self):
if self.filter_state.state(self.state_key) & self.filter_state.from_name('LOCK') > 0:
raise LockError('reset attempt on {} when state locked'.format(self.state_key))
if self.filter_state.state(self.state_key) & self.filter_state.from_name('DONE') == 0:
raise IncompleteFilterError('reset attempt on {} when incomplete'.format(self.state_key))
self.filter_state.move(self.state_key, self.filter_state.from_name('RESET'))
def advance(self):
if self.skip_filter:
raise FilterDone()
self.__check_done()
if self.filter_state.state(self.state_key) & self.filter_state.from_name('LOCK') > 0:
raise LockError('advance attempt on {} when state locked'.format(self.state_key))
done = False
try:
self.filter_state.next(self.state_key)
except StateInvalid:
done = True
if done:
self.filter_state.set(self.state_key, self.filter_state.from_name('DONE'))
raise FilterDone()
self.filter_state.set(self.state_key, self.filter_state.from_name('LOCK'))
def release(self, interrupt=False):
if self.skip_filter:
raise FilterDone()
if interrupt:
self.filter_state.set(self.state_key, self.filter_state.from_name('INTERRUPT'))
self.filter_state.set(self.state_key, self.filter_state.from_name('DONE'))
return
state = self.filter_state.state(self.state_key)
if state & self.filter_state.from_name('LOCK') == 0:
raise LockError('release attempt on {} when state unlocked'.format(self.state_key))
self.filter_state.unset(self.state_key, self.filter_state.from_name('LOCK'))
def __str__(self): def __str__(self):
@ -46,6 +106,7 @@ class SyncFsStore:
self.first = False self.first = False
self.target = None self.target = None
self.items = {} self.items = {}
self.started = False
default_path = os.path.join(base_path, 'default') default_path = os.path.join(base_path, 'default')
@ -76,7 +137,7 @@ class SyncFsStore:
base_filter_path = os.path.join(self.session_path, 'filter') base_filter_path = os.path.join(self.session_path, 'filter')
factory = SimpleFileStoreFactory(base_filter_path, binary=True) factory = SimpleFileStoreFactory(base_filter_path, binary=True)
filter_state_backend = PersistedState(factory, 0) filter_state_backend = PersistedState(factory.add, 0, check_alias=False)
self.filter_state = SyncState(filter_state_backend) self.filter_state = SyncState(filter_state_backend)
self.register = self.filter_state.register self.register = self.filter_state.register
@ -144,12 +205,22 @@ class SyncFsStore:
if self.first: if self.first:
block_number = offset block_number = offset
block_number_bytes = block_number.to_bytes(4, 'big') block_number_bytes = block_number.to_bytes(4, 'big')
self.state.put(str(block_number), block_number_bytes) block_number_str = str(block_number)
self.state.put(block_number_str, block_number_bytes)
self.filter_state.put(block_number_str)
o = SyncFsItem(block_number, target, self.state, self.filter_state)
self.items[block_number] = o
elif offset > 0: elif offset > 0:
logg.warning('block number argument {} for start ignored for already initiated sync {}'.format(offset, self.session_id)) logg.warning('block number argument {} for start ignored for already initiated sync {}'.format(offset, self.session_id))
self.started = True
def stop(self): def stop(self):
if self.target == 0: if self.target == 0:
block_number = self.height + 1 block_number = self.height + 1
block_number_bytes = block_number.to_bytes(4, 'big') block_number_bytes = block_number.to_bytes(4, 'big')
self.state.put(str(block_number), block_number_bytes) self.state.put(str(block_number), block_number_bytes)
def get(self, k):
return self.items[k]

View File

@ -0,0 +1 @@
from .base import *

View File

@ -1,12 +1,14 @@
# standard imports # standard imports
import os import os
import logging import logging
import hashlib
# external imports # external imports
from hexathon import add_0x from hexathon import add_0x
from shep.state import State
# local imports # local imports
from chainsyncer.driver.history import HistorySyncer #from chainsyncer.driver.history import HistorySyncer
from chainsyncer.error import NoBlockForYou from chainsyncer.error import NoBlockForYou
logg = logging.getLogger().getChild(__name__) logg = logging.getLogger().getChild(__name__)
@ -67,42 +69,77 @@ class MockBlock:
return MockTx(i, self.txs[i]) return MockTx(i, self.txs[i])
class TestSyncer(HistorySyncer): class MockStore(State):
"""Unittest extension of history syncer driver.
:param backend: Syncer backend def __init__(self, bits=0):
:type backend: chainsyncer.backend.base.Backend implementation super(MockStore, self).__init__(bits, check_alias=False)
:param chain_interface: Chain interface
:type chain_interface: chainlib.interface.ChainInterface implementation
:param tx_counts: List of integer values defining how many mock transactions to generate per block. Mock blocks will be generated for each element in list.
:type tx_counts: list
"""
def __init__(self, backend, chain_interface, tx_counts=[]):
self.tx_counts = tx_counts
super(TestSyncer, self).__init__(backend, chain_interface)
def get(self, conn): def start(self):
"""Implements the block getter of chainsyncer.driver.base.Syncer. pass
:param conn: RPC connection
:type conn: chainlib.connection.RPCConnection
:raises NoBlockForYou: End of mocked block array reached
:rtype: chainsyncer.unittest.base.MockBlock
:returns: Mock block.
"""
(pair, fltr) = self.backend.get()
(target_block, fltr) = self.backend.target()
block_height = pair[0]
if block_height == target_block: class MockFilter:
self.running = False
raise NoBlockForYou()
block_txs = [] def __init__(self, name, brk=False, z=None):
if block_height < len(self.tx_counts): self.name = name
for i in range(self.tx_counts[block_height]): if z == None:
block_txs.append(add_0x(os.urandom(32).hex())) h = hashlib.sha256()
h.update(self.name.encode('utf-8'))
z = h.digest()
self.z = z
self.brk = brk
return MockBlock(block_height, block_txs)
def sum(self):
return self.z
def common_name(self):
return self.name
def filter(self, conn, block, tx):
return self.brk
#class TestSyncer(HistorySyncer):
# """Unittest extension of history syncer driver.
#
# :param backend: Syncer backend
# :type backend: chainsyncer.backend.base.Backend implementation
# :param chain_interface: Chain interface
# :type chain_interface: chainlib.interface.ChainInterface implementation
# :param tx_counts: List of integer values defining how many mock transactions to generate per block. Mock blocks will be generated for each element in list.
# :type tx_counts: list
# """
#
# def __init__(self, backend, chain_interface, tx_counts=[]):
# self.tx_counts = tx_counts
# super(TestSyncer, self).__init__(backend, chain_interface)
#
#
# def get(self, conn):
# """Implements the block getter of chainsyncer.driver.base.Syncer.
#
# :param conn: RPC connection
# :type conn: chainlib.connection.RPCConnection
# :raises NoBlockForYou: End of mocked block array reached
# :rtype: chainsyncer.unittest.base.MockBlock
# :returns: Mock block.
# """
# (pair, fltr) = self.backend.get()
# (target_block, fltr) = self.backend.target()
# block_height = pair[0]
#
# if block_height == target_block:
# self.running = False
# raise NoBlockForYou()
#
# block_txs = []
# if block_height < len(self.tx_counts):
# for i in range(self.tx_counts[block_height]):
# block_txs.append(add_0x(os.urandom(32).hex()))
#
# return MockBlock(block_height, block_txs)

View File

@ -1,50 +1,21 @@
# standard imports # standard imports
import unittest import unittest
import hashlib
import tempfile import tempfile
import shutil import shutil
import logging import logging
# external imports
from shep.state import State
# local imports # local imports
from chainsyncer.session import SyncSession from chainsyncer.session import SyncSession
from chainsyncer.state import SyncState from chainsyncer.state import SyncState
from chainsyncer.store.fs import SyncFsStore from chainsyncer.store.fs import SyncFsStore
from chainsyncer.unittest import (
MockStore,
MockFilter,
)
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
logg = logging.getLogger() logg = logging.getLogger()
class MockStore(State):
def __init__(self, bits=0):
super(MockStore, self).__init__(bits, check_alias=False)
class MockFilter:
def __init__(self, name, brk=False, z=None):
self.name = name
if z == None:
h = hashlib.sha256()
h.update(self.name.encode('utf-8'))
z = h.digest()
self.z = z
self.brk = brk
def sum(self):
return self.z
def common_name(self):
return self.name
def filter(self, conn, block, tx):
return self.brk
class TestSync(unittest.TestCase): class TestSync(unittest.TestCase):
@ -64,7 +35,7 @@ class TestSync(unittest.TestCase):
def test_sum(self): def test_sum(self):
store = MockStore(4) store = MockStore(6)
state = SyncState(store) state = SyncState(store)
b = b'\x2a' * 32 b = b'\x2a' * 32

View File

@ -8,6 +8,13 @@ import os
# local imports # local imports
from chainsyncer.store.fs import SyncFsStore from chainsyncer.store.fs import SyncFsStore
from chainsyncer.session import SyncSession
from chainsyncer.error import (
LockError,
FilterDone,
IncompleteFilterError,
)
from chainsyncer.unittest import MockFilter
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
logg = logging.getLogger() logg = logging.getLogger()
@ -69,8 +76,97 @@ class TestFs(unittest.TestCase):
store = SyncFsStore(self.path) store = SyncFsStore(self.path)
store.start(13) store.start(13)
self.assertTrue(store.first) self.assertTrue(store.first)
# todo not done
def test_sync_process_nofilter(self):
store = SyncFsStore(self.path)
session = SyncSession(store)
session.start()
o = session.get(0)
with self.assertRaises(FilterDone):
o.advance()
def test_sync_process_onefilter(self):
store = SyncFsStore(self.path)
session = SyncSession(store)
fltr_one = MockFilter('foo')
session.register(fltr_one)
session.start()
o = session.get(0)
o.advance()
o.release()
def test_sync_process_outoforder(self):
store = SyncFsStore(self.path)
session = SyncSession(store)
fltr_one = MockFilter('foo')
session.register(fltr_one)
fltr_two = MockFilter('two')
session.register(fltr_two)
session.start()
o = session.get(0)
o.advance()
with self.assertRaises(LockError):
o.advance()
o.release()
with self.assertRaises(LockError):
o.release()
o.advance()
o.release()
def test_sync_process_interrupt(self):
store = SyncFsStore(self.path)
session = SyncSession(store)
fltr_one = MockFilter('foo')
session.register(fltr_one)
fltr_two = MockFilter('bar')
session.register(fltr_two)
session.start()
o = session.get(0)
o.advance()
o.release(interrupt=True)
with self.assertRaises(FilterDone):
o.advance()
def test_sync_process_reset(self):
store = SyncFsStore(self.path)
session = SyncSession(store)
fltr_one = MockFilter('foo')
session.register(fltr_one)
fltr_two = MockFilter('bar')
session.register(fltr_two)
session.start()
o = session.get(0)
o.advance()
with self.assertRaises(LockError):
o.reset()
o.release()
with self.assertRaises(IncompleteFilterError):
o.reset()
o.advance()
o.release()
with self.assertRaises(FilterDone):
o.advance()
o.reset()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()