From 798262f00f3af1cd59320ae489274f1cf999d099 Mon Sep 17 00:00:00 2001 From: lash Date: Wed, 16 Mar 2022 17:13:05 +0000 Subject: [PATCH] State change event emitter --- shep/state.py | 9 ++++++++- tests/test_state.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/shep/state.py b/shep/state.py index 19dceeb..e4ae551 100644 --- a/shep/state.py +++ b/shep/state.py @@ -30,7 +30,7 @@ class State: base_state_name = 'NEW' - def __init__(self, bits, logger=None, verifier=None, check_alias=True): + def __init__(self, bits, logger=None, verifier=None, check_alias=True, event_callback=None): self.__bits = bits self.__limit = (1 << bits) - 1 self.__c = 0 @@ -43,6 +43,7 @@ class State: self.modified_last = {} self.verifier = verifier self.check_alias = check_alias + self.event_callback = event_callback @classmethod @@ -320,6 +321,9 @@ class State: self.__contents[key] = contents self.register_modify(key) + + if self.event_callback != None: + self.event_callback(key, state) return state @@ -369,6 +373,9 @@ class State: self.register_modify(key) + if self.event_callback != None: + self.event_callback(key, to_state) + return to_state diff --git a/tests/test_state.py b/tests/test_state.py index b2f4693..d89470f 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -13,6 +13,18 @@ logging.basicConfig(level=logging.DEBUG) logg = logging.getLogger() +class MockCallback: + + def __init__(self): + self.items = {} + + + def add(self, k, v): + if self.items.get(k) == None: + self.items[k] = [] + self.items[k].append(v) + + class TestState(unittest.TestCase): def test_key_check(self): @@ -177,5 +189,24 @@ class TestState(unittest.TestCase): self.assertGreater(a, b) + def test_event_callback(self): + cb = MockCallback() + states = State(3, event_callback=cb.add) + states.add('foo') + states.add('bar') + states.add('baz') + states.alias('xyzzy', states.FOO | states.BAR) + states.put('abcd') + states.set('abcd', states.FOO) + states.set('abcd', states.BAR) + states.change('abcd', states.BAZ, states.XYZZY) + events = cb.items['abcd'] + self.assertEqual(len(events), 4) + self.assertEqual(events[0], states.NEW) + self.assertEqual(events[1], states.FOO) + self.assertEqual(events[2], states.XYZZY) + self.assertEqual(events[3], states.BAZ) + + if __name__ == '__main__': unittest.main()