Add filter to live instantiator, add filter start state

This commit is contained in:
nolash 2021-02-22 10:57:07 +01:00
parent 82e2674555
commit 81dd5753e8
Signed by: lash
GPG Key ID: 21D2E7BB88C2A746
9 changed files with 257 additions and 59 deletions

View File

@ -7,6 +7,7 @@ from chainlib.chain import ChainSpec
# local imports # local imports
from chainsyncer.db.models.sync import BlockchainSync from chainsyncer.db.models.sync import BlockchainSync
from chainsyncer.db.models.filter import BlockchainSyncFilter
from chainsyncer.db.models.base import SessionBase from chainsyncer.db.models.base import SessionBase
logg = logging.getLogger() logg = logging.getLogger()
@ -23,6 +24,7 @@ class SyncerBackend:
def __init__(self, chain_spec, object_id): def __init__(self, chain_spec, object_id):
self.db_session = None self.db_session = None
self.db_object = None self.db_object = None
self.db_object_filter = None
self.chain_spec = chain_spec self.chain_spec = chain_spec
self.object_id = object_id self.object_id = object_id
self.connect() self.connect()
@ -34,9 +36,17 @@ class SyncerBackend:
""" """
if self.db_session == None: if self.db_session == None:
self.db_session = SessionBase.create_session() self.db_session = SessionBase.create_session()
q = self.db_session.query(BlockchainSync) q = self.db_session.query(BlockchainSync)
q = q.filter(BlockchainSync.id==self.object_id) q = q.filter(BlockchainSync.id==self.object_id)
self.db_object = q.first() self.db_object = q.first()
if self.db_object != None:
qtwo = self.db_session.query(BlockchainSyncFilter)
qtwo = qtwo.join(BlockchainSync)
qtwo = qtwo.filter(BlockchainSync.id==self.db_object.id)
self.db_object_filter = qtwo.first()
if self.db_object == None: if self.db_object == None:
raise ValueError('sync entry with id {} not found'.format(self.object_id)) raise ValueError('sync entry with id {} not found'.format(self.object_id))
@ -44,6 +54,8 @@ class SyncerBackend:
def disconnect(self): def disconnect(self):
"""Commits state of sync to backend. """Commits state of sync to backend.
""" """
if self.db_object_filter != None:
self.db_session.add(self.db_object_filter)
self.db_session.add(self.db_object) self.db_session.add(self.db_object)
self.db_session.commit() self.db_session.commit()
self.db_session.close() self.db_session.close()
@ -67,8 +79,9 @@ class SyncerBackend:
""" """
self.connect() self.connect()
pair = self.db_object.cursor() pair = self.db_object.cursor()
filter_state = self.db_object_filter.filter()
self.disconnect() self.disconnect()
return pair return (pair, filter_state,)
def set(self, block_height, tx_height): def set(self, block_height, tx_height):
@ -82,8 +95,9 @@ class SyncerBackend:
""" """
self.connect() self.connect()
pair = self.db_object.set(block_height, tx_height) pair = self.db_object.set(block_height, tx_height)
filter_state = self.db_object_filter.filter()
self.disconnect() self.disconnect()
return pair return (pair, filter_state,)
def start(self): def start(self):
@ -94,8 +108,9 @@ class SyncerBackend:
""" """
self.connect() self.connect()
pair = self.db_object.start() pair = self.db_object.start()
filter_state = self.db_object_filter.start()
self.disconnect() self.disconnect()
return pair return (pair, filter_state,)
def target(self): def target(self):
@ -106,12 +121,13 @@ class SyncerBackend:
""" """
self.connect() self.connect()
target = self.db_object.target() target = self.db_object.target()
filter_state = self.db_object_filter.target()
self.disconnect() self.disconnect()
return target return (target, filter_target,)
@staticmethod @staticmethod
def first(chain): def first(chain_spec):
"""Returns the model object of the most recent syncer in backend. """Returns the model object of the most recent syncer in backend.
:param chain: Chain spec of chain that syncer is running for. :param chain: Chain spec of chain that syncer is running for.
@ -119,7 +135,12 @@ class SyncerBackend:
:returns: Last syncer object :returns: Last syncer object
:rtype: cic_eth.db.models.BlockchainSync :rtype: cic_eth.db.models.BlockchainSync
""" """
return BlockchainSync.first(chain) #return BlockchainSync.first(str(chain_spec))
object_id = BlockchainSync.first(str(chain_spec))
if object_id == None:
return None
return SyncerBackend(chain_spec, object_id)
@staticmethod @staticmethod
@ -193,15 +214,30 @@ class SyncerBackend:
""" """
object_id = None object_id = None
session = SessionBase.create_session() session = SessionBase.create_session()
o = BlockchainSync(str(chain_spec), block_height, 0, None) o = BlockchainSync(str(chain_spec), block_height, 0, None)
session.add(o) session.add(o)
session.commit() session.flush()
object_id = o.id object_id = o.id
of = BlockchainSyncFilter(o)
session.add(of)
session.commit()
session.close() session.close()
return SyncerBackend(chain_spec, object_id) return SyncerBackend(chain_spec, object_id)
def register_filter(self, name):
self.connect()
if self.db_object_filter == None:
self.db_object_filter = BlockchainSyncFilter(self.db_object)
self.db_object_filter.add(name)
self.db_session.add(self.db_object_filter)
self.disconnect()
class MemBackend: class MemBackend:
def __init__(self, chain_spec, object_id): def __init__(self, chain_spec, object_id):
@ -209,6 +245,7 @@ class MemBackend:
self.chain_spec = chain_spec self.chain_spec = chain_spec
self.block_height = 0 self.block_height = 0
self.tx_height = 0 self.tx_height = 0
self.flags = 0
self.db_session = None self.db_session = None

View File

@ -1,8 +1,18 @@
# stanard imports
import logging
# third-party imports # third-party imports
from sqlalchemy import Column, Integer from sqlalchemy import Column, Integer
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import (
StaticPool,
QueuePool,
AssertionPool,
)
logg = logging.getLogger()
Model = declarative_base(name='Model') Model = declarative_base(name='Model')
@ -21,7 +31,11 @@ class SessionBase(Model):
transactional = True transactional = True
"""Whether the database backend supports query transactions. Should be explicitly set by initialization code""" """Whether the database backend supports query transactions. Should be explicitly set by initialization code"""
poolable = True poolable = True
"""Whether the database backend supports query transactions. Should be explicitly set by initialization code""" """Whether the database backend supports connection pools. Should be explicitly set by initialization code"""
procedural = True
"""Whether the database backend supports stored procedures"""
localsessions = {}
"""Contains dictionary of sessions initiated by db model components"""
@staticmethod @staticmethod
@ -40,7 +54,7 @@ class SessionBase(Model):
@staticmethod @staticmethod
def connect(dsn, debug=False): def connect(dsn, pool_size=8, debug=False):
"""Create new database connection engine and connect to database backend. """Create new database connection engine and connect to database backend.
:param dsn: DSN string defining connection. :param dsn: DSN string defining connection.
@ -48,14 +62,28 @@ class SessionBase(Model):
""" """
e = None e = None
if SessionBase.poolable: if SessionBase.poolable:
e = create_engine( poolclass = QueuePool
dsn, if pool_size > 1:
max_overflow=50, e = create_engine(
pool_pre_ping=True, dsn,
pool_size=20, max_overflow=pool_size*3,
pool_recycle=10, pool_pre_ping=True,
echo=debug, pool_size=pool_size,
) pool_recycle=60,
poolclass=poolclass,
echo=debug,
)
else:
if debug:
poolclass = AssertionPool
else:
poolclass = StaticPool
e = create_engine(
dsn,
poolclass=poolclass,
echo=debug,
)
else: else:
e = create_engine( e = create_engine(
dsn, dsn,
@ -71,3 +99,24 @@ class SessionBase(Model):
""" """
SessionBase.engine.dispose() SessionBase.engine.dispose()
SessionBase.engine = None SessionBase.engine = None
@staticmethod
def bind_session(session=None):
localsession = session
if localsession == None:
localsession = SessionBase.create_session()
localsession_key = str(id(localsession))
logg.debug('creating new session {}'.format(localsession_key))
SessionBase.localsessions[localsession_key] = localsession
return localsession
@staticmethod
def release_session(session=None):
session.flush()
session_key = str(id(session))
if SessionBase.localsessions.get(session_key) != None:
logg.debug('destroying session {}'.format(session_key))
session.commit()
session.close()

View File

@ -1,40 +1,79 @@
# standard imports # standard imports
import logging
import hashlib import hashlib
# third-party imports # external imports
from sqlalchemy import Column, String, Integer, BLOB from sqlalchemy import Column, String, Integer, BLOB, ForeignKey
from sqlalchemy.ext.hybrid import hybrid_property, hybrid_method from sqlalchemy.ext.hybrid import hybrid_property, hybrid_method
# local imports # local imports
from .base import SessionBase from .base import SessionBase
from .sync import BlockchainSync
zero_digest = bytearray(32)
zero_digest = '{:<064s'.format('0') logg = logging.getLogger(__name__)
class BlockchainSyncFilter(SessionBase): class BlockchainSyncFilter(SessionBase):
__tablename__ = 'chain_sync_filter' __tablename__ = 'chain_sync_filter'
chain_sync_id = Column(Integer, ForeignKey='chain_sync.id') chain_sync_id = Column(Integer, ForeignKey('chain_sync.id'))
flags_start = Column(BLOB)
flags = Column(BLOB) flags = Column(BLOB)
digest = Column(String) digest = Column(BLOB)
count = Column(Integer) count = Column(Integer)
@staticmethod
def set(self, names):
def __init__(self, chain_sync, count=0, flags=None, digest=zero_digest):
def __init__(self, names, chain_sync, digest=None):
if len(names) == 0:
digest = zero_digest
elif digest == None:
h = hashlib.new('sha256')
for n in names:
h.update(n.encode('utf-8') + b'\x00')
z = h.digest()
digest = z.hex()
self.digest = digest self.digest = digest
self.count = len(names) self.count = count
self.flags = bytearray((len(names) -1 ) / 8 + 1)
if flags == None:
flags = bytearray(0)
self.flags_start = flags
self.flags = flags
self.chain_sync_id = chain_sync.id self.chain_sync_id = chain_sync.id
def add(self, name):
h = hashlib.new('sha256')
h.update(self.digest)
h.update(name.encode('utf-8'))
z = h.digest()
old_byte_count = int((self.count - 1) / 8 + 1)
new_byte_count = int((self.count) / 8 + 1)
logg.debug('old new {} {}'.format(old_byte_count, new_byte_count))
if old_byte_count != new_byte_count:
self.flags = bytearray(1) + self.flags
self.count += 1
self.digest = z
def start(self):
return self.flags_start
def cursor(self):
return self.flags_current
def clear(self):
self.flags = 0
def target(self):
n = 0
for i in range(self.count):
n |= 2 << i
return n
def set(self, n):
if self.flags & n > 0:
SessionBase.release_session(session)
raise AttributeError('Filter bit already set')
r.flags |= n

View File

@ -41,19 +41,23 @@ class BlockchainSync(SessionBase):
:type chain: str :type chain: str
:param session: Session to use. If not specified, a separate session will be created for this method only. :param session: Session to use. If not specified, a separate session will be created for this method only.
:type session: SqlAlchemy Session :type session: SqlAlchemy Session
:returns: True if sync record found :returns: Database primary key id of sync record
:rtype: bool :rtype: number|None
""" """
local_session = False session = SessionBase.bind_session(session)
if session == None:
session = SessionBase.create_session()
local_session = True
q = session.query(BlockchainSync.id) q = session.query(BlockchainSync.id)
q = q.filter(BlockchainSync.blockchain==chain) q = q.filter(BlockchainSync.blockchain==chain)
o = q.first() o = q.first()
if local_session:
session.close() if o == None:
return o == None return None
sync_id = o.id
SessionBase.release_session(session)
return sync_id
@staticmethod @staticmethod
@ -165,4 +169,4 @@ class BlockchainSync(SessionBase):
self.tx_cursor = tx_start self.tx_cursor = tx_start
self.block_target = block_target self.block_target = block_target
self.date_created = datetime.datetime.utcnow() self.date_created = datetime.datetime.utcnow()
self.date_modified = datetime.datetime.utcnow() self.date_updated = datetime.datetime.utcnow()

View File

@ -9,6 +9,7 @@ from .error import BackendError
logg = logging.getLogger(__name__) logg = logging.getLogger(__name__)
class SyncFilter: class SyncFilter:
def __init__(self, backend, safe=True): def __init__(self, backend, safe=True):
@ -32,11 +33,15 @@ class SyncFilter:
except sqlalchemy.exc.TimeoutError as e: except sqlalchemy.exc.TimeoutError as e:
self.backend.disconnect() self.backend.disconnect()
raise BackendError('database connection fail: {}'.format(e)) raise BackendError('database connection fail: {}'.format(e))
i = 0
for f in self.filters: for f in self.filters:
i += 1
logg.debug('applying filter {}'.format(str(f))) logg.debug('applying filter {}'.format(str(f)))
f.filter(conn, block, tx, self.backend.db_session) f.filter(conn, block, tx, self.backend.db_session)
self.backend.set_filter()
self.backend.disconnect() self.backend.disconnect()
class NoopFilter: class NoopFilter:
def filter(self, conn, block, tx, db_session=None): def filter(self, conn, block, tx, db_session=None):

View File

@ -1,13 +1,11 @@
CREATE TABLE IF NOT EXISTS chain_sync ( CREATE TABLE IF NOT EXISTS chain_sync (
id serial primary key not null, id integer primary key autoincrement,
blockchain varchar not null, blockchain varchar not null,
block_start int not null default 0, block_start integer not null default 0,
tx_start int not null default 0, tx_start integer not null default 0,
block_cursor int not null default 0, block_cursor integer not null default 0,
tx_cursor int not null default 0, tx_cursor integer not null default 0,
flags bytea not null, block_target integer default null,
num_flags int not null,
block_target int default null,
date_created timestamp not null, date_created timestamp not null,
date_updated timestamp default null date_updated timestamp default null
); );

View File

@ -1,8 +1,9 @@
CREATE TABLE IF NOT EXISTS chain_sync_filter ( CREATE TABLE IF NOT EXISTS chain_sync_filter (
id serial primary key not null, id integer primary key autoincrement not null,
chain_sync_id int not null, chain_sync_id integer not null,
flags bytea default null, flags bytea default null,
count int not null default 0, flags_start bytea default null,
count integer not null default 0,
digest char(64) not null default '0000000000000000000000000000000000000000000000000000000000000000', digest char(64) not null default '0000000000000000000000000000000000000000000000000000000000000000',
CONSTRAINT fk_chain_sync CONSTRAINT fk_chain_sync
FOREIGN KEY(chain_sync_id) FOREIGN KEY(chain_sync_id)

View File

@ -1,13 +1,21 @@
# standard imports
import logging
import unittest import unittest
import tempfile import tempfile
import os import os
#import pysqlite #import pysqlite
# external imports
from chainlib.chain import ChainSpec
# local imports
from chainsyncer.db import dsn_from_config from chainsyncer.db import dsn_from_config
from chainsyncer.db.models.base import SessionBase from chainsyncer.db.models.base import SessionBase
script_dir = os.path.realpath(os.path.dirname(__file__)) script_dir = os.path.realpath(os.path.dirname(__file__))
logging.basicConfig(level=logging.DEBUG)
class TestBase(unittest.TestCase): class TestBase(unittest.TestCase):
@ -23,7 +31,7 @@ class TestBase(unittest.TestCase):
SessionBase.poolable = False SessionBase.poolable = False
SessionBase.transactional = False SessionBase.transactional = False
SessionBase.procedural = False SessionBase.procedural = False
SessionBase.connect(dsn, debug=True) SessionBase.connect(dsn, debug=False)
f = open(os.path.join(script_dir, '..', 'sql', 'sqlite', '1.sql'), 'r') f = open(os.path.join(script_dir, '..', 'sql', 'sqlite', '1.sql'), 'r')
sql = f.read() sql = f.read()
@ -39,6 +47,8 @@ class TestBase(unittest.TestCase):
conn = SessionBase.engine.connect() conn = SessionBase.engine.connect()
conn.execute(sql) conn.execute(sql)
self.chain_spec = ChainSpec('evm', 'foo', 42, 'bar')
def tearDown(self): def tearDown(self):
SessionBase.disconnect() SessionBase.disconnect()
os.unlink(self.db_path) os.unlink(self.db_path)

55
tests/test_database.py Normal file
View File

@ -0,0 +1,55 @@
# standard imports
import unittest
# external imports
from chainlib.chain import ChainSpec
# local imports
from chainsyncer.db.models.base import SessionBase
from chainsyncer.db.models.filter import BlockchainSyncFilter
from chainsyncer.backend import SyncerBackend
# testutil imports
from tests.base import TestBase
class TestDatabase(TestBase):
def test_backend_live(self):
s = SyncerBackend.live(self.chain_spec, 42)
self.assertEqual(s.object_id, 1)
backend = SyncerBackend.first(self.chain_spec)
#SyncerBackend(self.chain_spec, sync_id)
self.assertEqual(backend.object_id, 1)
bogus_chain_spec = ChainSpec('bogus', 'foo', 13, 'baz')
sync_id = SyncerBackend.first(bogus_chain_spec)
self.assertIsNone(sync_id)
def test_backend_filter(self):
s = SyncerBackend.live(self.chain_spec, 42)
s.connect()
filter_id = s.db_object_filter.id
s.disconnect()
session = SessionBase.create_session()
o = session.query(BlockchainSyncFilter).get(filter_id)
self.assertEqual(len(o.flags), 0)
session.close()
for i in range(9):
s.register_filter(str(i))
s.connect()
filter_id = s.db_object_filter.id
s.disconnect()
session = SessionBase.create_session()
o = session.query(BlockchainSyncFilter).get(filter_id)
self.assertEqual(len(o.flags), 2)
session.close()
if __name__ == '__main__':
unittest.main()