Add interrupt test base
This commit is contained in:
		
							parent
							
								
									c738563d89
								
							
						
					
					
						commit
						908f762cd0
					
				| @ -14,6 +14,8 @@ class MemBackend: | ||||
|         self.flags = 0 | ||||
|         self.target_block = target_block | ||||
|         self.db_session = None | ||||
|         self.filter_names = [] | ||||
|         self.filter_values = [] | ||||
| 
 | ||||
| 
 | ||||
|     def connect(self): | ||||
| @ -28,6 +30,8 @@ class MemBackend: | ||||
|         logg.debug('stateless backend received {} {}'.format(block_height, tx_height)) | ||||
|         self.block_height = block_height | ||||
|         self.tx_height = tx_height | ||||
|         for i in range(len(self.filter_values)): | ||||
|             self.filter_values[i] = False | ||||
| 
 | ||||
| 
 | ||||
|     def get(self): | ||||
| @ -39,11 +43,13 @@ class MemBackend: | ||||
| 
 | ||||
| 
 | ||||
|     def register_filter(self, name): | ||||
|         pass | ||||
|         self.filter_names.append(name) | ||||
|         self.filter_values.append(False) | ||||
| 
 | ||||
| 
 | ||||
|     def complete_filter(self, n): | ||||
|         pass | ||||
|         self.filter_values[n-1] = True | ||||
|         logg.debug('set filter {}'.format(self.filter_names[n-1])) | ||||
| 
 | ||||
| 
 | ||||
|     def __str__(self): | ||||
|  | ||||
| @ -72,6 +72,11 @@ class Syncer: | ||||
|         self.backend.register_filter(str(f)) | ||||
| 
 | ||||
| 
 | ||||
|     def process_single(self, conn, block, tx, block_height, tx_index): | ||||
|         self.backend.set(block_height, tx_index) | ||||
|         self.filter.apply(conn, block, tx) | ||||
| 
 | ||||
| 
 | ||||
| class BlockPollSyncer(Syncer): | ||||
| 
 | ||||
|     def __init__(self, backend, pre_callback=None, block_callback=None, post_callback=None): | ||||
| @ -120,14 +125,16 @@ class HeadSyncer(BlockPollSyncer): | ||||
|         while True: | ||||
|             try: | ||||
|                 tx = block.tx(i) | ||||
|                 rcpt = conn.do(receipt(tx.hash)) | ||||
|                 tx.apply_receipt(rcpt) | ||||
|                 self.backend.set(block.number, i) | ||||
|                 self.filter.apply(conn, block, tx) | ||||
|             except IndexError as e: | ||||
|                 logg.debug('index error syncer rcpt get {}'.format(e)) | ||||
|                 self.backend.set(block.number + 1, 0) | ||||
|                 break | ||||
| 
 | ||||
|             rcpt = conn.do(receipt(tx.hash)) | ||||
|             tx.apply_receipt(rcpt) | ||||
|      | ||||
|             self.process_single(conn, block, tx, block.number, i) | ||||
|                          | ||||
|             i += 1 | ||||
|          | ||||
| 
 | ||||
|  | ||||
							
								
								
									
										2
									
								
								sql_requirements.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								sql_requirements.txt
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,2 @@ | ||||
| psycopg2==2.8.6 | ||||
| SQLAlchemy==1.3.20 | ||||
							
								
								
									
										106
									
								
								tests/test_interrupt.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										106
									
								
								tests/test_interrupt.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,106 @@ | ||||
| # standard imports | ||||
| import logging | ||||
| import unittest | ||||
| import os | ||||
| 
 | ||||
| # external imports | ||||
| from chainlib.chain import ChainSpec | ||||
| from hexathon import add_0x | ||||
| 
 | ||||
| # local imports | ||||
| from chainsyncer.backend.memory import MemBackend | ||||
| from chainsyncer.driver import HeadSyncer | ||||
| from chainsyncer.error import NoBlockForYou | ||||
| 
 | ||||
| # test imports | ||||
| from tests.base import TestBase | ||||
| 
 | ||||
| logging.basicConfig(level=logging.DEBUG) | ||||
| logg = logging.getLogger() | ||||
| 
 | ||||
| 
 | ||||
| class TestSyncer(HeadSyncer): | ||||
| 
 | ||||
| 
 | ||||
|     def __init__(self, backend, tx_counts=[]): | ||||
|         self.tx_counts = tx_counts | ||||
|         super(TestSyncer, self).__init__(backend) | ||||
| 
 | ||||
| 
 | ||||
|     def get(self, conn): | ||||
|         if self.backend.block_height == self.backend.target_block: | ||||
|             raise NoBlockForYou() | ||||
|         if self.backend.block_height > len(self.tx_counts): | ||||
|             return [] | ||||
| 
 | ||||
|         block_txs = [] | ||||
|         for i in range(self.tx_counts[self.backend.block_height]): | ||||
|             block_txs.append(add_0x(os.urandom(32).hex())) | ||||
|        | ||||
|         return block_txs | ||||
| 
 | ||||
| 
 | ||||
|     def process(self, conn, block): | ||||
|         i = 0 | ||||
|         for tx in block: | ||||
|             self.process_single(conn, block, tx, self.backend.block_height, i) | ||||
|             i += 1 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| class NaughtyCountExceptionFilter: | ||||
| 
 | ||||
|     def __init__(self, name, croak_on): | ||||
|         self.c = 0 | ||||
|         self.croak = croak_on | ||||
|         self.name = name | ||||
| 
 | ||||
| 
 | ||||
|     def filter(self, conn, block, tx, db_session=None): | ||||
|         self.c += 1 | ||||
|         if self.c == self.croak: | ||||
|             raise RuntimeError('foo') | ||||
| 
 | ||||
| 
 | ||||
|     def __str__(self): | ||||
|         return '{} {}'.format(self.__class__.__name__, self.name) | ||||
| 
 | ||||
| 
 | ||||
| class CountFilter: | ||||
| 
 | ||||
|     def __init__(self, name): | ||||
|         self.c = 0 | ||||
|         self.name = name | ||||
| 
 | ||||
| 
 | ||||
|     def filter(self, conn, block, tx, db_session=None): | ||||
|         self.c += 1 | ||||
| 
 | ||||
| 
 | ||||
|     def __str__(self): | ||||
|         return '{} {}'.format(self.__class__.__name__, self.name) | ||||
| 
 | ||||
| 
 | ||||
| class TestInterrupt(unittest.TestCase): | ||||
| 
 | ||||
|     def setUp(self): | ||||
|         self.chain_spec = ChainSpec('foo', 'bar', 42, 'baz') | ||||
|         self.backend = MemBackend(self.chain_spec, None, target_block=2) | ||||
|         self.syncer = TestSyncer(self.backend, [4, 2, 3]) | ||||
| 
 | ||||
|     def test_filter_interrupt(self): | ||||
|         | ||||
|         fltrs = [ | ||||
|             CountFilter('foo'), | ||||
|             CountFilter('bar'), | ||||
|             NaughtyCountExceptionFilter('xyzzy', 2), | ||||
|             CountFilter('baz'), | ||||
|                 ] | ||||
| 
 | ||||
|         for fltr in fltrs: | ||||
|             self.syncer.add_filter(fltr) | ||||
| 
 | ||||
|         self.syncer.loop(0.1, None) | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|     unittest.main() | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user