Add filter tests

This commit is contained in:
lash 2022-03-14 21:17:31 +00:00
parent 1ae96670c7
commit 25ac641476
Signed by: lash
GPG Key ID: 21D2E7BB88C2A746
4 changed files with 68 additions and 4 deletions

View File

@ -1,3 +1,6 @@
# standard imports
import logging
# external imports # external imports
from chainlib.error import RPCException from chainlib.error import RPCException
from chainqueue import ( from chainqueue import (
@ -11,6 +14,8 @@ from chainqueue.store.fs import (
) )
from shep.store.file import SimpleFileStoreFactory from shep.store.file import SimpleFileStoreFactory
logg = logging.getLogger(__name__)
class ChaindFsAdapter: class ChaindFsAdapter:
@ -26,8 +31,8 @@ class ChaindFsAdapter:
def put(self, signed_tx): def put(self, signed_tx):
cache_tx = self.deserialize(signed_tx) cache_tx = self.deserialize(signed_tx)
self.store.put(cache_tx.tx_hash, signed_tx) self.store.put(cache_tx.hash, signed_tx)
return cache_tx.tx_hash return cache_tx.hash
def get(self, tx_hash): def get(self, tx_hash):
@ -47,6 +52,14 @@ class ChaindFsAdapter:
return self.store.deferred() return self.store.deferred()
def succeed(self, block, tx):
return self.store.final(tx.hash, block, tx, error=False)
def fail(self, block, tx):
return self.store.final(tx.hash, block, tx, error=True)
def enqueue(self, tx_hash): def enqueue(self, tx_hash):
return self.store.enqueue(tx_hash) return self.store.enqueue(tx_hash)

View File

@ -20,7 +20,7 @@ class QueueDriver:
for i in range(c): for i in range(c):
self.adapter.enqueue(txs[i]) self.adapter.enqueue(txs[i])
if self.throttler != None: if self.throttler != None:
self.throttler.inc() self.throttler.inc(txs[i].hash)
return c return c

19
chaind/filter.py Normal file
View File

@ -0,0 +1,19 @@
# external imports
from chainlib.status import Status as TxStatus
class StateFilter:
def __init__(self, adapter, throttler=None):
self.adapter = adapter
self.throttler = throttler
def filter(self, conn, block, tx, session=None):
cache_tx = self.adapter.get(tx.hash)
if tx.status == TxStatus.SUCCESS:
self.adapter.succeed(block, tx)
else:
self.adapter.fail(block, tx)
if self.throttler != None:
self.throttler.dec(tx.hash)

View File

@ -10,10 +10,12 @@ import hashlib
from chainlib.chain import ChainSpec from chainlib.chain import ChainSpec
from chainqueue.cache import CacheTokenTx from chainqueue.cache import CacheTokenTx
from chainlib.error import RPCException from chainlib.error import RPCException
from chainlib.status import Status as TxStatus
# local imports # local imports
from chaind.adapters.new import ChaindFsAdapter from chaind.adapters.new import ChaindFsAdapter
from chaind.driver import QueueDriver from chaind.driver import QueueDriver
from chaind.filter import StateFilter
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
logg = logging.getLogger() logg = logging.getLogger()
@ -26,7 +28,7 @@ class MockCacheAdapter(CacheTokenTx):
h = hashlib.sha256() h = hashlib.sha256()
h.update(v.encode('utf-8')) h.update(v.encode('utf-8'))
z = h.digest() z = h.digest()
tx.tx_hash = z.hex() tx.hash = z.hex()
return tx return tx
@ -46,6 +48,14 @@ class MockDispatcher:
pass pass
class MockTx:
def __init__(self, tx_hash, status=TxStatus.SUCCESS):
self.hash = tx_hash
self.status = status
class TestChaindFs(unittest.TestCase): class TestChaindFs(unittest.TestCase):
def setUp(self): def setUp(self):
@ -89,5 +99,27 @@ class TestChaindFs(unittest.TestCase):
self.assertEqual(len(txs), 1) self.assertEqual(len(txs), 1)
def test_fs_filter(self):
drv = QueueDriver(self.adapter)
data = os.urandom(128).hex()
hsh = self.adapter.put(data)
fltr = StateFilter(self.adapter)
tx = MockTx(hsh)
fltr.filter(None, None, tx)
def test_fs_filter_fail(self):
drv = QueueDriver(self.adapter)
data = os.urandom(128).hex()
hsh = self.adapter.put(data)
fltr = StateFilter(self.adapter)
tx = MockTx(hsh, TxStatus.ERROR)
fltr.filter(None, None, tx)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()