Implement filter integrity test in sql backend

This commit is contained in:
nolash 2021-04-15 15:06:07 +02:00
parent d1077bf87a
commit 987a18fd6b
Signed by: lash
GPG Key ID: 21D2E7BB88C2A746
5 changed files with 58 additions and 34 deletions

View File

@ -44,9 +44,9 @@ class MemBackend:
def complete_filter(self, n):
v = 1 << (n-1)
v = 1 << n
self.flags |= v
logg.debug('set filter {} {}'.format(self.filter_names[n-1], v))
logg.debug('set filter {} {}'.format(self.filter_names[n], v))
def reset_filter(self):

View File

@ -2,7 +2,7 @@
import logging
import uuid
# third-party imports
# imports
from chainlib.chain import ChainSpec
# local imports
@ -56,6 +56,9 @@ class SyncerBackend:
def disconnect(self):
"""Commits state of sync to backend.
"""
if self.db_session == None:
return
if self.db_object_filter != None:
self.db_session.add(self.db_object_filter)
self.db_session.add(self.db_object)
@ -97,7 +100,6 @@ class SyncerBackend:
"""
self.connect()
pair = self.db_object.set(block_height, tx_height)
self.db_object_filter.clear()
(filter_state, count, digest)= self.db_object_filter.cursor()
self.disconnect()
return (pair, filter_state,)
@ -294,5 +296,11 @@ class SyncerBackend:
self.disconnect()
def reset_filter(self):
self.connect()
self.db_object_filter.clear()
self.disconnect()
def __str__(self):
return "syncerbackend chain {} start {} target {}".format(self.chain(), self.start(), self.target())

View File

@ -36,16 +36,15 @@ class SyncFilter:
i = 0
(pair, flags) = self.backend.get()
for f in self.filters:
if flags & (1 << i) == 0:
logg.debug('applying filter {} {}'.format(str(f), flags))
f.filter(conn, block, tx, session)
self.backend.complete_filter(i)
else:
logg.debug('skipping previously applied filter {} {}'.format(str(f), flags))
i += 1
if flags & (1 << (i - 1)) > 0:
logg.debug('skipping previously applied filter {}'.format(str(f)))
continue
logg.debug('applying filter {}'.format(str(f)))
f.filter(conn, block, tx, session)
self.backend.complete_filter(i)
if session != None:
self.backend.disconnect()
self.backend.disconnect()
class NoopFilter:

View File

@ -39,19 +39,21 @@ class TestSyncer(HistorySyncer):
def get(self, conn):
if self.backend.block_height == self.backend.target_block:
(pair, fltr) = self.backend.get()
(target_block, fltr) = self.backend.target()
block_height = pair[0]
if block_height == target_block:
self.running = False
raise NoBlockForYou()
return []
block_txs = []
if self.backend.block_height < len(self.tx_counts):
for i in range(self.tx_counts[self.backend.block_height]):
if block_height < len(self.tx_counts):
for i in range(self.tx_counts[block_height]):
block_txs.append(add_0x(os.urandom(32).hex()))
logg.debug('get tx height {}'.format(self.backend.tx_height))
return MockBlock(self.backend.block_height, block_txs)
return MockBlock(block_height, block_txs)
# TODO: implement mock conn instead, and use HeadSyncer.process
@ -61,4 +63,4 @@ class TestSyncer(HistorySyncer):
self.process_single(conn, block, block.tx(i))
self.backend.reset_filter()
i += 1
self.backend.set(self.backend.block_height + 1, 0)
self.backend.set(block.number + 1, 0)

View File

@ -8,6 +8,7 @@ from chainlib.chain import ChainSpec
# local imports
from chainsyncer.backend.memory import MemBackend
from chainsyncer.backend.sql import SyncerBackend
# test imports
from tests.base import TestBase
@ -54,35 +55,49 @@ class CountFilter:
return '{} {}'.format(self.__class__.__name__, self.name)
class TestInterrupt(unittest.TestCase):
class TestInterrupt(TestBase):
def setUp(self):
self.chain_spec = ChainSpec('foo', 'bar', 42, 'baz')
self.backend = MemBackend(self.chain_spec, None, target_block=4)
self.syncer = TestSyncer(self.backend, [4, 2, 3])
def test_filter_interrupt(self):
fltrs = [
super(TestInterrupt, self).setUp()
self.filters = [
CountFilter('foo'),
CountFilter('bar'),
NaughtyCountExceptionFilter('xyzzy', 3),
CountFilter('baz'),
]
]
self.backend = None
for fltr in fltrs:
self.syncer.add_filter(fltr)
def assert_filter_interrupt(self):
syncer = TestSyncer(self.backend, [4, 2, 3])
for fltr in self.filters:
syncer.add_filter(fltr)
try:
self.syncer.loop(0.1, None)
syncer.loop(0.1, None)
except RuntimeError:
logg.info('caught croak')
pass
self.syncer.loop(0.1, None)
(pair, fltr) = self.backend.get()
self.assertGreater(fltr, 0)
syncer.loop(0.1, None)
for fltr in fltrs:
for fltr in self.filters:
logg.debug('{} {}'.format(str(fltr), fltr.c))
#self.assertEqual(fltr.c, 11)
self.assertEqual(fltr.c, 9)
def test_filter_interrupt_memory(self):
self.backend = MemBackend(self.chain_spec, None, target_block=4)
self.assert_filter_interrupt()
def test_filter_interrupt_sql(self):
self.backend = SyncerBackend.initial(self.chain_spec, 4)
self.assert_filter_interrupt()
if __name__ == '__main__':