WIP Add queue driver logic

This commit is contained in:
lash 2022-03-14 19:53:29 +00:00
parent 9c641892ab
commit 1ae96670c7
Signed by: lash
GPG Key ID: 21D2E7BB88C2A746
3 changed files with 118 additions and 4 deletions

View File

@ -1,4 +1,5 @@
# external imports
from chainlib.error import RPCException
from chainqueue import (
Status,
Store as QueueStore,
@ -13,16 +14,17 @@ from shep.store.file import SimpleFileStoreFactory
class ChaindFsAdapter:
def __init__(self, chain_spec, path, deserializer, cache=None, pending_retry_threshold=0, error_retry_threshold=0, digest_bytes=32):
def __init__(self, chain_spec, path, deserializer, dispatcher, cache=None, pending_retry_threshold=0, error_retry_threshold=0, digest_bytes=32):
factory = SimpleFileStoreFactory(path).add
state_store = Status(factory)
index_store = IndexStore(path, digest_bytes=digest_bytes)
counter_store = CounterStore(path)
self.store = QueueStore(chain_spec, state_store, index_store, counter_store, cache=cache)
self.deserialize = deserializer
self.dispatcher = dispatcher
def add(self, signed_tx):
def put(self, signed_tx):
cache_tx = self.deserialize(signed_tx)
self.store.put(cache_tx.tx_hash, signed_tx)
return cache_tx.tx_hash
@ -31,3 +33,34 @@ class ChaindFsAdapter:
def get(self, tx_hash):
v = self.store.get(tx_hash)
return v[1]
def upcoming(self):
return self.store.upcoming()
def pending(self):
return self.store.pending()
def deferred(self):
return self.store.deferred()
def enqueue(self, tx_hash):
return self.store.enqueue(tx_hash)
def dispatch(self, tx_hash):
entry = self.store.send_start(tx_hash)
tx_wire = entry.serialize()
r = None
try:
r = self.dispatcher.send(tx_wire)
except RPCException:
self.store.fail(tx_hash)
return False
self.store.send_end(tx_hash)
return True

39
chaind/driver.py Normal file
View File

@ -0,0 +1,39 @@
# standard imports
import logging
logg = logging.getLogger(__name__)
class QueueDriver:
def __init__(self, adapter, throttler=None):
self.adapter = adapter
self.throttler = throttler
def __enqueue(self, txs):
c = len(txs)
if self.throttler != None:
r = self.throttler.count()
if r < c:
c = r
for i in range(c):
self.adapter.enqueue(txs[i])
if self.throttler != None:
self.throttler.inc()
return c
def process(self):
total = 0
txs = self.adapter.pending()
r = self.__enqueue(txs)
total += r
logg.debug('pending enqueued {} total {}'.format(r, total))
txs = self.adapter.deferred()
r = self.__enqueue(txs)
total += r
logg.debug('deferred enqueued {} total {}'.format(r, total))
return txs

View File

@ -9,9 +9,11 @@ import hashlib
# external imports
from chainlib.chain import ChainSpec
from chainqueue.cache import CacheTokenTx
from chainlib.error import RPCException
# local imports
from chaind.adapters.new import ChaindFsAdapter
from chaind.driver import QueueDriver
logging.basicConfig(level=logging.DEBUG)
logg = logging.getLogger()
@ -28,12 +30,29 @@ class MockCacheAdapter(CacheTokenTx):
return tx
class MockDispatcher:
def __init__(self):
self.fails = []
def add_fail(self, v):
self.fails.append(v)
def send(self, v):
if v not in self.fails:
raise RPCException('{} is in fails'.format(v))
pass
class TestChaindFs(unittest.TestCase):
def setUp(self):
self.chain_spec = ChainSpec('foo', 'bar', 42, 'baz')
self.path = tempfile.mkdtemp()
self.adapter = ChaindFsAdapter(self.chain_spec, self.path, MockCacheAdapter().deserialize)
self.dispatcher = MockDispatcher()
self.adapter = ChaindFsAdapter(self.chain_spec, self.path, MockCacheAdapter().deserialize, self.dispatcher)
def tearDown(self):
@ -42,10 +61,33 @@ class TestChaindFs(unittest.TestCase):
def test_fs_setup(self):
data = os.urandom(128).hex()
hsh = self.adapter.add(data)
hsh = self.adapter.put(data)
v = self.adapter.get(hsh)
self.assertEqual(data, v)
def test_fs_defer(self):
data = os.urandom(128).hex()
hsh = self.adapter.put(data)
self.dispatcher.add_fail(hsh)
self.adapter.dispatch(hsh)
txs = self.adapter.deferred()
self.assertEqual(len(txs), 1)
def test_fs_process(self):
drv = QueueDriver(self.adapter)
data = os.urandom(128).hex()
hsh = self.adapter.put(data)
txs = self.adapter.upcoming()
self.assertEqual(len(txs), 0)
drv.process()
txs = self.adapter.upcoming()
self.assertEqual(len(txs), 1)
if __name__ == '__main__':
unittest.main()