Add lock flags to model, backend

This commit is contained in:
nolash 2021-09-26 19:32:08 +02:00
parent 9f4362ad07
commit db6128f823
Signed by: lash
GPG Key ID: 21D2E7BB88C2A746
11 changed files with 104 additions and 19 deletions

View File

@ -86,7 +86,7 @@ class MemBackend(Backend):
self.filter_count += 1
def complete_filter(self, n):
def begin_filter(self, n):
"""Set filter at index as completed for the current block / tx state.
:param n: Filter index
@ -97,6 +97,10 @@ class MemBackend(Backend):
logg.debug('set filter {} {}'.format(self.filter_names[n], v))
def complete_filter(self, n):
pass
def reset_filter(self):
"""Set all filters to unprocessed for the current block / tx state.
"""
@ -104,11 +108,5 @@ class MemBackend(Backend):
self.flags = 0
# def get_flags(self):
# """Returns flags
# """
# return self.flags
def __str__(self):
return "syncer membackend {} chain {} cursor {}".format(self.object_id, self.chain(), self.get())

View File

@ -314,8 +314,8 @@ class SQLBackend(Backend):
self.disconnect()
def complete_filter(self, n):
"""Sets the filter at the given index as completed.
def begin_filter(self, n):
"""Marks start of execution of the filter indexed by the corresponding bit.
:param n: Filter index
:type n: int
@ -327,6 +327,14 @@ class SQLBackend(Backend):
self.disconnect()
def complete_filter(self, n):
self.connect()
self.db_object_filter.release(check_bit=n)
self.db_session.add(self.db_object_filter)
self.db_session.commit()
self.disconnect()
def reset_filter(self):
"""Reset all filter states.
"""

View File

@ -21,6 +21,7 @@ def upgrade():
sa.Column('id', sa.Integer, primary_key=True),
sa.Column('chain_sync_id', sa.Integer, sa.ForeignKey('chain_sync.id'), nullable=True),
sa.Column('flags', sa.LargeBinary, nullable=True),
sa.Column('flags_lock', sa.Integer, nullable=False, default=0),
sa.Column('flags_start', sa.LargeBinary, nullable=True),
sa.Column('count', sa.Integer, nullable=False, default=0),
sa.Column('digest', sa.String(64), nullable=False),

View File

@ -9,6 +9,7 @@ from sqlalchemy.ext.hybrid import hybrid_property, hybrid_method
# local imports
from .base import SessionBase
from .sync import BlockchainSync
from chainsyncer.error import LockError
zero_digest = bytes(32).hex()
logg = logging.getLogger(__name__)
@ -32,6 +33,7 @@ class BlockchainSyncFilter(SessionBase):
chain_sync_id = Column(Integer, ForeignKey('chain_sync.id'))
flags_start = Column(LargeBinary)
flags = Column(LargeBinary)
flags_lock = Column(Integer)
digest = Column(String(64))
count = Column(Integer)
@ -47,10 +49,20 @@ class BlockchainSyncFilter(SessionBase):
flags = flags.to_bytes(bytecount, 'big')
self.flags_start = flags
self.flags = flags
self.flags_lock = 0
self.chain_sync_id = chain_sync.id
@staticmethod
def load(sync_id, session=None):
q = session.query(BlockchainSyncFilter)
q = q.filter(BlockchainSyncFilter.chain_sync_id==sync_id)
o = q.first()
if o.is_locked():
raise LockError('locked state for flag {} of sync id {} must be manually resolved'.format(o.flags_lock))
def add(self, name):
"""Add a new filter to the syncer record.
@ -106,9 +118,16 @@ class BlockchainSyncFilter(SessionBase):
return (n, self.count, self.digest)
def is_locked(self):
return self.flags_lock > 0
def clear(self):
"""Set current filter flag value to zero.
"""
if self.is_locked():
raise LockError('flag clear attempted when lock set at {}'.format(self.flags_lock))
self.flags = bytearray(len(self.flags))
@ -120,9 +139,14 @@ class BlockchainSyncFilter(SessionBase):
:raises IndexError: Invalid flag index
:raises AttributeError: Flag at index already set
"""
if self.is_locked():
raise LockError('flag set attempted when lock set at {}'.format(self.flags_lock))
if n > self.count:
raise IndexError('bit flag out of range')
self.flags_lock = n
b = 1 << (n % 8)
i = int(n / 8)
byte_idx = len(self.flags)-1-i
@ -131,3 +155,10 @@ class BlockchainSyncFilter(SessionBase):
flags = bytearray(self.flags)
flags[byte_idx] |= b
self.flags = flags
def release(self, check_bit=0):
if check_bit > 0:
if self.flags_lock > 0 and self.flags_lock != check_bit:
raise LockError('release attemped on explicit bit {}, but bit {} was locked'.format(check_bit, self.flags_lock))
self.flags_lock = 0

View File

@ -21,6 +21,12 @@ class BackendError(Exception):
pass
class LockError(Exception):
"""Base exception for attempting to manipulate a locked property
"""
pass
#class AbortTx(Exception):
# """
# """

View File

@ -36,6 +36,7 @@ class SyncFilter:
def __apply_one(self, fltr, idx, conn, block, tx, session):
self.backend.begin_filter(idx)
fltr.filter(conn, block, tx, session)
self.backend.complete_filter(idx)

View File

@ -1,4 +1,4 @@
confini>=0.3.6rc3,<0.5.0
semver==2.13.0
hexathon~=0.0.1a8
chainlib>=0.0.9a2,<=0.1.0
chainlib>=0.0.9a11,<=0.1.0

View File

@ -1,6 +1,6 @@
[metadata]
name = chainsyncer
version = 0.0.6a3
version = 0.0.7a1
description = Generic blockchain syncer driver
author = Louis Holbrook
author_email = dev@holbrook.no

View File

@ -1,4 +1,4 @@
chainlib-eth~=0.0.9a4
chainlib-eth~=0.0.9a14
psycopg2==2.8.6
SQLAlchemy==1.3.20
alembic==1.4.2

View File

@ -9,6 +9,7 @@ from chainlib.chain import ChainSpec
from chainsyncer.db.models.base import SessionBase
from chainsyncer.db.models.filter import BlockchainSyncFilter
from chainsyncer.backend.sql import SQLBackend
from chainsyncer.error import LockError
# testutil imports
from tests.chainsyncer_base import TestBase
@ -31,6 +32,35 @@ class TestDatabase(TestBase):
self.assertIsNone(sync_id)
def test_backend_filter_lock(self):
s = SQLBackend.live(self.chain_spec, 42)
s.connect()
filter_id = s.db_object_filter.id
s.disconnect()
session = SessionBase.create_session()
o = session.query(BlockchainSyncFilter).get(filter_id)
self.assertEqual(len(o.flags), 0)
session.close()
s.register_filter(str(0))
s.register_filter(str(1))
s.connect()
filter_id = s.db_object_filter.id
s.disconnect()
session = SessionBase.create_session()
o = session.query(BlockchainSyncFilter).get(filter_id)
o.set(1)
with self.assertRaises(LockError):
o.set(2)
o.release()
o.set(2)
def test_backend_filter(self):
s = SQLBackend.live(self.chain_spec, 42)
@ -59,6 +89,7 @@ class TestDatabase(TestBase):
for i in range(9):
o.set(i)
o.release()
(f, c, d) = o.cursor()
self.assertEqual(f, t)
@ -144,8 +175,8 @@ class TestDatabase(TestBase):
s.register_filter('baz')
s.set(43, 13)
s.complete_filter(0)
s.complete_filter(2)
s.begin_filter(0)
s.begin_filter(2)
s = SQLBackend.resume(self.chain_spec, 666)
(pair, flags) = s[0].get()

View File

@ -14,6 +14,7 @@ from chainsyncer.backend.file import (
FileBackend,
data_dir_for,
)
from chainsyncer.error import LockError
# test imports
from tests.chainsyncer_base import TestBase
@ -36,10 +37,10 @@ class NaughtyCountExceptionFilter:
def filter(self, conn, block, tx, db_session=None):
self.c += 1
if self.c == self.croak:
self.croak = -1
raise RuntimeError('foo')
self.c += 1
def __str__(self):
@ -75,6 +76,7 @@ class TestInterrupt(TestBase):
[6, 5, 2],
[6, 4, 3],
]
self.track_complete = True
def assert_filter_interrupt(self, vector, chain_interface):
@ -100,11 +102,17 @@ class TestInterrupt(TestBase):
try:
syncer.loop(0.1, self.conn)
except RuntimeError:
self.croaked = 2
logg.info('caught croak')
pass
(pair, fltr) = self.backend.get()
self.assertGreater(fltr, 0)
syncer.loop(0.1, self.conn)
try:
syncer.loop(0.1, self.conn)
except LockError:
self.backend.complete_filter(2)
syncer.loop(0.1, self.conn)
for fltr in filters:
logg.debug('{} {}'.format(str(fltr), fltr.c))
@ -112,11 +120,13 @@ class TestInterrupt(TestBase):
def test_filter_interrupt_memory(self):
self.track_complete = True
for vector in self.vectors:
self.backend = MemBackend(self.chain_spec, None, target_block=len(vector))
self.assert_filter_interrupt(vector, self.interface)
#TODO: implement flag lock in file backend
@unittest.expectedFailure
def test_filter_interrupt_file(self):
#for vector in self.vectors:
vector = self.vectors.pop()
@ -127,12 +137,11 @@ class TestInterrupt(TestBase):
def test_filter_interrupt_sql(self):
self.track_complete = True
for vector in self.vectors:
self.backend = SQLBackend.initial(self.chain_spec, len(vector))
self.assert_filter_interrupt(vector, self.interface)
if __name__ == '__main__':
unittest.main()