New filter interface, add state step stubs

This commit is contained in:
lash 2022-03-17 10:09:12 +00:00
parent 2ba87de195
commit af47e31cc8
Signed by: lash
GPG Key ID: 21D2E7BB88C2A746
4 changed files with 90 additions and 121 deletions

View File

@ -1,96 +1,12 @@
# standard imports
import logging
# local imports
from .error import BackendError
logg = logging.getLogger(__name__)
class SyncFilter: class SyncFilter:
"""Manages the collection of filters on behalf of a specific backend.
A filter is a pluggable piece of code to execute for every transaction retrieved by the syncer. Filters are executed in the sequence they were added to the instance. def common_name(self):
raise NotImplementedError()
:param backend: Syncer backend to apply filter state changes to
:type backend: chainsyncer.backend.base.Backend implementation
"""
def __init__(self, backend):
self.filters = []
self.backend = backend
def add(self, fltr): def sum(self):
"""Add a filter instance. raise NotImplementedError()
:param fltr: Filter instance.
:type fltr: Object instance implementing signature as in chainsyncer.filter.NoopFilter.filter
:raises ValueError: Object instance is incorrect implementation
"""
if getattr(fltr, 'filter') == None:
raise ValueError('filter object must implement have method filter')
logg.debug('added filter "{}"'.format(str(fltr)))
self.filters.append(fltr)
def __apply_one(self, fltr, idx, conn, block, tx, session): def filter(self, conn, block, tx):
self.backend.begin_filter(idx) raise NotImplementedError()
fltr.filter(conn, block, tx, session)
self.backend.complete_filter(idx)
def apply(self, conn, block, tx):
"""Apply all registered filters on the given transaction.
:param conn: RPC Connection, will be passed to the filter method
:type conn: chainlib.connection.RPCConnection
:param block: Block object
:type block: chainlib.block.Block
:param tx: Transaction object
:type tx: chainlib.tx.Tx
:raises BackendError: Backend connection failed
"""
session = None
try:
session = self.backend.connect()
except TimeoutError as e:
self.backend.disconnect()
raise BackendError('database connection fail: {}'.format(e))
i = 0
(pair, flags) = self.backend.get()
for f in self.filters:
if not self.backend.check_filter(i, flags):
logg.debug('applying filter {} {}'.format(str(f), flags))
self.__apply_one(f, i, conn, block, tx, session)
else:
logg.debug('skipping previously applied filter {} {}'.format(str(f), flags))
i += 1
self.backend.disconnect()
class NoopFilter:
"""A noop implemenation of a sync filter.
Logs the filter inputs at debug log level.
"""
def filter(self, conn, block, tx, db_session=None):
"""Filter method implementation:
:param conn: RPC Connection, will be passed to the filter method
:type conn: chainlib.connection.RPCConnection
:param block: Block object
:type block: chainlib.block.Block
:param tx: Transaction object
:type tx: chainlib.tx.Tx
:param db_session: Backend session object
:type db_session: varies
"""
logg.debug('noop filter :received\n{} {} {}'.format(block, tx, id(db_session)))
def __str__(self):
return 'noopfilter'

View File

@ -4,20 +4,33 @@ import uuid
class SyncSession: class SyncSession:
def __init__(self, state_store, session_id=None, is_default=False): def __init__(self, session_store, sync_state, session_id=None, is_default=False):
self.session_store = session_store
if session_id == None: if session_id == None:
session_id = str(uuid.uuid4()) session_id = str(uuid.uuid4())
is_default = True is_default = True
self.session_id = session_id self.session_id = session_id
self.is_default = is_default self.is_default = is_default
self.state_store = state_store self.sync_state = sync_state
self.filters = [] self.filters = []
self.started = False
def add_filter(self, fltr): def add_filter(self, fltr):
self.state_store.register(fltr) if self.started:
raise RuntimeError('filters cannot be changed after syncer start')
self.sync_state.register(fltr)
self.filters.append(fltr) self.filters.append(fltr)
def start(self): def start(self):
self.state_store.start() self.started = True
def filter(self, conn, block, tx):
self.sync_state.connect()
for fltr in filters:
self.sync_start.lock()
self.sync_start.unlock()
self.sync_start.disconnect()

View File

@ -1,13 +1,18 @@
# standard imports # standard imports
import hashlib import hashlib
class SyncState: class SyncState:
def __init__(self, state_store): def __init__(self, state_store):
self.store = state_store self.state_store = state_store
self.digest = b'\x00' * 32 self.digest = b'\x00' * 32
self.summed = False self.summed = False
self.synced = {} self.__syncs = {}
self.synced = False
self.connected = False
self.state_store.add('INTERRUPT')
self.state_store.add('LOCK')
def __verify_sum(self, v): def __verify_sum(self, v):
@ -24,8 +29,7 @@ class SyncState:
self.__verify_sum(z) self.__verify_sum(z)
self.digest += z self.digest += z
s = fltr.common_name() s = fltr.common_name()
self.store.add('i_' + s) self.state_store.add(s)
self.store.add('o_' + s)
def sum(self): def sum(self):
@ -36,7 +40,22 @@ class SyncState:
return self.digest return self.digest
def start(self): def connect(self):
for v in self.store.all(): if not self.synced:
self.store.sync(v) for v in self.state_store.all():
self.synced[v] = True self.state_store.sync(v)
self.__syncs[v] = True
self.synced = True
self.connected = True
def disconnect(self):
self.connected = False
def lock(self):
pass
def unlock(self):
pass

View File

@ -11,15 +11,20 @@ from chainsyncer.session import SyncSession
class MockStore(State): class MockStore(State):
def __init__(self, bits): def __init__(self, bits=0):
super(MockStore, self).__init__(bits, check_alias=False) super(MockStore, self).__init__(bits, check_alias=False)
class MockFilter: class MockFilter:
def __init__(self, z, name): def __init__(self, name, brk=False, z=None):
self.z = z
self.name = name self.name = name
if z == None:
h = hashlib.sha256()
h.update(self.name.encode('utf-8'))
z = h.digest()
self.z = z
self.brk = brk
def sum(self): def sum(self):
@ -30,43 +35,59 @@ class MockFilter:
return self.name return self.name
def filter(self, conn, block, tx):
return self.brk
class TestSync(unittest.TestCase): class TestSync(unittest.TestCase):
def setUp(self):
self.store = MockStore(6)
self.state = SyncState(self.store)
def test_basic(self): def test_basic(self):
session = SyncSession(self.state) store = MockStore(6)
state = SyncState(store)
session = SyncSession(None, state)
self.assertTrue(session.is_default) self.assertTrue(session.is_default)
session = SyncSession(self.state, session_id='foo') session = SyncSession(None, state, session_id='foo')
self.assertFalse(session.is_default) self.assertFalse(session.is_default)
def test_sum(self): def test_sum(self):
store = MockStore(4)
state = SyncState(store)
b = b'\x2a' * 32 b = b'\x2a' * 32
fltr = MockFilter(b, name='foo') fltr = MockFilter('foo', z=b)
self.state.register(fltr) state.register(fltr)
b = b'\x0d' * 31 b = b'\x0d' * 31
fltr = MockFilter(b, name='bar') fltr = MockFilter('bar', z=b)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
self.state.register(fltr) state.register(fltr)
b = b'\x0d' * 32 b = b'\x0d' * 32
fltr = MockFilter(b, name='bar') fltr = MockFilter('bar', z=b)
self.state.register(fltr) state.register(fltr)
v = self.state.sum() v = state.sum()
self.assertEqual(v.hex(), 'a24abf9fec112b4e0210ae874b4a371f8657b1ee0d923ad6d974aef90bad8550') self.assertEqual(v.hex(), 'a24abf9fec112b4e0210ae874b4a371f8657b1ee0d923ad6d974aef90bad8550')
def test_session_start(self): def test_session_start(self):
session = SyncSession(self.state) store = MockStore(6)
state = SyncState(store)
session = SyncSession(None, state)
session.start() session.start()
def test_state_dynamic(self):
store = MockStore()
state = SyncState(store)
b = b'\x0d' * 32
fltr = MockFilter(name='foo', z=b)
state.register(fltr)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()