Add filter to live instantiator, add filter start state
This commit is contained in:
parent
82e2674555
commit
81dd5753e8
@ -7,6 +7,7 @@ from chainlib.chain import ChainSpec
|
||||
|
||||
# local imports
|
||||
from chainsyncer.db.models.sync import BlockchainSync
|
||||
from chainsyncer.db.models.filter import BlockchainSyncFilter
|
||||
from chainsyncer.db.models.base import SessionBase
|
||||
|
||||
logg = logging.getLogger()
|
||||
@ -23,6 +24,7 @@ class SyncerBackend:
|
||||
def __init__(self, chain_spec, object_id):
|
||||
self.db_session = None
|
||||
self.db_object = None
|
||||
self.db_object_filter = None
|
||||
self.chain_spec = chain_spec
|
||||
self.object_id = object_id
|
||||
self.connect()
|
||||
@ -34,9 +36,17 @@ class SyncerBackend:
|
||||
"""
|
||||
if self.db_session == None:
|
||||
self.db_session = SessionBase.create_session()
|
||||
|
||||
q = self.db_session.query(BlockchainSync)
|
||||
q = q.filter(BlockchainSync.id==self.object_id)
|
||||
self.db_object = q.first()
|
||||
|
||||
if self.db_object != None:
|
||||
qtwo = self.db_session.query(BlockchainSyncFilter)
|
||||
qtwo = qtwo.join(BlockchainSync)
|
||||
qtwo = qtwo.filter(BlockchainSync.id==self.db_object.id)
|
||||
self.db_object_filter = qtwo.first()
|
||||
|
||||
if self.db_object == None:
|
||||
raise ValueError('sync entry with id {} not found'.format(self.object_id))
|
||||
|
||||
@ -44,6 +54,8 @@ class SyncerBackend:
|
||||
def disconnect(self):
|
||||
"""Commits state of sync to backend.
|
||||
"""
|
||||
if self.db_object_filter != None:
|
||||
self.db_session.add(self.db_object_filter)
|
||||
self.db_session.add(self.db_object)
|
||||
self.db_session.commit()
|
||||
self.db_session.close()
|
||||
@ -67,8 +79,9 @@ class SyncerBackend:
|
||||
"""
|
||||
self.connect()
|
||||
pair = self.db_object.cursor()
|
||||
filter_state = self.db_object_filter.filter()
|
||||
self.disconnect()
|
||||
return pair
|
||||
return (pair, filter_state,)
|
||||
|
||||
|
||||
def set(self, block_height, tx_height):
|
||||
@ -82,8 +95,9 @@ class SyncerBackend:
|
||||
"""
|
||||
self.connect()
|
||||
pair = self.db_object.set(block_height, tx_height)
|
||||
filter_state = self.db_object_filter.filter()
|
||||
self.disconnect()
|
||||
return pair
|
||||
return (pair, filter_state,)
|
||||
|
||||
|
||||
def start(self):
|
||||
@ -94,8 +108,9 @@ class SyncerBackend:
|
||||
"""
|
||||
self.connect()
|
||||
pair = self.db_object.start()
|
||||
filter_state = self.db_object_filter.start()
|
||||
self.disconnect()
|
||||
return pair
|
||||
return (pair, filter_state,)
|
||||
|
||||
|
||||
def target(self):
|
||||
@ -106,12 +121,13 @@ class SyncerBackend:
|
||||
"""
|
||||
self.connect()
|
||||
target = self.db_object.target()
|
||||
filter_state = self.db_object_filter.target()
|
||||
self.disconnect()
|
||||
return target
|
||||
return (target, filter_target,)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def first(chain):
|
||||
def first(chain_spec):
|
||||
"""Returns the model object of the most recent syncer in backend.
|
||||
|
||||
:param chain: Chain spec of chain that syncer is running for.
|
||||
@ -119,7 +135,12 @@ class SyncerBackend:
|
||||
:returns: Last syncer object
|
||||
:rtype: cic_eth.db.models.BlockchainSync
|
||||
"""
|
||||
return BlockchainSync.first(chain)
|
||||
#return BlockchainSync.first(str(chain_spec))
|
||||
object_id = BlockchainSync.first(str(chain_spec))
|
||||
if object_id == None:
|
||||
return None
|
||||
return SyncerBackend(chain_spec, object_id)
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
@ -193,15 +214,30 @@ class SyncerBackend:
|
||||
"""
|
||||
object_id = None
|
||||
session = SessionBase.create_session()
|
||||
|
||||
o = BlockchainSync(str(chain_spec), block_height, 0, None)
|
||||
session.add(o)
|
||||
session.commit()
|
||||
session.flush()
|
||||
object_id = o.id
|
||||
|
||||
of = BlockchainSyncFilter(o)
|
||||
session.add(of)
|
||||
session.commit()
|
||||
|
||||
session.close()
|
||||
|
||||
return SyncerBackend(chain_spec, object_id)
|
||||
|
||||
|
||||
def register_filter(self, name):
|
||||
self.connect()
|
||||
if self.db_object_filter == None:
|
||||
self.db_object_filter = BlockchainSyncFilter(self.db_object)
|
||||
self.db_object_filter.add(name)
|
||||
self.db_session.add(self.db_object_filter)
|
||||
self.disconnect()
|
||||
|
||||
|
||||
class MemBackend:
|
||||
|
||||
def __init__(self, chain_spec, object_id):
|
||||
@ -209,6 +245,7 @@ class MemBackend:
|
||||
self.chain_spec = chain_spec
|
||||
self.block_height = 0
|
||||
self.tx_height = 0
|
||||
self.flags = 0
|
||||
self.db_session = None
|
||||
|
||||
|
||||
|
@ -1,8 +1,18 @@
|
||||
# stanard imports
|
||||
import logging
|
||||
|
||||
# third-party imports
|
||||
from sqlalchemy import Column, Integer
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import (
|
||||
StaticPool,
|
||||
QueuePool,
|
||||
AssertionPool,
|
||||
)
|
||||
|
||||
logg = logging.getLogger()
|
||||
|
||||
Model = declarative_base(name='Model')
|
||||
|
||||
@ -21,7 +31,11 @@ class SessionBase(Model):
|
||||
transactional = True
|
||||
"""Whether the database backend supports query transactions. Should be explicitly set by initialization code"""
|
||||
poolable = True
|
||||
"""Whether the database backend supports query transactions. Should be explicitly set by initialization code"""
|
||||
"""Whether the database backend supports connection pools. Should be explicitly set by initialization code"""
|
||||
procedural = True
|
||||
"""Whether the database backend supports stored procedures"""
|
||||
localsessions = {}
|
||||
"""Contains dictionary of sessions initiated by db model components"""
|
||||
|
||||
|
||||
@staticmethod
|
||||
@ -40,7 +54,7 @@ class SessionBase(Model):
|
||||
|
||||
|
||||
@staticmethod
|
||||
def connect(dsn, debug=False):
|
||||
def connect(dsn, pool_size=8, debug=False):
|
||||
"""Create new database connection engine and connect to database backend.
|
||||
|
||||
:param dsn: DSN string defining connection.
|
||||
@ -48,12 +62,26 @@ class SessionBase(Model):
|
||||
"""
|
||||
e = None
|
||||
if SessionBase.poolable:
|
||||
poolclass = QueuePool
|
||||
if pool_size > 1:
|
||||
e = create_engine(
|
||||
dsn,
|
||||
max_overflow=50,
|
||||
max_overflow=pool_size*3,
|
||||
pool_pre_ping=True,
|
||||
pool_size=20,
|
||||
pool_recycle=10,
|
||||
pool_size=pool_size,
|
||||
pool_recycle=60,
|
||||
poolclass=poolclass,
|
||||
echo=debug,
|
||||
)
|
||||
else:
|
||||
if debug:
|
||||
poolclass = AssertionPool
|
||||
else:
|
||||
poolclass = StaticPool
|
||||
|
||||
e = create_engine(
|
||||
dsn,
|
||||
poolclass=poolclass,
|
||||
echo=debug,
|
||||
)
|
||||
else:
|
||||
@ -71,3 +99,24 @@ class SessionBase(Model):
|
||||
"""
|
||||
SessionBase.engine.dispose()
|
||||
SessionBase.engine = None
|
||||
|
||||
|
||||
@staticmethod
|
||||
def bind_session(session=None):
|
||||
localsession = session
|
||||
if localsession == None:
|
||||
localsession = SessionBase.create_session()
|
||||
localsession_key = str(id(localsession))
|
||||
logg.debug('creating new session {}'.format(localsession_key))
|
||||
SessionBase.localsessions[localsession_key] = localsession
|
||||
return localsession
|
||||
|
||||
|
||||
@staticmethod
|
||||
def release_session(session=None):
|
||||
session.flush()
|
||||
session_key = str(id(session))
|
||||
if SessionBase.localsessions.get(session_key) != None:
|
||||
logg.debug('destroying session {}'.format(session_key))
|
||||
session.commit()
|
||||
session.close()
|
||||
|
@ -1,40 +1,79 @@
|
||||
# standard imports
|
||||
import logging
|
||||
import hashlib
|
||||
|
||||
# third-party imports
|
||||
from sqlalchemy import Column, String, Integer, BLOB
|
||||
# external imports
|
||||
from sqlalchemy import Column, String, Integer, BLOB, ForeignKey
|
||||
from sqlalchemy.ext.hybrid import hybrid_property, hybrid_method
|
||||
|
||||
# local imports
|
||||
from .base import SessionBase
|
||||
from .sync import BlockchainSync
|
||||
|
||||
|
||||
zero_digest = '{:<064s'.format('0')
|
||||
zero_digest = bytearray(32)
|
||||
logg = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BlockchainSyncFilter(SessionBase):
|
||||
|
||||
__tablename__ = 'chain_sync_filter'
|
||||
|
||||
chain_sync_id = Column(Integer, ForeignKey='chain_sync.id')
|
||||
chain_sync_id = Column(Integer, ForeignKey('chain_sync.id'))
|
||||
flags_start = Column(BLOB)
|
||||
flags = Column(BLOB)
|
||||
digest = Column(String)
|
||||
digest = Column(BLOB)
|
||||
count = Column(Integer)
|
||||
|
||||
@staticmethod
|
||||
def set(self, names):
|
||||
|
||||
|
||||
def __init__(self, names, chain_sync, digest=None):
|
||||
if len(names) == 0:
|
||||
digest = zero_digest
|
||||
elif digest == None:
|
||||
h = hashlib.new('sha256')
|
||||
for n in names:
|
||||
h.update(n.encode('utf-8') + b'\x00')
|
||||
z = h.digest()
|
||||
digest = z.hex()
|
||||
def __init__(self, chain_sync, count=0, flags=None, digest=zero_digest):
|
||||
self.digest = digest
|
||||
self.count = len(names)
|
||||
self.flags = bytearray((len(names) -1 ) / 8 + 1)
|
||||
self.count = count
|
||||
|
||||
if flags == None:
|
||||
flags = bytearray(0)
|
||||
self.flags_start = flags
|
||||
self.flags = flags
|
||||
|
||||
self.chain_sync_id = chain_sync.id
|
||||
|
||||
|
||||
def add(self, name):
|
||||
h = hashlib.new('sha256')
|
||||
h.update(self.digest)
|
||||
h.update(name.encode('utf-8'))
|
||||
z = h.digest()
|
||||
|
||||
old_byte_count = int((self.count - 1) / 8 + 1)
|
||||
new_byte_count = int((self.count) / 8 + 1)
|
||||
|
||||
logg.debug('old new {} {}'.format(old_byte_count, new_byte_count))
|
||||
if old_byte_count != new_byte_count:
|
||||
self.flags = bytearray(1) + self.flags
|
||||
self.count += 1
|
||||
self.digest = z
|
||||
|
||||
|
||||
def start(self):
|
||||
return self.flags_start
|
||||
|
||||
|
||||
def cursor(self):
|
||||
return self.flags_current
|
||||
|
||||
|
||||
def clear(self):
|
||||
self.flags = 0
|
||||
|
||||
|
||||
def target(self):
|
||||
n = 0
|
||||
for i in range(self.count):
|
||||
n |= 2 << i
|
||||
return n
|
||||
|
||||
|
||||
def set(self, n):
|
||||
if self.flags & n > 0:
|
||||
SessionBase.release_session(session)
|
||||
raise AttributeError('Filter bit already set')
|
||||
r.flags |= n
|
||||
|
@ -41,19 +41,23 @@ class BlockchainSync(SessionBase):
|
||||
:type chain: str
|
||||
:param session: Session to use. If not specified, a separate session will be created for this method only.
|
||||
:type session: SqlAlchemy Session
|
||||
:returns: True if sync record found
|
||||
:rtype: bool
|
||||
:returns: Database primary key id of sync record
|
||||
:rtype: number|None
|
||||
"""
|
||||
local_session = False
|
||||
if session == None:
|
||||
session = SessionBase.create_session()
|
||||
local_session = True
|
||||
session = SessionBase.bind_session(session)
|
||||
|
||||
q = session.query(BlockchainSync.id)
|
||||
q = q.filter(BlockchainSync.blockchain==chain)
|
||||
o = q.first()
|
||||
if local_session:
|
||||
session.close()
|
||||
return o == None
|
||||
|
||||
if o == None:
|
||||
return None
|
||||
|
||||
sync_id = o.id
|
||||
|
||||
SessionBase.release_session(session)
|
||||
|
||||
return sync_id
|
||||
|
||||
|
||||
@staticmethod
|
||||
@ -165,4 +169,4 @@ class BlockchainSync(SessionBase):
|
||||
self.tx_cursor = tx_start
|
||||
self.block_target = block_target
|
||||
self.date_created = datetime.datetime.utcnow()
|
||||
self.date_modified = datetime.datetime.utcnow()
|
||||
self.date_updated = datetime.datetime.utcnow()
|
||||
|
@ -9,6 +9,7 @@ from .error import BackendError
|
||||
|
||||
logg = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SyncFilter:
|
||||
|
||||
def __init__(self, backend, safe=True):
|
||||
@ -32,11 +33,15 @@ class SyncFilter:
|
||||
except sqlalchemy.exc.TimeoutError as e:
|
||||
self.backend.disconnect()
|
||||
raise BackendError('database connection fail: {}'.format(e))
|
||||
i = 0
|
||||
for f in self.filters:
|
||||
i += 1
|
||||
logg.debug('applying filter {}'.format(str(f)))
|
||||
f.filter(conn, block, tx, self.backend.db_session)
|
||||
self.backend.set_filter()
|
||||
self.backend.disconnect()
|
||||
|
||||
|
||||
class NoopFilter:
|
||||
|
||||
def filter(self, conn, block, tx, db_session=None):
|
||||
|
@ -1,13 +1,11 @@
|
||||
CREATE TABLE IF NOT EXISTS chain_sync (
|
||||
id serial primary key not null,
|
||||
id integer primary key autoincrement,
|
||||
blockchain varchar not null,
|
||||
block_start int not null default 0,
|
||||
tx_start int not null default 0,
|
||||
block_cursor int not null default 0,
|
||||
tx_cursor int not null default 0,
|
||||
flags bytea not null,
|
||||
num_flags int not null,
|
||||
block_target int default null,
|
||||
block_start integer not null default 0,
|
||||
tx_start integer not null default 0,
|
||||
block_cursor integer not null default 0,
|
||||
tx_cursor integer not null default 0,
|
||||
block_target integer default null,
|
||||
date_created timestamp not null,
|
||||
date_updated timestamp default null
|
||||
);
|
||||
|
@ -1,8 +1,9 @@
|
||||
CREATE TABLE IF NOT EXISTS chain_sync_filter (
|
||||
id serial primary key not null,
|
||||
chain_sync_id int not null,
|
||||
id integer primary key autoincrement not null,
|
||||
chain_sync_id integer not null,
|
||||
flags bytea default null,
|
||||
count int not null default 0,
|
||||
flags_start bytea default null,
|
||||
count integer not null default 0,
|
||||
digest char(64) not null default '0000000000000000000000000000000000000000000000000000000000000000',
|
||||
CONSTRAINT fk_chain_sync
|
||||
FOREIGN KEY(chain_sync_id)
|
||||
|
@ -1,13 +1,21 @@
|
||||
# standard imports
|
||||
import logging
|
||||
import unittest
|
||||
import tempfile
|
||||
import os
|
||||
#import pysqlite
|
||||
|
||||
# external imports
|
||||
from chainlib.chain import ChainSpec
|
||||
|
||||
# local imports
|
||||
from chainsyncer.db import dsn_from_config
|
||||
from chainsyncer.db.models.base import SessionBase
|
||||
|
||||
script_dir = os.path.realpath(os.path.dirname(__file__))
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
|
||||
class TestBase(unittest.TestCase):
|
||||
|
||||
@ -23,7 +31,7 @@ class TestBase(unittest.TestCase):
|
||||
SessionBase.poolable = False
|
||||
SessionBase.transactional = False
|
||||
SessionBase.procedural = False
|
||||
SessionBase.connect(dsn, debug=True)
|
||||
SessionBase.connect(dsn, debug=False)
|
||||
|
||||
f = open(os.path.join(script_dir, '..', 'sql', 'sqlite', '1.sql'), 'r')
|
||||
sql = f.read()
|
||||
@ -39,6 +47,8 @@ class TestBase(unittest.TestCase):
|
||||
conn = SessionBase.engine.connect()
|
||||
conn.execute(sql)
|
||||
|
||||
self.chain_spec = ChainSpec('evm', 'foo', 42, 'bar')
|
||||
|
||||
def tearDown(self):
|
||||
SessionBase.disconnect()
|
||||
os.unlink(self.db_path)
|
||||
|
55
tests/test_database.py
Normal file
55
tests/test_database.py
Normal file
@ -0,0 +1,55 @@
|
||||
# standard imports
|
||||
import unittest
|
||||
|
||||
# external imports
|
||||
from chainlib.chain import ChainSpec
|
||||
|
||||
# local imports
|
||||
from chainsyncer.db.models.base import SessionBase
|
||||
from chainsyncer.db.models.filter import BlockchainSyncFilter
|
||||
from chainsyncer.backend import SyncerBackend
|
||||
|
||||
# testutil imports
|
||||
from tests.base import TestBase
|
||||
|
||||
class TestDatabase(TestBase):
|
||||
|
||||
|
||||
def test_backend_live(self):
|
||||
s = SyncerBackend.live(self.chain_spec, 42)
|
||||
self.assertEqual(s.object_id, 1)
|
||||
backend = SyncerBackend.first(self.chain_spec)
|
||||
#SyncerBackend(self.chain_spec, sync_id)
|
||||
self.assertEqual(backend.object_id, 1)
|
||||
|
||||
bogus_chain_spec = ChainSpec('bogus', 'foo', 13, 'baz')
|
||||
sync_id = SyncerBackend.first(bogus_chain_spec)
|
||||
self.assertIsNone(sync_id)
|
||||
|
||||
|
||||
def test_backend_filter(self):
|
||||
s = SyncerBackend.live(self.chain_spec, 42)
|
||||
|
||||
s.connect()
|
||||
filter_id = s.db_object_filter.id
|
||||
s.disconnect()
|
||||
|
||||
session = SessionBase.create_session()
|
||||
o = session.query(BlockchainSyncFilter).get(filter_id)
|
||||
self.assertEqual(len(o.flags), 0)
|
||||
session.close()
|
||||
|
||||
for i in range(9):
|
||||
s.register_filter(str(i))
|
||||
|
||||
s.connect()
|
||||
filter_id = s.db_object_filter.id
|
||||
s.disconnect()
|
||||
|
||||
session = SessionBase.create_session()
|
||||
o = session.query(BlockchainSyncFilter).get(filter_id)
|
||||
self.assertEqual(len(o.flags), 2)
|
||||
session.close()
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in New Issue
Block a user