From 3ffb3b08aa99b34bdbe78c74196ea655c31fb9e0 Mon Sep 17 00:00:00 2001 From: lash Date: Mon, 31 Jan 2022 12:10:04 +0000 Subject: [PATCH] Add generic persistence wrapper --- shep/persist.py | 48 ++++++++++++++++++++++++++ shep/state.py | 16 ++++++--- tests/test_item.py | 1 - tests/test_store.py | 84 +++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 144 insertions(+), 5 deletions(-) create mode 100644 shep/persist.py create mode 100644 tests/test_store.py diff --git a/shep/persist.py b/shep/persist.py new file mode 100644 index 0000000..fd5818b --- /dev/null +++ b/shep/persist.py @@ -0,0 +1,48 @@ +# local imports +from .state import State + + +class PersistedState(State): + + def __init__(self, factory, bits, logger=None): + super(PersistedState, self).__init__(bits, logger=logger) + self.__store_factory = factory + self.__stores = {} + + + def __ensure_store(self, k): + if self.__stores.get(k) == None: + self.__stores[k] = self.__store_factory(k) + + + def put(self, item, state=None): + k = self.name(state) + self.__ensure_store(k) + self.__stores[k].add(item) + + super(PersistedState, self).put(item, state=state) + + + def move(self, item, to_state): + k_to = self.name(to_state) + + from_state = self.state(item) + k_from = self.name(from_state) + + self.__ensure_store(k_to) + self.__ensure_store(k_from) + + self.__stores[k_to].add(item) + self.__stores[k_from].remove(item) + + super(PersistedState, self).move(item, to_state) + + + def purge(self, item): + state = self.state(item) + k = self.name(state) + + self.__ensure_store(k) + + self.__stores[k].remove(item) + super(PersistedState, self).purge(item) diff --git a/shep/state.py b/shep/state.py index a07aa62..82df2a9 100644 --- a/shep/state.py +++ b/shep/state.py @@ -9,13 +9,13 @@ from shep.error import ( class State: - def __init__(self, bits, logger=None, store_factory=None): + def __init__(self, bits, logger=None): self.__bits = bits self.__limit = (1 << bits) - 1 self.__c = 0 - self.__reverse = {} - self.NEW = 0 + + self.__reverse = {0: self.NEW} self.__items = {self.NEW: []} self.__items_reverse = {} @@ -124,6 +124,15 @@ class State: return l + def name(self, v): + if v == None: + return self.NEW + k = self.__reverse.get(v) + if k == None: + raise StateInvalid(v) + return k + + def match(self, v, pure=False): alias = None if not pure: @@ -175,7 +184,6 @@ class State: current_state_list.pop(idx) - def purge(self, item): current_state = self.__items_reverse.get(item) if current_state == None: diff --git a/tests/test_item.py b/tests/test_item.py index 418976b..52f6345 100644 --- a/tests/test_item.py +++ b/tests/test_item.py @@ -74,6 +74,5 @@ class TestStateItems(unittest.TestCase): self.states.state(item) - if __name__ == '__main__': unittest.main() diff --git a/tests/test_store.py b/tests/test_store.py new file mode 100644 index 0000000..a14aacd --- /dev/null +++ b/tests/test_store.py @@ -0,0 +1,84 @@ +# standard imports +import unittest + +# local imports +from shep.persist import PersistedState +from shep.error import ( + StateExists, + StateItemExists, + StateInvalid, + StateItemNotFound, + ) + +class MockStore: + + def __init__(self): + self.v = {} + self.for_state = 0 + + + def add(self, k): + self.v[k] = True + + + def remove(self, k): + del self.v[k] + + +class TestStateItems(unittest.TestCase): + + def setUp(self): + self.mockstore = MockStore() + + def mockstore_factory(v): + self.mockstore.for_state = v + return self.mockstore + + self.states = PersistedState(mockstore_factory, 4) + self.states.add('foo') + self.states.add('bar') + self.states.add('baz') + self.states.alias('xyzzy', self.states.BAZ | self.states.BAR) + self.states.alias('plugh', self.states.FOO | self.states.BAR) + + + def test_persist_new(self): + item = b'foo' + self.states.put(item) + self.assertTrue(self.mockstore.v.get(item)) + + + def test_persist_move(self): + item = b'foo' + self.states.put(item, self.states.FOO) + self.states.move(item, self.states.XYZZY) + self.assertEqual(self.mockstore.for_state, self.states.name(self.states.XYZZY)) + + + def test_persist_move(self): + item = b'foo' + self.states.put(item, self.states.FOO) + self.states.move(item, self.states.XYZZY) + self.assertEqual(self.mockstore.for_state, self.states.name(self.states.XYZZY)) + # TODO: cant check the add because remove happens after remove, need better mock + self.assertIsNone(self.mockstore.v.get(item)) + + + def test_persist_purge(self): + item = b'foo' + self.states.put(item, self.states.FOO) + self.states.purge(item) + self.assertEqual(self.mockstore.for_state, self.states.name(self.states.FOO)) + self.assertIsNone(self.mockstore.v.get(item)) + + + def test_persist_move_new(self): + item = b'foo' + self.states.put(item) + self.states.move(item, self.states.BAZ) + self.assertEqual(self.mockstore.for_state, self.states.name(self.states.BAZ)) + self.assertIsNone(self.mockstore.v.get(item)) + + +if __name__ == '__main__': + unittest.main()