Add lock flags to model, backend

This commit is contained in:
nolash 2021-09-26 19:32:08 +02:00
parent 9f4362ad07
commit db6128f823
Signed by: lash
GPG Key ID: 21D2E7BB88C2A746
11 changed files with 104 additions and 19 deletions

View File

@ -86,7 +86,7 @@ class MemBackend(Backend):
self.filter_count += 1 self.filter_count += 1
def complete_filter(self, n): def begin_filter(self, n):
"""Set filter at index as completed for the current block / tx state. """Set filter at index as completed for the current block / tx state.
:param n: Filter index :param n: Filter index
@ -97,6 +97,10 @@ class MemBackend(Backend):
logg.debug('set filter {} {}'.format(self.filter_names[n], v)) logg.debug('set filter {} {}'.format(self.filter_names[n], v))
def complete_filter(self, n):
pass
def reset_filter(self): def reset_filter(self):
"""Set all filters to unprocessed for the current block / tx state. """Set all filters to unprocessed for the current block / tx state.
""" """
@ -104,11 +108,5 @@ class MemBackend(Backend):
self.flags = 0 self.flags = 0
# def get_flags(self):
# """Returns flags
# """
# return self.flags
def __str__(self): def __str__(self):
return "syncer membackend {} chain {} cursor {}".format(self.object_id, self.chain(), self.get()) return "syncer membackend {} chain {} cursor {}".format(self.object_id, self.chain(), self.get())

View File

@ -314,8 +314,8 @@ class SQLBackend(Backend):
self.disconnect() self.disconnect()
def complete_filter(self, n): def begin_filter(self, n):
"""Sets the filter at the given index as completed. """Marks start of execution of the filter indexed by the corresponding bit.
:param n: Filter index :param n: Filter index
:type n: int :type n: int
@ -327,6 +327,14 @@ class SQLBackend(Backend):
self.disconnect() self.disconnect()
def complete_filter(self, n):
self.connect()
self.db_object_filter.release(check_bit=n)
self.db_session.add(self.db_object_filter)
self.db_session.commit()
self.disconnect()
def reset_filter(self): def reset_filter(self):
"""Reset all filter states. """Reset all filter states.
""" """

View File

@ -21,6 +21,7 @@ def upgrade():
sa.Column('id', sa.Integer, primary_key=True), sa.Column('id', sa.Integer, primary_key=True),
sa.Column('chain_sync_id', sa.Integer, sa.ForeignKey('chain_sync.id'), nullable=True), sa.Column('chain_sync_id', sa.Integer, sa.ForeignKey('chain_sync.id'), nullable=True),
sa.Column('flags', sa.LargeBinary, nullable=True), sa.Column('flags', sa.LargeBinary, nullable=True),
sa.Column('flags_lock', sa.Integer, nullable=False, default=0),
sa.Column('flags_start', sa.LargeBinary, nullable=True), sa.Column('flags_start', sa.LargeBinary, nullable=True),
sa.Column('count', sa.Integer, nullable=False, default=0), sa.Column('count', sa.Integer, nullable=False, default=0),
sa.Column('digest', sa.String(64), nullable=False), sa.Column('digest', sa.String(64), nullable=False),

View File

@ -9,6 +9,7 @@ from sqlalchemy.ext.hybrid import hybrid_property, hybrid_method
# local imports # local imports
from .base import SessionBase from .base import SessionBase
from .sync import BlockchainSync from .sync import BlockchainSync
from chainsyncer.error import LockError
zero_digest = bytes(32).hex() zero_digest = bytes(32).hex()
logg = logging.getLogger(__name__) logg = logging.getLogger(__name__)
@ -32,6 +33,7 @@ class BlockchainSyncFilter(SessionBase):
chain_sync_id = Column(Integer, ForeignKey('chain_sync.id')) chain_sync_id = Column(Integer, ForeignKey('chain_sync.id'))
flags_start = Column(LargeBinary) flags_start = Column(LargeBinary)
flags = Column(LargeBinary) flags = Column(LargeBinary)
flags_lock = Column(Integer)
digest = Column(String(64)) digest = Column(String(64))
count = Column(Integer) count = Column(Integer)
@ -47,10 +49,20 @@ class BlockchainSyncFilter(SessionBase):
flags = flags.to_bytes(bytecount, 'big') flags = flags.to_bytes(bytecount, 'big')
self.flags_start = flags self.flags_start = flags
self.flags = flags self.flags = flags
self.flags_lock = 0
self.chain_sync_id = chain_sync.id self.chain_sync_id = chain_sync.id
@staticmethod
def load(sync_id, session=None):
q = session.query(BlockchainSyncFilter)
q = q.filter(BlockchainSyncFilter.chain_sync_id==sync_id)
o = q.first()
if o.is_locked():
raise LockError('locked state for flag {} of sync id {} must be manually resolved'.format(o.flags_lock))
def add(self, name): def add(self, name):
"""Add a new filter to the syncer record. """Add a new filter to the syncer record.
@ -106,9 +118,16 @@ class BlockchainSyncFilter(SessionBase):
return (n, self.count, self.digest) return (n, self.count, self.digest)
def is_locked(self):
return self.flags_lock > 0
def clear(self): def clear(self):
"""Set current filter flag value to zero. """Set current filter flag value to zero.
""" """
if self.is_locked():
raise LockError('flag clear attempted when lock set at {}'.format(self.flags_lock))
self.flags = bytearray(len(self.flags)) self.flags = bytearray(len(self.flags))
@ -120,9 +139,14 @@ class BlockchainSyncFilter(SessionBase):
:raises IndexError: Invalid flag index :raises IndexError: Invalid flag index
:raises AttributeError: Flag at index already set :raises AttributeError: Flag at index already set
""" """
if self.is_locked():
raise LockError('flag set attempted when lock set at {}'.format(self.flags_lock))
if n > self.count: if n > self.count:
raise IndexError('bit flag out of range') raise IndexError('bit flag out of range')
self.flags_lock = n
b = 1 << (n % 8) b = 1 << (n % 8)
i = int(n / 8) i = int(n / 8)
byte_idx = len(self.flags)-1-i byte_idx = len(self.flags)-1-i
@ -131,3 +155,10 @@ class BlockchainSyncFilter(SessionBase):
flags = bytearray(self.flags) flags = bytearray(self.flags)
flags[byte_idx] |= b flags[byte_idx] |= b
self.flags = flags self.flags = flags
def release(self, check_bit=0):
if check_bit > 0:
if self.flags_lock > 0 and self.flags_lock != check_bit:
raise LockError('release attemped on explicit bit {}, but bit {} was locked'.format(check_bit, self.flags_lock))
self.flags_lock = 0

View File

@ -21,6 +21,12 @@ class BackendError(Exception):
pass pass
class LockError(Exception):
"""Base exception for attempting to manipulate a locked property
"""
pass
#class AbortTx(Exception): #class AbortTx(Exception):
# """ # """
# """ # """

View File

@ -36,6 +36,7 @@ class SyncFilter:
def __apply_one(self, fltr, idx, conn, block, tx, session): def __apply_one(self, fltr, idx, conn, block, tx, session):
self.backend.begin_filter(idx)
fltr.filter(conn, block, tx, session) fltr.filter(conn, block, tx, session)
self.backend.complete_filter(idx) self.backend.complete_filter(idx)

View File

@ -1,4 +1,4 @@
confini>=0.3.6rc3,<0.5.0 confini>=0.3.6rc3,<0.5.0
semver==2.13.0 semver==2.13.0
hexathon~=0.0.1a8 hexathon~=0.0.1a8
chainlib>=0.0.9a2,<=0.1.0 chainlib>=0.0.9a11,<=0.1.0

View File

@ -1,6 +1,6 @@
[metadata] [metadata]
name = chainsyncer name = chainsyncer
version = 0.0.6a3 version = 0.0.7a1
description = Generic blockchain syncer driver description = Generic blockchain syncer driver
author = Louis Holbrook author = Louis Holbrook
author_email = dev@holbrook.no author_email = dev@holbrook.no

View File

@ -1,4 +1,4 @@
chainlib-eth~=0.0.9a4 chainlib-eth~=0.0.9a14
psycopg2==2.8.6 psycopg2==2.8.6
SQLAlchemy==1.3.20 SQLAlchemy==1.3.20
alembic==1.4.2 alembic==1.4.2

View File

@ -9,6 +9,7 @@ from chainlib.chain import ChainSpec
from chainsyncer.db.models.base import SessionBase from chainsyncer.db.models.base import SessionBase
from chainsyncer.db.models.filter import BlockchainSyncFilter from chainsyncer.db.models.filter import BlockchainSyncFilter
from chainsyncer.backend.sql import SQLBackend from chainsyncer.backend.sql import SQLBackend
from chainsyncer.error import LockError
# testutil imports # testutil imports
from tests.chainsyncer_base import TestBase from tests.chainsyncer_base import TestBase
@ -31,6 +32,35 @@ class TestDatabase(TestBase):
self.assertIsNone(sync_id) self.assertIsNone(sync_id)
def test_backend_filter_lock(self):
s = SQLBackend.live(self.chain_spec, 42)
s.connect()
filter_id = s.db_object_filter.id
s.disconnect()
session = SessionBase.create_session()
o = session.query(BlockchainSyncFilter).get(filter_id)
self.assertEqual(len(o.flags), 0)
session.close()
s.register_filter(str(0))
s.register_filter(str(1))
s.connect()
filter_id = s.db_object_filter.id
s.disconnect()
session = SessionBase.create_session()
o = session.query(BlockchainSyncFilter).get(filter_id)
o.set(1)
with self.assertRaises(LockError):
o.set(2)
o.release()
o.set(2)
def test_backend_filter(self): def test_backend_filter(self):
s = SQLBackend.live(self.chain_spec, 42) s = SQLBackend.live(self.chain_spec, 42)
@ -59,6 +89,7 @@ class TestDatabase(TestBase):
for i in range(9): for i in range(9):
o.set(i) o.set(i)
o.release()
(f, c, d) = o.cursor() (f, c, d) = o.cursor()
self.assertEqual(f, t) self.assertEqual(f, t)
@ -144,8 +175,8 @@ class TestDatabase(TestBase):
s.register_filter('baz') s.register_filter('baz')
s.set(43, 13) s.set(43, 13)
s.complete_filter(0) s.begin_filter(0)
s.complete_filter(2) s.begin_filter(2)
s = SQLBackend.resume(self.chain_spec, 666) s = SQLBackend.resume(self.chain_spec, 666)
(pair, flags) = s[0].get() (pair, flags) = s[0].get()

View File

@ -14,6 +14,7 @@ from chainsyncer.backend.file import (
FileBackend, FileBackend,
data_dir_for, data_dir_for,
) )
from chainsyncer.error import LockError
# test imports # test imports
from tests.chainsyncer_base import TestBase from tests.chainsyncer_base import TestBase
@ -36,10 +37,10 @@ class NaughtyCountExceptionFilter:
def filter(self, conn, block, tx, db_session=None): def filter(self, conn, block, tx, db_session=None):
self.c += 1
if self.c == self.croak: if self.c == self.croak:
self.croak = -1 self.croak = -1
raise RuntimeError('foo') raise RuntimeError('foo')
self.c += 1
def __str__(self): def __str__(self):
@ -75,6 +76,7 @@ class TestInterrupt(TestBase):
[6, 5, 2], [6, 5, 2],
[6, 4, 3], [6, 4, 3],
] ]
self.track_complete = True
def assert_filter_interrupt(self, vector, chain_interface): def assert_filter_interrupt(self, vector, chain_interface):
@ -100,10 +102,16 @@ class TestInterrupt(TestBase):
try: try:
syncer.loop(0.1, self.conn) syncer.loop(0.1, self.conn)
except RuntimeError: except RuntimeError:
self.croaked = 2
logg.info('caught croak') logg.info('caught croak')
pass pass
(pair, fltr) = self.backend.get() (pair, fltr) = self.backend.get()
self.assertGreater(fltr, 0) self.assertGreater(fltr, 0)
try:
syncer.loop(0.1, self.conn)
except LockError:
self.backend.complete_filter(2)
syncer.loop(0.1, self.conn) syncer.loop(0.1, self.conn)
for fltr in filters: for fltr in filters:
@ -112,11 +120,13 @@ class TestInterrupt(TestBase):
def test_filter_interrupt_memory(self): def test_filter_interrupt_memory(self):
self.track_complete = True
for vector in self.vectors: for vector in self.vectors:
self.backend = MemBackend(self.chain_spec, None, target_block=len(vector)) self.backend = MemBackend(self.chain_spec, None, target_block=len(vector))
self.assert_filter_interrupt(vector, self.interface) self.assert_filter_interrupt(vector, self.interface)
#TODO: implement flag lock in file backend
@unittest.expectedFailure
def test_filter_interrupt_file(self): def test_filter_interrupt_file(self):
#for vector in self.vectors: #for vector in self.vectors:
vector = self.vectors.pop() vector = self.vectors.pop()
@ -127,12 +137,11 @@ class TestInterrupt(TestBase):
def test_filter_interrupt_sql(self): def test_filter_interrupt_sql(self):
self.track_complete = True
for vector in self.vectors: for vector in self.vectors:
self.backend = SQLBackend.initial(self.chain_spec, len(vector)) self.backend = SQLBackend.initial(self.chain_spec, len(vector))
self.assert_filter_interrupt(vector, self.interface) self.assert_filter_interrupt(vector, self.interface)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()