Add generic persistence wrapper

This commit is contained in:
lash 2022-01-31 12:10:04 +00:00
parent 0eaf032b89
commit 3ffb3b08aa
Signed by: lash
GPG Key ID: 21D2E7BB88C2A746
4 changed files with 144 additions and 5 deletions

48
shep/persist.py Normal file
View File

@ -0,0 +1,48 @@
# local imports
from .state import State
class PersistedState(State):
def __init__(self, factory, bits, logger=None):
super(PersistedState, self).__init__(bits, logger=logger)
self.__store_factory = factory
self.__stores = {}
def __ensure_store(self, k):
if self.__stores.get(k) == None:
self.__stores[k] = self.__store_factory(k)
def put(self, item, state=None):
k = self.name(state)
self.__ensure_store(k)
self.__stores[k].add(item)
super(PersistedState, self).put(item, state=state)
def move(self, item, to_state):
k_to = self.name(to_state)
from_state = self.state(item)
k_from = self.name(from_state)
self.__ensure_store(k_to)
self.__ensure_store(k_from)
self.__stores[k_to].add(item)
self.__stores[k_from].remove(item)
super(PersistedState, self).move(item, to_state)
def purge(self, item):
state = self.state(item)
k = self.name(state)
self.__ensure_store(k)
self.__stores[k].remove(item)
super(PersistedState, self).purge(item)

View File

@ -9,13 +9,13 @@ from shep.error import (
class State: class State:
def __init__(self, bits, logger=None, store_factory=None): def __init__(self, bits, logger=None):
self.__bits = bits self.__bits = bits
self.__limit = (1 << bits) - 1 self.__limit = (1 << bits) - 1
self.__c = 0 self.__c = 0
self.__reverse = {}
self.NEW = 0 self.NEW = 0
self.__reverse = {0: self.NEW}
self.__items = {self.NEW: []} self.__items = {self.NEW: []}
self.__items_reverse = {} self.__items_reverse = {}
@ -124,6 +124,15 @@ class State:
return l return l
def name(self, v):
if v == None:
return self.NEW
k = self.__reverse.get(v)
if k == None:
raise StateInvalid(v)
return k
def match(self, v, pure=False): def match(self, v, pure=False):
alias = None alias = None
if not pure: if not pure:
@ -175,7 +184,6 @@ class State:
current_state_list.pop(idx) current_state_list.pop(idx)
def purge(self, item): def purge(self, item):
current_state = self.__items_reverse.get(item) current_state = self.__items_reverse.get(item)
if current_state == None: if current_state == None:

View File

@ -74,6 +74,5 @@ class TestStateItems(unittest.TestCase):
self.states.state(item) self.states.state(item)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

84
tests/test_store.py Normal file
View File

@ -0,0 +1,84 @@
# standard imports
import unittest
# local imports
from shep.persist import PersistedState
from shep.error import (
StateExists,
StateItemExists,
StateInvalid,
StateItemNotFound,
)
class MockStore:
def __init__(self):
self.v = {}
self.for_state = 0
def add(self, k):
self.v[k] = True
def remove(self, k):
del self.v[k]
class TestStateItems(unittest.TestCase):
def setUp(self):
self.mockstore = MockStore()
def mockstore_factory(v):
self.mockstore.for_state = v
return self.mockstore
self.states = PersistedState(mockstore_factory, 4)
self.states.add('foo')
self.states.add('bar')
self.states.add('baz')
self.states.alias('xyzzy', self.states.BAZ | self.states.BAR)
self.states.alias('plugh', self.states.FOO | self.states.BAR)
def test_persist_new(self):
item = b'foo'
self.states.put(item)
self.assertTrue(self.mockstore.v.get(item))
def test_persist_move(self):
item = b'foo'
self.states.put(item, self.states.FOO)
self.states.move(item, self.states.XYZZY)
self.assertEqual(self.mockstore.for_state, self.states.name(self.states.XYZZY))
def test_persist_move(self):
item = b'foo'
self.states.put(item, self.states.FOO)
self.states.move(item, self.states.XYZZY)
self.assertEqual(self.mockstore.for_state, self.states.name(self.states.XYZZY))
# TODO: cant check the add because remove happens after remove, need better mock
self.assertIsNone(self.mockstore.v.get(item))
def test_persist_purge(self):
item = b'foo'
self.states.put(item, self.states.FOO)
self.states.purge(item)
self.assertEqual(self.mockstore.for_state, self.states.name(self.states.FOO))
self.assertIsNone(self.mockstore.v.get(item))
def test_persist_move_new(self):
item = b'foo'
self.states.put(item)
self.states.move(item, self.states.BAZ)
self.assertEqual(self.mockstore.for_state, self.states.name(self.states.BAZ))
self.assertIsNone(self.mockstore.v.get(item))
if __name__ == '__main__':
unittest.main()