Implement filter integrity test in sql backend

This commit is contained in:
nolash 2021-04-15 15:06:07 +02:00
parent d1077bf87a
commit 987a18fd6b
Signed by: lash
GPG Key ID: 21D2E7BB88C2A746
5 changed files with 58 additions and 34 deletions

View File

@ -44,9 +44,9 @@ class MemBackend:
def complete_filter(self, n): def complete_filter(self, n):
v = 1 << (n-1) v = 1 << n
self.flags |= v self.flags |= v
logg.debug('set filter {} {}'.format(self.filter_names[n-1], v)) logg.debug('set filter {} {}'.format(self.filter_names[n], v))
def reset_filter(self): def reset_filter(self):

View File

@ -2,7 +2,7 @@
import logging import logging
import uuid import uuid
# third-party imports # imports
from chainlib.chain import ChainSpec from chainlib.chain import ChainSpec
# local imports # local imports
@ -56,6 +56,9 @@ class SyncerBackend:
def disconnect(self): def disconnect(self):
"""Commits state of sync to backend. """Commits state of sync to backend.
""" """
if self.db_session == None:
return
if self.db_object_filter != None: if self.db_object_filter != None:
self.db_session.add(self.db_object_filter) self.db_session.add(self.db_object_filter)
self.db_session.add(self.db_object) self.db_session.add(self.db_object)
@ -97,7 +100,6 @@ class SyncerBackend:
""" """
self.connect() self.connect()
pair = self.db_object.set(block_height, tx_height) pair = self.db_object.set(block_height, tx_height)
self.db_object_filter.clear()
(filter_state, count, digest)= self.db_object_filter.cursor() (filter_state, count, digest)= self.db_object_filter.cursor()
self.disconnect() self.disconnect()
return (pair, filter_state,) return (pair, filter_state,)
@ -294,5 +296,11 @@ class SyncerBackend:
self.disconnect() self.disconnect()
def reset_filter(self):
self.connect()
self.db_object_filter.clear()
self.disconnect()
def __str__(self): def __str__(self):
return "syncerbackend chain {} start {} target {}".format(self.chain(), self.start(), self.target()) return "syncerbackend chain {} start {} target {}".format(self.chain(), self.start(), self.target())

View File

@ -36,16 +36,15 @@ class SyncFilter:
i = 0 i = 0
(pair, flags) = self.backend.get() (pair, flags) = self.backend.get()
for f in self.filters: for f in self.filters:
if flags & (1 << i) == 0:
logg.debug('applying filter {} {}'.format(str(f), flags))
f.filter(conn, block, tx, session)
self.backend.complete_filter(i)
else:
logg.debug('skipping previously applied filter {} {}'.format(str(f), flags))
i += 1 i += 1
if flags & (1 << (i - 1)) > 0:
logg.debug('skipping previously applied filter {}'.format(str(f)))
continue
logg.debug('applying filter {}'.format(str(f)))
f.filter(conn, block, tx, session)
self.backend.complete_filter(i)
if session != None:
self.backend.disconnect()
self.backend.disconnect()
class NoopFilter: class NoopFilter:

View File

@ -39,19 +39,21 @@ class TestSyncer(HistorySyncer):
def get(self, conn): def get(self, conn):
if self.backend.block_height == self.backend.target_block: (pair, fltr) = self.backend.get()
(target_block, fltr) = self.backend.target()
block_height = pair[0]
if block_height == target_block:
self.running = False self.running = False
raise NoBlockForYou() raise NoBlockForYou()
return [] return []
block_txs = [] block_txs = []
if self.backend.block_height < len(self.tx_counts): if block_height < len(self.tx_counts):
for i in range(self.tx_counts[self.backend.block_height]): for i in range(self.tx_counts[block_height]):
block_txs.append(add_0x(os.urandom(32).hex())) block_txs.append(add_0x(os.urandom(32).hex()))
logg.debug('get tx height {}'.format(self.backend.tx_height)) return MockBlock(block_height, block_txs)
return MockBlock(self.backend.block_height, block_txs)
# TODO: implement mock conn instead, and use HeadSyncer.process # TODO: implement mock conn instead, and use HeadSyncer.process
@ -61,4 +63,4 @@ class TestSyncer(HistorySyncer):
self.process_single(conn, block, block.tx(i)) self.process_single(conn, block, block.tx(i))
self.backend.reset_filter() self.backend.reset_filter()
i += 1 i += 1
self.backend.set(self.backend.block_height + 1, 0) self.backend.set(block.number + 1, 0)

View File

@ -8,6 +8,7 @@ from chainlib.chain import ChainSpec
# local imports # local imports
from chainsyncer.backend.memory import MemBackend from chainsyncer.backend.memory import MemBackend
from chainsyncer.backend.sql import SyncerBackend
# test imports # test imports
from tests.base import TestBase from tests.base import TestBase
@ -54,35 +55,49 @@ class CountFilter:
return '{} {}'.format(self.__class__.__name__, self.name) return '{} {}'.format(self.__class__.__name__, self.name)
class TestInterrupt(unittest.TestCase):
class TestInterrupt(TestBase):
def setUp(self): def setUp(self):
self.chain_spec = ChainSpec('foo', 'bar', 42, 'baz') super(TestInterrupt, self).setUp()
self.backend = MemBackend(self.chain_spec, None, target_block=4) self.filters = [
self.syncer = TestSyncer(self.backend, [4, 2, 3])
def test_filter_interrupt(self):
fltrs = [
CountFilter('foo'), CountFilter('foo'),
CountFilter('bar'), CountFilter('bar'),
NaughtyCountExceptionFilter('xyzzy', 3), NaughtyCountExceptionFilter('xyzzy', 3),
CountFilter('baz'), CountFilter('baz'),
] ]
self.backend = None
for fltr in fltrs: def assert_filter_interrupt(self):
self.syncer.add_filter(fltr)
syncer = TestSyncer(self.backend, [4, 2, 3])
for fltr in self.filters:
syncer.add_filter(fltr)
try: try:
self.syncer.loop(0.1, None) syncer.loop(0.1, None)
except RuntimeError: except RuntimeError:
logg.info('caught croak') logg.info('caught croak')
pass pass
self.syncer.loop(0.1, None) (pair, fltr) = self.backend.get()
self.assertGreater(fltr, 0)
syncer.loop(0.1, None)
for fltr in fltrs: for fltr in self.filters:
logg.debug('{} {}'.format(str(fltr), fltr.c)) logg.debug('{} {}'.format(str(fltr), fltr.c))
#self.assertEqual(fltr.c, 11) self.assertEqual(fltr.c, 9)
def test_filter_interrupt_memory(self):
self.backend = MemBackend(self.chain_spec, None, target_block=4)
self.assert_filter_interrupt()
def test_filter_interrupt_sql(self):
self.backend = SyncerBackend.initial(self.chain_spec, 4)
self.assert_filter_interrupt()
if __name__ == '__main__': if __name__ == '__main__':