107 lines
2.4 KiB
Python
107 lines
2.4 KiB
Python
# standard imports
|
||
import logging
|
||
import unittest
|
||
import os
|
||
|
||
# external imports
|
||
from chainlib.chain import ChainSpec
|
||
|
||
# local imports
|
||
from chainsyncer.backend.memory import MemBackend
|
||
from chainsyncer.backend.sql import SyncerBackend
|
||
|
||
# test imports
|
||
from tests.base import TestBase
|
||
from chainsyncer.unittest.base import (
|
||
MockBlock,
|
||
MockConn,
|
||
TestSyncer,
|
||
)
|
||
|
||
logging.basicConfig(level=logging.DEBUG)
|
||
logg = logging.getLogger()
|
||
|
||
|
||
class NaughtyCountExceptionFilter:
|
||
|
||
def __init__(self, name, croak_on):
|
||
self.c = 0
|
||
self.croak = croak_on
|
||
self.name = name
|
||
|
||
|
||
def filter(self, conn, block, tx, db_session=None):
|
||
if self.c == self.croak:
|
||
self.croak = -1
|
||
raise RuntimeError('foo')
|
||
self.c += 1
|
||
|
||
|
||
def __str__(self):
|
||
return '{} {}'.format(self.__class__.__name__, self.name)
|
||
|
||
|
||
class CountFilter:
|
||
|
||
def __init__(self, name):
|
||
self.c = 0
|
||
self.name = name
|
||
|
||
|
||
def filter(self, conn, block, tx, db_session=None):
|
||
self.c += 1
|
||
|
||
|
||
def __str__(self):
|
||
return '{} {}'.format(self.__class__.__name__, self.name)
|
||
|
||
|
||
|
||
class TestInterrupt(TestBase):
|
||
|
||
def setUp(self):
|
||
super(TestInterrupt, self).setUp()
|
||
self.filters = [
|
||
CountFilter('foo'),
|
||
CountFilter('bar'),
|
||
NaughtyCountExceptionFilter('xyzzy', croak_on=3),
|
||
CountFilter('baz'),
|
||
]
|
||
self.backend = None
|
||
self.conn = MockConn()
|
||
|
||
|
||
def assert_filter_interrupt(self):
|
||
|
||
syncer = TestSyncer(self.backend, [4, 3, 2])
|
||
|
||
for fltr in self.filters:
|
||
syncer.add_filter(fltr)
|
||
|
||
try:
|
||
syncer.loop(0.1, self.conn)
|
||
except RuntimeError:
|
||
logg.info('caught croak')
|
||
pass
|
||
(pair, fltr) = self.backend.get()
|
||
self.assertGreater(fltr, 0)
|
||
syncer.loop(0.1, self.conn)
|
||
|
||
for fltr in self.filters:
|
||
logg.debug('{} {}'.format(str(fltr), fltr.c))
|
||
self.assertEqual(fltr.c, 9)
|
||
|
||
|
||
def test_filter_interrupt_memory(self):
|
||
self.backend = MemBackend(self.chain_spec, None, target_block=4)
|
||
self.assert_filter_interrupt()
|
||
|
||
|
||
def test_filter_interrupt_sql(self):
|
||
self.backend = SyncerBackend.initial(self.chain_spec, 4)
|
||
self.assert_filter_interrupt()
|
||
|
||
|
||
if __name__ == '__main__':
|
||
unittest.main()
|