diff --git a/CHANGELOG b/CHANGELOG index 7fa5ba1..6f82d60 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -5,6 +5,7 @@ * Dynamic bits * Optional binary contents * Sync all if no state passed as argument + * Mask method for client-side state manipulation - 0.1.0 * Release version bump - 0.0.19: diff --git a/shep/state.py b/shep/state.py index c4bbaf2..0ff02df 100644 --- a/shep/state.py +++ b/shep/state.py @@ -92,10 +92,11 @@ class State: # enforces state value within bit limit of instantiation - def __check_limit(self, v): - if self.__initial_bits == 0: - self.__bits += 1 - self.__limit = (1 << self.__bits) - 1 + def __check_limit(self, v, pure=True): + if pure: + if self.__initial_bits == 0: + self.__bits += 1 + self.__limit = (1 << self.__bits) - 1 if v > self.__limit: raise OverflowError(v) return v @@ -197,7 +198,7 @@ class State: v = 0 for a in args: a = self.__check_value_cursor(a) - v = self.__check_limit(v | a) + v = self.__check_limit(v | a, pure=False) if self.__is_pure(v): raise ValueError('use add to add pure values') self.__set(k, v) @@ -593,3 +594,11 @@ class State: def register_modify(self, key): self.modified_last[key] = datetime.datetime.now().timestamp() + + + def mask(self, key, states): + statemask = self.__limit + 1 + statemask |= states + statemask = ~statemask + statemask &= self.__limit + return statemask diff --git a/tests/test_state.py b/tests/test_state.py index cf24c37..eefe6ff 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -216,5 +216,26 @@ class TestState(unittest.TestCase): states.alias('baz', states.FOO | states.BAR) + + def test_mask(self): + states = State(3) + states.add('foo') + states.add('bar') + states.add('baz') + states.alias('all', states.FOO | states.BAR | states.BAZ) + mask = states.mask('xyzzy', states.FOO | states.BAZ) + self.assertEqual(mask, states.BAR) + + + def test_mask_dynamic(self): + states = State(0) + states.add('foo') + states.add('bar') + states.add('baz') + states.alias('all', states.FOO | states.BAR | states.BAZ) + mask = states.mask('xyzzy', states.FOO | states.BAZ) + self.assertEqual(mask, states.BAR) + + if __name__ == '__main__': unittest.main()