diff --git a/chainsyncer/backend.py b/chainsyncer/backend.py index 3380bfc..5953ff2 100644 --- a/chainsyncer/backend.py +++ b/chainsyncer/backend.py @@ -7,6 +7,7 @@ from chainlib.chain import ChainSpec # local imports from chainsyncer.db.models.sync import BlockchainSync +from chainsyncer.db.models.filter import BlockchainSyncFilter from chainsyncer.db.models.base import SessionBase logg = logging.getLogger() @@ -23,6 +24,7 @@ class SyncerBackend: def __init__(self, chain_spec, object_id): self.db_session = None self.db_object = None + self.db_object_filter = None self.chain_spec = chain_spec self.object_id = object_id self.connect() @@ -34,9 +36,17 @@ class SyncerBackend: """ if self.db_session == None: self.db_session = SessionBase.create_session() + q = self.db_session.query(BlockchainSync) q = q.filter(BlockchainSync.id==self.object_id) self.db_object = q.first() + + if self.db_object != None: + qtwo = self.db_session.query(BlockchainSyncFilter) + qtwo = qtwo.join(BlockchainSync) + qtwo = qtwo.filter(BlockchainSync.id==self.db_object.id) + self.db_object_filter = qtwo.first() + if self.db_object == None: raise ValueError('sync entry with id {} not found'.format(self.object_id)) @@ -44,6 +54,8 @@ class SyncerBackend: def disconnect(self): """Commits state of sync to backend. """ + if self.db_object_filter != None: + self.db_session.add(self.db_object_filter) self.db_session.add(self.db_object) self.db_session.commit() self.db_session.close() @@ -67,8 +79,9 @@ class SyncerBackend: """ self.connect() pair = self.db_object.cursor() + filter_state = self.db_object_filter.filter() self.disconnect() - return pair + return (pair, filter_state,) def set(self, block_height, tx_height): @@ -82,8 +95,9 @@ class SyncerBackend: """ self.connect() pair = self.db_object.set(block_height, tx_height) + filter_state = self.db_object_filter.filter() self.disconnect() - return pair + return (pair, filter_state,) def start(self): @@ -94,8 +108,9 @@ class SyncerBackend: """ self.connect() pair = self.db_object.start() + filter_state = self.db_object_filter.start() self.disconnect() - return pair + return (pair, filter_state,) def target(self): @@ -106,12 +121,13 @@ class SyncerBackend: """ self.connect() target = self.db_object.target() + filter_state = self.db_object_filter.target() self.disconnect() - return target + return (target, filter_target,) @staticmethod - def first(chain): + def first(chain_spec): """Returns the model object of the most recent syncer in backend. :param chain: Chain spec of chain that syncer is running for. @@ -119,7 +135,12 @@ class SyncerBackend: :returns: Last syncer object :rtype: cic_eth.db.models.BlockchainSync """ - return BlockchainSync.first(chain) + #return BlockchainSync.first(str(chain_spec)) + object_id = BlockchainSync.first(str(chain_spec)) + if object_id == None: + return None + return SyncerBackend(chain_spec, object_id) + @staticmethod @@ -193,15 +214,30 @@ class SyncerBackend: """ object_id = None session = SessionBase.create_session() + o = BlockchainSync(str(chain_spec), block_height, 0, None) session.add(o) - session.commit() + session.flush() object_id = o.id + + of = BlockchainSyncFilter(o) + session.add(of) + session.commit() + session.close() return SyncerBackend(chain_spec, object_id) + def register_filter(self, name): + self.connect() + if self.db_object_filter == None: + self.db_object_filter = BlockchainSyncFilter(self.db_object) + self.db_object_filter.add(name) + self.db_session.add(self.db_object_filter) + self.disconnect() + + class MemBackend: def __init__(self, chain_spec, object_id): @@ -209,6 +245,7 @@ class MemBackend: self.chain_spec = chain_spec self.block_height = 0 self.tx_height = 0 + self.flags = 0 self.db_session = None diff --git a/chainsyncer/db/models/base.py b/chainsyncer/db/models/base.py index 153906a..db570df 100644 --- a/chainsyncer/db/models/base.py +++ b/chainsyncer/db/models/base.py @@ -1,8 +1,18 @@ +# stanard imports +import logging + # third-party imports from sqlalchemy import Column, Integer from sqlalchemy.ext.declarative import declarative_base from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker +from sqlalchemy.pool import ( + StaticPool, + QueuePool, + AssertionPool, + ) + +logg = logging.getLogger() Model = declarative_base(name='Model') @@ -21,7 +31,11 @@ class SessionBase(Model): transactional = True """Whether the database backend supports query transactions. Should be explicitly set by initialization code""" poolable = True - """Whether the database backend supports query transactions. Should be explicitly set by initialization code""" + """Whether the database backend supports connection pools. Should be explicitly set by initialization code""" + procedural = True + """Whether the database backend supports stored procedures""" + localsessions = {} + """Contains dictionary of sessions initiated by db model components""" @staticmethod @@ -40,7 +54,7 @@ class SessionBase(Model): @staticmethod - def connect(dsn, debug=False): + def connect(dsn, pool_size=8, debug=False): """Create new database connection engine and connect to database backend. :param dsn: DSN string defining connection. @@ -48,14 +62,28 @@ class SessionBase(Model): """ e = None if SessionBase.poolable: - e = create_engine( - dsn, - max_overflow=50, - pool_pre_ping=True, - pool_size=20, - pool_recycle=10, - echo=debug, - ) + poolclass = QueuePool + if pool_size > 1: + e = create_engine( + dsn, + max_overflow=pool_size*3, + pool_pre_ping=True, + pool_size=pool_size, + pool_recycle=60, + poolclass=poolclass, + echo=debug, + ) + else: + if debug: + poolclass = AssertionPool + else: + poolclass = StaticPool + + e = create_engine( + dsn, + poolclass=poolclass, + echo=debug, + ) else: e = create_engine( dsn, @@ -71,3 +99,24 @@ class SessionBase(Model): """ SessionBase.engine.dispose() SessionBase.engine = None + + + @staticmethod + def bind_session(session=None): + localsession = session + if localsession == None: + localsession = SessionBase.create_session() + localsession_key = str(id(localsession)) + logg.debug('creating new session {}'.format(localsession_key)) + SessionBase.localsessions[localsession_key] = localsession + return localsession + + + @staticmethod + def release_session(session=None): + session.flush() + session_key = str(id(session)) + if SessionBase.localsessions.get(session_key) != None: + logg.debug('destroying session {}'.format(session_key)) + session.commit() + session.close() diff --git a/chainsyncer/db/models/filter.py b/chainsyncer/db/models/filter.py index 89bd564..656ed62 100644 --- a/chainsyncer/db/models/filter.py +++ b/chainsyncer/db/models/filter.py @@ -1,40 +1,79 @@ # standard imports +import logging import hashlib -# third-party imports -from sqlalchemy import Column, String, Integer, BLOB +# external imports +from sqlalchemy import Column, String, Integer, BLOB, ForeignKey from sqlalchemy.ext.hybrid import hybrid_property, hybrid_method # local imports from .base import SessionBase +from .sync import BlockchainSync - -zero_digest = '{:<064s'.format('0') +zero_digest = bytearray(32) +logg = logging.getLogger(__name__) class BlockchainSyncFilter(SessionBase): __tablename__ = 'chain_sync_filter' - chain_sync_id = Column(Integer, ForeignKey='chain_sync.id') + chain_sync_id = Column(Integer, ForeignKey('chain_sync.id')) + flags_start = Column(BLOB) flags = Column(BLOB) - digest = Column(String) + digest = Column(BLOB) count = Column(Integer) - @staticmethod - def set(self, names): - - def __init__(self, names, chain_sync, digest=None): - if len(names) == 0: - digest = zero_digest - elif digest == None: - h = hashlib.new('sha256') - for n in names: - h.update(n.encode('utf-8') + b'\x00') - z = h.digest() - digest = z.hex() + def __init__(self, chain_sync, count=0, flags=None, digest=zero_digest): self.digest = digest - self.count = len(names) - self.flags = bytearray((len(names) -1 ) / 8 + 1) + self.count = count + + if flags == None: + flags = bytearray(0) + self.flags_start = flags + self.flags = flags + self.chain_sync_id = chain_sync.id + + + def add(self, name): + h = hashlib.new('sha256') + h.update(self.digest) + h.update(name.encode('utf-8')) + z = h.digest() + + old_byte_count = int((self.count - 1) / 8 + 1) + new_byte_count = int((self.count) / 8 + 1) + + logg.debug('old new {} {}'.format(old_byte_count, new_byte_count)) + if old_byte_count != new_byte_count: + self.flags = bytearray(1) + self.flags + self.count += 1 + self.digest = z + + + def start(self): + return self.flags_start + + + def cursor(self): + return self.flags_current + + + def clear(self): + self.flags = 0 + + + def target(self): + n = 0 + for i in range(self.count): + n |= 2 << i + return n + + + def set(self, n): + if self.flags & n > 0: + SessionBase.release_session(session) + raise AttributeError('Filter bit already set') + r.flags |= n diff --git a/chainsyncer/db/models/sync.py b/chainsyncer/db/models/sync.py index 6dab033..4f2f156 100644 --- a/chainsyncer/db/models/sync.py +++ b/chainsyncer/db/models/sync.py @@ -41,19 +41,23 @@ class BlockchainSync(SessionBase): :type chain: str :param session: Session to use. If not specified, a separate session will be created for this method only. :type session: SqlAlchemy Session - :returns: True if sync record found - :rtype: bool + :returns: Database primary key id of sync record + :rtype: number|None """ - local_session = False - if session == None: - session = SessionBase.create_session() - local_session = True + session = SessionBase.bind_session(session) + q = session.query(BlockchainSync.id) q = q.filter(BlockchainSync.blockchain==chain) o = q.first() - if local_session: - session.close() - return o == None + + if o == None: + return None + + sync_id = o.id + + SessionBase.release_session(session) + + return sync_id @staticmethod @@ -165,4 +169,4 @@ class BlockchainSync(SessionBase): self.tx_cursor = tx_start self.block_target = block_target self.date_created = datetime.datetime.utcnow() - self.date_modified = datetime.datetime.utcnow() + self.date_updated = datetime.datetime.utcnow() diff --git a/chainsyncer/filter.py b/chainsyncer/filter.py index e488b8c..263bbe5 100644 --- a/chainsyncer/filter.py +++ b/chainsyncer/filter.py @@ -9,6 +9,7 @@ from .error import BackendError logg = logging.getLogger(__name__) + class SyncFilter: def __init__(self, backend, safe=True): @@ -32,11 +33,15 @@ class SyncFilter: except sqlalchemy.exc.TimeoutError as e: self.backend.disconnect() raise BackendError('database connection fail: {}'.format(e)) + i = 0 for f in self.filters: + i += 1 logg.debug('applying filter {}'.format(str(f))) f.filter(conn, block, tx, self.backend.db_session) + self.backend.set_filter() self.backend.disconnect() + class NoopFilter: def filter(self, conn, block, tx, db_session=None): diff --git a/sql/sqlite/1.sql b/sql/sqlite/1.sql index fa9f6fa..0f115d4 100644 --- a/sql/sqlite/1.sql +++ b/sql/sqlite/1.sql @@ -1,13 +1,11 @@ CREATE TABLE IF NOT EXISTS chain_sync ( - id serial primary key not null, + id integer primary key autoincrement, blockchain varchar not null, - block_start int not null default 0, - tx_start int not null default 0, - block_cursor int not null default 0, - tx_cursor int not null default 0, - flags bytea not null, - num_flags int not null, - block_target int default null, + block_start integer not null default 0, + tx_start integer not null default 0, + block_cursor integer not null default 0, + tx_cursor integer not null default 0, + block_target integer default null, date_created timestamp not null, date_updated timestamp default null ); diff --git a/sql/sqlite/2.sql b/sql/sqlite/2.sql index c43624e..9e42802 100644 --- a/sql/sqlite/2.sql +++ b/sql/sqlite/2.sql @@ -1,8 +1,9 @@ CREATE TABLE IF NOT EXISTS chain_sync_filter ( - id serial primary key not null, - chain_sync_id int not null, + id integer primary key autoincrement not null, + chain_sync_id integer not null, flags bytea default null, - count int not null default 0, + flags_start bytea default null, + count integer not null default 0, digest char(64) not null default '0000000000000000000000000000000000000000000000000000000000000000', CONSTRAINT fk_chain_sync FOREIGN KEY(chain_sync_id) diff --git a/tests/base.py b/tests/base.py index 6729eb9..2d0a99c 100644 --- a/tests/base.py +++ b/tests/base.py @@ -1,13 +1,21 @@ +# standard imports +import logging import unittest import tempfile import os #import pysqlite +# external imports +from chainlib.chain import ChainSpec + +# local imports from chainsyncer.db import dsn_from_config from chainsyncer.db.models.base import SessionBase script_dir = os.path.realpath(os.path.dirname(__file__)) +logging.basicConfig(level=logging.DEBUG) + class TestBase(unittest.TestCase): @@ -23,7 +31,7 @@ class TestBase(unittest.TestCase): SessionBase.poolable = False SessionBase.transactional = False SessionBase.procedural = False - SessionBase.connect(dsn, debug=True) + SessionBase.connect(dsn, debug=False) f = open(os.path.join(script_dir, '..', 'sql', 'sqlite', '1.sql'), 'r') sql = f.read() @@ -39,6 +47,8 @@ class TestBase(unittest.TestCase): conn = SessionBase.engine.connect() conn.execute(sql) + self.chain_spec = ChainSpec('evm', 'foo', 42, 'bar') + def tearDown(self): SessionBase.disconnect() os.unlink(self.db_path) diff --git a/tests/test_database.py b/tests/test_database.py new file mode 100644 index 0000000..22f919b --- /dev/null +++ b/tests/test_database.py @@ -0,0 +1,55 @@ +# standard imports +import unittest + +# external imports +from chainlib.chain import ChainSpec + +# local imports +from chainsyncer.db.models.base import SessionBase +from chainsyncer.db.models.filter import BlockchainSyncFilter +from chainsyncer.backend import SyncerBackend + +# testutil imports +from tests.base import TestBase + +class TestDatabase(TestBase): + + + def test_backend_live(self): + s = SyncerBackend.live(self.chain_spec, 42) + self.assertEqual(s.object_id, 1) + backend = SyncerBackend.first(self.chain_spec) + #SyncerBackend(self.chain_spec, sync_id) + self.assertEqual(backend.object_id, 1) + + bogus_chain_spec = ChainSpec('bogus', 'foo', 13, 'baz') + sync_id = SyncerBackend.first(bogus_chain_spec) + self.assertIsNone(sync_id) + + + def test_backend_filter(self): + s = SyncerBackend.live(self.chain_spec, 42) + + s.connect() + filter_id = s.db_object_filter.id + s.disconnect() + + session = SessionBase.create_session() + o = session.query(BlockchainSyncFilter).get(filter_id) + self.assertEqual(len(o.flags), 0) + session.close() + + for i in range(9): + s.register_filter(str(i)) + + s.connect() + filter_id = s.db_object_filter.id + s.disconnect() + + session = SessionBase.create_session() + o = session.query(BlockchainSyncFilter).get(filter_id) + self.assertEqual(len(o.flags), 2) + session.close() + +if __name__ == '__main__': + unittest.main()