Merge branch 'lash/veirfy'

This commit is contained in:
lash 2022-04-09 16:11:25 +00:00
commit 9becb47751
11 changed files with 537 additions and 63 deletions

View File

@ -1,3 +1,13 @@
- 0.1.1
* Optional, pluggable verifier to protect state transition
* Change method for atomic simultaneous set and unset
* Optionally allow undefined composite states
* Dynamic bits
* Optional binary contents
* Sync all if no state passed as argument
* Mask method for client-side state manipulation
- 0.1.0
* Release version bump
- 0.0.19: - 0.0.19:
* Enable alias with comma separated values * Enable alias with comma separated values
- 0.0.18 - 0.0.18

64
example/kanban.py Normal file
View File

@ -0,0 +1,64 @@
from shep.state import State
# we don't like "NEW" as the default label for a new item in the queue, so we change it to BACKLOG
State.set_default_state('backlog')
# define all the valid states
st = State(5)
st.add('pending')
st.add('blocked')
st.add('doing')
st.add('review')
st.add('finished')
# define a couple of states that give a bit more context to progress; something is blocked before starting development or something is blocked during development...
st.alias('startblock', st.BLOCKED, st.PENDING)
st.alias('doingblock', st.BLOCKED, st.DOING)
# create the foo key which will forever languish in backlog
k = 'foo'
st.put(k)
foo_state = st.state(k)
foo_state_name = st.name(foo_state)
foo_contents_r = st.get('foo')
print('{} {} {}'.format(k, foo_state_name, foo_contents_r))
# Create bar->baz and advance it from backlog to pending
k = 'bar'
bar_contents = 'baz'
st.put(k, contents=bar_contents)
st.next(k)
bar_state = st.state(k)
bar_state_name = st.name(bar_state)
bar_contents_r = st.get('bar')
print('{} {} {}'.format(k, bar_state_name, bar_contents_r))
# Create inky->pinky and move to doing then doing-blocked
k = 'inky'
inky_contents = 'pinky'
st.put(k, contents=inky_contents)
inky_state = st.state(k)
st.move(k, st.DOING)
st.set(k, st.BLOCKED)
inky_state = st.state(k)
inky_state_name = st.name(inky_state)
inky_contents_r = st.get('inky')
print('{} {} {}'.format(k, inky_state_name, bar_contents_r))
# then replace the content
# note that replace could potentially mean some VCS below
inky_new_contents = 'blinky'
st.replace(k, inky_new_contents)
inky_contents_r = st.get('inky')
print('{} {} {}'.format(k, inky_state_name, inky_contents_r))
# so now move to review
st.move(k, st.REVIEW)
inky_state = st.state(k)
inky_state_name = st.name(inky_state)
print('{} {} {}'.format(k, inky_state_name, inky_contents_r))

View File

@ -1,6 +1,6 @@
[metadata] [metadata]
name = shep name = shep
version = 0.1.0 version = 0.1.1rc1
description = Multi-state key stores using bit masks description = Multi-state key stores using bit masks
author = Louis Holbrook author = Louis Holbrook
author_email = dev@holbrook.no author_email = dev@holbrook.no

View File

@ -26,3 +26,9 @@ class StateCorruptionError(RuntimeError):
"""An irrecoverable discrepancy between persisted state and memory state has occurred. """An irrecoverable discrepancy between persisted state and memory state has occurred.
""" """
pass pass
class StateTransitionInvalid(Exception):
"""Raised if state transition verification fails
"""
pass

View File

@ -1,3 +1,6 @@
# standard imports
import datetime
# local imports # local imports
from .state import State from .state import State
from .error import StateItemExists from .error import StateItemExists
@ -14,8 +17,8 @@ class PersistedState(State):
:type logger: object :type logger: object
""" """
def __init__(self, factory, bits, logger=None): def __init__(self, factory, bits, logger=None, verifier=None, check_alias=True, event_callback=None):
super(PersistedState, self).__init__(bits, logger=logger) super(PersistedState, self).__init__(bits, logger=logger, verifier=verifier, check_alias=check_alias, event_callback=event_callback)
self.__store_factory = factory self.__store_factory = factory
self.__stores = {} self.__stores = {}
@ -55,6 +58,8 @@ class PersistedState(State):
self.__stores[k_to].add(key, contents) self.__stores[k_to].add(key, contents)
self.__stores[k_from].remove(key) self.__stores[k_from].remove(key)
self.sync(to_state)
return to_state return to_state
@ -78,6 +83,28 @@ class PersistedState(State):
return to_state return to_state
def change(self, key, bits_set, bits_unset):
"""Persist a new state for a key or key/content.
See shep.state.State.unset
"""
from_state = self.state(key)
k_from = self.name(from_state)
to_state = super(PersistedState, self).change(key, bits_set, bits_unset)
k_to = self.name(to_state)
self.__ensure_store(k_to)
contents = self.__stores[k_from].get(key)
self.__stores[k_to].add(key, contents)
self.__stores[k_from].remove(key)
self.register_modify(key)
return to_state
def move(self, key, to_state): def move(self, key, to_state):
"""Persist a new state for a key or key/content. """Persist a new state for a key or key/content.
@ -99,10 +126,14 @@ class PersistedState(State):
self.__stores[k_to].add(key, contents) self.__stores[k_to].add(key, contents)
self.__stores[k_from].remove(key) self.__stores[k_from].remove(key)
self.register_modify(key)
self.sync(to_state)
return to_state return to_state
def sync(self, state): def sync(self, state=None):
"""Reload resources for a single state in memory from the persisted state store. """Reload resources for a single state in memory from the persisted state store.
:param state: State to load :param state: State to load
@ -110,15 +141,23 @@ class PersistedState(State):
:raises StateItemExists: A content key is already recorded with a different state in memory than in persisted store. :raises StateItemExists: A content key is already recorded with a different state in memory than in persisted store.
# :todo: if sync state is none, sync all # :todo: if sync state is none, sync all
""" """
k = self.name(state) states = []
if state == None:
states = list(self.all())
else:
states = [self.name(state)]
ks = []
for k in states:
ks.append(k)
for k in ks:
self.__ensure_store(k) self.__ensure_store(k)
for o in self.__stores[k].list(): for o in self.__stores[k].list():
self.__ensure_store(k) state = self.from_name(k)
try: try:
super(PersistedState, self).put(o[0], state=state, contents=o[1]) super(PersistedState, self).put(o[0], state=state, contents=o[1])
except StateItemExists: except StateItemExists as e:
pass pass
@ -131,7 +170,6 @@ class PersistedState(State):
""" """
k = self.name(state) k = self.name(state)
self.__ensure_store(k) self.__ensure_store(k)
#return self.__stores[k].list(state)
return super(PersistedState, self).list(state) return super(PersistedState, self).list(state)
@ -172,3 +210,9 @@ class PersistedState(State):
state = self.state(key) state = self.state(key)
k = self.name(state) k = self.name(state)
return self.__stores[k].replace(key, contents) return self.__stores[k].replace(key, contents)
def modified(self, key):
state = self.state(key)
k = self.name(state)
return self.__stores[k].modified(key)

View File

@ -1,12 +1,20 @@
# standard imports
import re
import datetime
# local imports # local imports
from shep.error import ( from shep.error import (
StateExists, StateExists,
StateInvalid, StateInvalid,
StateItemExists, StateItemExists,
StateItemNotFound, StateItemNotFound,
StateTransitionInvalid,
StateCorruptionError,
) )
re_name = r'^[a-zA-Z_\.]+$'
class State: class State:
"""State is an in-memory bitmasked state store for key-value pairs, or even just keys alone. """State is an in-memory bitmasked state store for key-value pairs, or even just keys alone.
@ -19,16 +27,29 @@ class State:
:param logger: Standard library logging instance to output to :param logger: Standard library logging instance to output to
:type logger: logging.Logger :type logger: logging.Logger
""" """
def __init__(self, bits, logger=None):
base_state_name = 'NEW'
def __init__(self, bits, logger=None, verifier=None, check_alias=True, event_callback=None):
self.__initial_bits = bits
self.__bits = bits self.__bits = bits
self.__limit = (1 << bits) - 1 self.__limit = (1 << bits) - 1
self.__c = 0 self.__c = 0
self.NEW = 0 setattr(self, self.base_state_name, 0)
self.__reverse = {0: self.NEW} self.__reverse = {0: getattr(self, self.base_state_name)}
self.__keys = {self.NEW: []} self.__keys = {getattr(self, self.base_state_name): []}
self.__keys_reverse = {} self.__keys_reverse = {}
self.__contents = {} self.__contents = {}
self.modified_last = {}
self.verifier = verifier
self.check_alias = check_alias
self.event_callback = event_callback
@classmethod
def set_default_state(cls, state_name):
cls.base_state_name = state_name.upper()
# return true if v is a single-bit state # return true if v is a single-bit state
@ -45,8 +66,8 @@ class State:
# validates a state name and return its canonical representation # validates a state name and return its canonical representation
def __check_name_valid(self, k): def __check_name_valid(self, k):
if not k.isalpha(): if not re.match(re_name, k):
raise ValueError('only alpha') raise ValueError('only alpha and underscore')
return k.upper() return k.upper()
@ -71,7 +92,11 @@ class State:
# enforces state value within bit limit of instantiation # enforces state value within bit limit of instantiation
def __check_limit(self, v): def __check_limit(self, v, pure=True):
if pure:
if self.__initial_bits == 0:
self.__bits += 1
self.__limit = (1 << self.__bits) - 1
if v > self.__limit: if v > self.__limit:
raise OverflowError(v) raise OverflowError(v)
return v return v
@ -114,8 +139,20 @@ class State:
def __add_state_list(self, state, item): def __add_state_list(self, state, item):
if self.__keys.get(state) == None: if self.__keys.get(state) == None:
self.__keys[state] = [] self.__keys[state] = []
if not self.__is_pure(state) or state == 0:
self.__keys[state].append(item) self.__keys[state].append(item)
c = 1
for i in range(self.__bits):
part = c & state
if part > 0:
if self.__keys.get(part) == None:
self.__keys[part] = []
self.__keys[part].append(item)
c <<= 1
self.__keys_reverse[item] = state self.__keys_reverse[item] = state
if self.__reverse.get(state) == None and not self.check_alias:
s = self.elements(state)
self.__alias(s, state)
def __state_list_index(self, item, state_list): def __state_list_index(self, item, state_list):
@ -148,6 +185,16 @@ class State:
self.__set(k, v) self.__set(k, v)
def __alias(self, k, *args):
v = 0
for a in args:
a = self.__check_value_cursor(a)
v = self.__check_limit(v | a, pure=False)
if self.__is_pure(v):
raise ValueError('use add to add pure values')
return self.__set(k, v)
def alias(self, k, *args): def alias(self, k, *args):
"""Add an alias for a combination of states in the store. """Add an alias for a combination of states in the store.
@ -161,16 +208,10 @@ class State:
:raises ValueError: Attempt to use bit value as alias :raises ValueError: Attempt to use bit value as alias
""" """
k = self.__check_name(k) k = self.__check_name(k)
v = 0 return self.__alias(k, *args)
for a in args:
a = self.__check_value_cursor(a)
v = self.__check_limit(v | a)
if self.__is_pure(v):
raise ValueError('use add to add pure values')
self.__set(k, v)
def all(self): def all(self, pure=False):
"""Return list of all unique atomic and alias states. """Return list of all unique atomic and alias states.
:rtype: list of ints :rtype: list of ints
@ -182,11 +223,36 @@ class State:
continue continue
if k.upper() != k: if k.upper() != k:
continue continue
if pure:
state = self.from_name(k)
if not self.__is_pure(state):
continue
l.append(k) l.append(k)
l.sort() l.sort()
return l return l
def elements(self, v):
r = []
if v == None or v == 0:
return self.base_state_name
c = 1
for i in range(self.__bits):
if v & c > 0:
r.append(self.name(c))
c <<= 1
return '_' + '.'.join(r)
def from_elements(self, k):
r = 0
if k[0] != '_':
raise ValueError('elements string must start with underscore (_), got {}'.format(k))
for v in k[1:].split('.'):
r |= self.from_name(v)
return r
def name(self, v): def name(self, v):
"""Retrieve that string representation of the state attribute represented by the given state integer value. """Retrieve that string representation of the state attribute represented by the given state integer value.
@ -196,11 +262,14 @@ class State:
:rtype: str :rtype: str
:return: State name :return: State name
""" """
if v == None or v == 0:
return 'NEW'
k = self.__reverse.get(v) k = self.__reverse.get(v)
if k == None: if k == None:
if self.check_alias:
raise StateInvalid(v) raise StateInvalid(v)
else:
k = self.elements(v)
elif v == None or v == 0:
return self.base_state_name
return k return k
@ -252,13 +321,13 @@ class State:
def put(self, key, state=None, contents=None): def put(self, key, state=None, contents=None):
"""Add a key to an existing state. """Add a key to an existing state.
If no state it specified, the default state attribute "NEW" will be used. If no state it specified, the default state attribute State.base_state_name will be used.
Contents may be supplied as value to pair with the given key. Contents may be changed later by calling the `replace` method. Contents may be supplied as value to pair with the given key. Contents may be changed later by calling the `replace` method.
:param key: Content key to add :param key: Content key to add
:type key: str :type key: str
:param state: Initial state for the put. If not given, initial state will be NEW :param state: Initial state for the put. If not given, initial state will be State.base_state_name
:type state: int :type state: int
:param contents: Contents to associate with key. A valie of None should be recognized as an undefined value as opposed to a zero-length value throughout any backend :param contents: Contents to associate with key. A valie of None should be recognized as an undefined value as opposed to a zero-length value throughout any backend
:type contents: str :type contents: str
@ -268,14 +337,21 @@ class State:
:return: Resulting state that key is put under (should match the input state) :return: Resulting state that key is put under (should match the input state)
""" """
if state == None: if state == None:
state = self.NEW state = getattr(self, self.base_state_name)
elif self.__reverse.get(state) == None: elif self.__reverse.get(state) == None and self.check_alias:
raise StateInvalid(state) raise StateInvalid(state)
self.__check_key(key) self.__check_key(key)
if self.event_callback != None:
old_state = self.__keys_reverse.get(key)
self.event_callback(key, None, self.name(state))
self.__add_state_list(state, key) self.__add_state_list(state, key)
if contents != None: if contents != None:
self.__contents[key] = contents self.__contents[key] = contents
self.register_modify(key)
return state return state
@ -296,7 +372,7 @@ class State:
raise StateItemNotFound(key) raise StateItemNotFound(key)
new_state = self.__reverse.get(to_state) new_state = self.__reverse.get(to_state)
if new_state == None: if new_state == None and self.check_alias:
raise StateInvalid(to_state) raise StateInvalid(to_state)
return self.__move(key, current_state, to_state) return self.__move(key, current_state, to_state)
@ -314,9 +390,21 @@ class State:
if current_state_list == None: if current_state_list == None:
raise StateCorruptionError(to_state) raise StateCorruptionError(to_state)
self.__add_state_list(to_state, key) if self.verifier != None:
r = self.verifier(self, from_state, to_state)
if r != None:
raise StateTransitionInvalid(r)
current_state_list.pop(idx) current_state_list.pop(idx)
if self.event_callback != None:
old_state = self.__keys_reverse.get(key)
self.event_callback(key, self.name(old_state), self.name(to_state))
self.__add_state_list(to_state, key)
self.register_modify(key)
return to_state return to_state
@ -342,7 +430,7 @@ class State:
to_state = current_state | or_state to_state = current_state | or_state
new_state = self.__reverse.get(to_state) new_state = self.__reverse.get(to_state)
if new_state == None: if new_state == None and self.check_alias:
raise StateInvalid('resulting to state is unknown: {}'.format(to_state)) raise StateInvalid('resulting to state is unknown: {}'.format(to_state))
return self.__move(key, current_state, to_state) return self.__move(key, current_state, to_state)
@ -351,13 +439,13 @@ class State:
def unset(self, key, not_state): def unset(self, key, not_state):
"""Unset a single bit, moving to a pure or alias state. """Unset a single bit, moving to a pure or alias state.
The resulting state cannot be NEW (0). The resulting state cannot be State.base_state_name (0).
:param key: Content key to modify state for :param key: Content key to modify state for
:type key: str :type key: str
:param or_state: Atomic stat to add :param or_state: Atomic stat to add
:type or_state: int :type or_state: int
:raises ValueError: State is not a single bit state, or attempts to revert to NEW :raises ValueError: State is not a single bit state, or attempts to revert to State.base_state_name
:raises StateItemNotFound: Content key is not registered :raises StateItemNotFound: Content key is not registered
:raises StateInvalid: Resulting state after addition of atomic state is unknown :raises StateInvalid: Resulting state after addition of atomic state is unknown
:rtype: int :rtype: int
@ -374,8 +462,30 @@ class State:
if to_state == current_state: if to_state == current_state:
raise ValueError('invalid change for state {}: {}'.format(key, not_state)) raise ValueError('invalid change for state {}: {}'.format(key, not_state))
if to_state == self.NEW: if to_state == getattr(self, self.base_state_name):
raise ValueError('State {} for {} cannot be reverted to NEW'.format(current_state, key)) raise ValueError('State {} for {} cannot be reverted to {}'.format(current_state, key, self.base_state_name))
new_state = self.__reverse.get(to_state)
if new_state == None:
raise StateInvalid('resulting to state is unknown: {}'.format(to_state))
return self.__move(key, current_state, to_state)
def change(self, key, sets, unsets):
current_state = self.__keys_reverse.get(key)
if current_state == None:
raise StateItemNotFound(key)
to_state = current_state | sets
to_state &= ~unsets & self.__limit
if sets == 0:
to_state = current_state & (~unsets)
if to_state == current_state:
raise ValueError('invalid change by unsets for state {}: {}'.format(key, unsets))
if to_state == getattr(self, self.base_state_name):
raise ValueError('State {} for {} cannot be reverted to {}'.format(current_state, key, self.base_state_name))
new_state = self.__reverse.get(to_state) new_state = self.__reverse.get(to_state)
if new_state == None: if new_state == None:
@ -424,7 +534,7 @@ class State:
return [] return []
def sync(self, state): def sync(self, state=None):
"""Noop method for interface implementation providing sync to backend. """Noop method for interface implementation providing sync to backend.
:param state: State to sync. :param state: State to sync.
@ -464,7 +574,7 @@ class State:
state = 1 state = 1
else: else:
state <<= 1 state <<= 1
if state > self.__c: if state > self.__limit:
raise StateInvalid('unknown state {}'.format(state)) raise StateInvalid('unknown state {}'.format(state))
return state return state
@ -496,3 +606,19 @@ class State:
""" """
self.state(key) self.state(key)
self.__contents[key] = contents self.__contents[key] = contents
def modified(self, key):
return self.modified_last[key]
def register_modify(self, key):
self.modified_last[key] = datetime.datetime.now().timestamp()
def mask(self, key, states=0):
statemask = self.__limit + 1
statemask |= states
statemask = ~statemask
statemask &= self.__limit
return statemask

View File

@ -8,9 +8,13 @@ class SimpleFileStore:
:param path: Filesystem base path for all state directory :param path: Filesystem base path for all state directory
:type path: str :type path: str
""" """
def __init__(self, path): def __init__(self, path, binary=False):
self.__path = path self.__path = path
os.makedirs(self.__path, exist_ok=True) os.makedirs(self.__path, exist_ok=True)
if binary:
self.__m = ['rb', 'wb']
else:
self.__m = ['r', 'w']
def add(self, k, contents=None): def add(self, k, contents=None):
@ -23,9 +27,12 @@ class SimpleFileStore:
""" """
fp = os.path.join(self.__path, k) fp = os.path.join(self.__path, k)
if contents == None: if contents == None:
if self.__m[1] == 'wb':
contents = b''
else:
contents = '' contents = ''
f = open(fp, 'w') f = open(fp, self.__m[1])
f.write(contents) f.write(contents)
f.close() f.close()
@ -51,7 +58,7 @@ class SimpleFileStore:
:return: Contents :return: Contents
""" """
fp = os.path.join(self.__path, k) fp = os.path.join(self.__path, k)
f = open(fp, 'r') f = open(fp, self.__m[0])
r = f.read() r = f.read()
f.close() f.close()
return r return r
@ -66,7 +73,7 @@ class SimpleFileStore:
files = [] files = []
for p in os.listdir(self.__path): for p in os.listdir(self.__path):
fp = os.path.join(self.__path, p) fp = os.path.join(self.__path, p)
f = open(fp, 'r') f = open(fp, self.__m[0])
r = f.read() r = f.read()
f.close() f.close()
if len(r) == 0: if len(r) == 0:
@ -98,19 +105,30 @@ class SimpleFileStore:
""" """
fp = os.path.join(self.__path, k) fp = os.path.join(self.__path, k)
os.stat(fp) os.stat(fp)
f = open(fp, 'w') f = open(fp, self.__m[1])
r = f.write(contents) r = f.write(contents)
f.close() f.close()
def modified(self, k):
path = self.path(k)
st = os.stat(path)
return st.st_ctime
def register_modify(self, k):
pass
class SimpleFileStoreFactory: class SimpleFileStoreFactory:
"""Provide a method to instantiate SimpleFileStore instances that provide persistence for individual states. """Provide a method to instantiate SimpleFileStore instances that provide persistence for individual states.
:param path: Filesystem path as base path for states :param path: Filesystem path as base path for states
:type path: str :type path: str
""" """
def __init__(self, path): def __init__(self, path, binary=False):
self.__path = path self.__path = path
self.__binary = binary
def add(self, k): def add(self, k):
@ -123,4 +141,4 @@ class SimpleFileStoreFactory:
""" """
k = str(k) k = str(k)
store_path = os.path.join(self.__path, k) store_path = os.path.join(self.__path, k)
return SimpleFileStore(store_path) return SimpleFileStore(store_path, binary=self.__binary)

2
shep/verify.py Normal file
View File

@ -0,0 +1,2 @@
def default_checker(statestore, old, new):
return None

View File

@ -74,6 +74,40 @@ class TestStateReport(unittest.TestCase):
os.stat(fp) os.stat(fp)
def test_change(self):
self.states.alias('inky', self.states.FOO | self.states.BAR)
self.states.put('abcd', state=self.states.FOO, contents='foo')
self.states.change('abcd', self.states.BAR, 0)
fp = os.path.join(self.d, 'INKY', 'abcd')
f = open(fp, 'r')
v = f.read()
f.close()
fp = os.path.join(self.d, 'FOO', 'abcd')
with self.assertRaises(FileNotFoundError):
os.stat(fp)
fp = os.path.join(self.d, 'BAR', 'abcd')
with self.assertRaises(FileNotFoundError):
os.stat(fp)
self.states.change('abcd', 0, self.states.BAR)
fp = os.path.join(self.d, 'FOO', 'abcd')
f = open(fp, 'r')
v = f.read()
f.close()
fp = os.path.join(self.d, 'INKY', 'abcd')
with self.assertRaises(FileNotFoundError):
os.stat(fp)
fp = os.path.join(self.d, 'BAR', 'abcd')
with self.assertRaises(FileNotFoundError):
os.stat(fp)
def test_set(self): def test_set(self):
self.states.alias('xyzzy', self.states.FOO | self.states.BAR) self.states.alias('xyzzy', self.states.FOO | self.states.BAR)
self.states.put('abcd', state=self.states.FOO, contents='foo') self.states.put('abcd', state=self.states.FOO, contents='foo')
@ -108,7 +142,7 @@ class TestStateReport(unittest.TestCase):
os.stat(fp) os.stat(fp)
def test_sync(self): def test_sync_one(self):
self.states.put('abcd', state=self.states.FOO, contents='foo') self.states.put('abcd', state=self.states.FOO, contents='foo')
self.states.put('xxx', state=self.states.FOO) self.states.put('xxx', state=self.states.FOO)
self.states.put('yyy', state=self.states.FOO) self.states.put('yyy', state=self.states.FOO)
@ -128,6 +162,25 @@ class TestStateReport(unittest.TestCase):
self.assertEqual(self.states.get('zzzz'), 'xyzzy') self.assertEqual(self.states.get('zzzz'), 'xyzzy')
def test_sync_all(self):
self.states.put('abcd', state=self.states.FOO)
self.states.put('xxx', state=self.states.BAR)
fp = os.path.join(self.d, 'FOO', 'abcd')
f = open(fp, 'w')
f.write('foofoo')
f.close()
fp = os.path.join(self.d, 'BAR', 'zzzz')
f = open(fp, 'w')
f.write('barbar')
f.close()
self.states.sync()
self.assertEqual(self.states.get('abcd'), None)
self.assertEqual(self.states.get('zzzz'), 'barbar')
def test_path(self): def test_path(self):
self.states.put('yyy', state=self.states.FOO) self.states.put('yyy', state=self.states.FOO)
@ -147,6 +200,9 @@ class TestStateReport(unittest.TestCase):
self.states.next('abcd') self.states.next('abcd')
self.assertEqual(self.states.state('abcd'), self.states.BAR) self.assertEqual(self.states.state('abcd'), self.states.BAR)
self.states.next('abcd')
self.assertEqual(self.states.state('abcd'), self.states.BAZ)
with self.assertRaises(StateInvalid): with self.assertRaises(StateInvalid):
self.states.next('abcd') self.states.next('abcd')
@ -154,7 +210,7 @@ class TestStateReport(unittest.TestCase):
with self.assertRaises(FileNotFoundError): with self.assertRaises(FileNotFoundError):
os.stat(fp) os.stat(fp)
fp = os.path.join(self.d, 'BAR', 'abcd') fp = os.path.join(self.d, 'BAZ', 'abcd')
os.stat(fp) os.stat(fp)

View File

@ -1,5 +1,6 @@
# standard imports # standard imports
import unittest import unittest
import logging
# local imports # local imports
from shep import State from shep import State
@ -8,6 +9,24 @@ from shep.error import (
StateInvalid, StateInvalid,
) )
logging.basicConfig(level=logging.DEBUG)
logg = logging.getLogger()
class MockCallback:
def __init__(self):
self.items = {}
self.items_from = {}
def add(self, k, v_from, v_to):
if self.items.get(k) == None:
self.items[k] = []
self.items_from[k] = []
self.items[k].append(v_to)
self.items_from[k].append(v_from)
class TestState(unittest.TestCase): class TestState(unittest.TestCase):
@ -18,7 +37,6 @@ class TestState(unittest.TestCase):
for k in [ for k in [
'f0o', 'f0o',
'f oo', 'f oo',
'f_oo',
]: ]:
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
states.add(k) states.add(k)
@ -33,11 +51,12 @@ class TestState(unittest.TestCase):
def test_limit(self): def test_limit(self):
states = State(2) states = State(3)
states.add('foo') states.add('foo')
states.add('bar') states.add('bar')
with self.assertRaises(OverflowError):
states.add('baz') states.add('baz')
with self.assertRaises(OverflowError):
states.add('gaz')
def test_dup(self): def test_dup(self):
@ -84,10 +103,33 @@ class TestState(unittest.TestCase):
states.alias('baz', 5) states.alias('baz', 5)
def test_peek(self): def test_alias_invalid(self):
states = State(3) states = State(3)
states.add('foo') states.add('foo')
states.add('bar') states.add('bar')
states.put('abcd')
states.set('abcd', states.FOO)
with self.assertRaises(StateInvalid):
states.set('abcd', states.BAR)
def test_alias_invalid_ignore(self):
states = State(3, check_alias=False)
states.add('foo')
states.add('bar')
states.add('baz')
states.put('abcd')
states.set('abcd', states.FOO)
states.set('abcd', states.BAZ)
v = states.state('abcd')
s = states.name(v)
self.assertEqual(s, '_FOO.BAZ')
def test_peek(self):
states = State(2)
states.add('foo')
states.add('bar')
states.put('abcd') states.put('abcd')
self.assertEqual(states.peek('abcd'), states.FOO) self.assertEqual(states.peek('abcd'), states.FOO)
@ -98,7 +140,7 @@ class TestState(unittest.TestCase):
states.move('abcd', states.BAR) states.move('abcd', states.BAR)
with self.assertRaises(StateInvalid): with self.assertRaises(StateInvalid):
self.assertEqual(states.peek('abcd')) states.peek('abcd')
def test_from_name(self): def test_from_name(self):
@ -107,5 +149,106 @@ class TestState(unittest.TestCase):
self.assertEqual(states.from_name('foo'), states.FOO) self.assertEqual(states.from_name('foo'), states.FOO)
def test_change(self):
states = State(3)
states.add('foo')
states.add('bar')
states.add('baz')
states.alias('inky', states.FOO | states.BAR)
states.alias('pinky', states.FOO | states.BAZ)
states.put('abcd')
states.next('abcd')
states.set('abcd', states.BAR)
states.change('abcd', states.BAZ, states.BAR)
self.assertEqual(states.state('abcd'), states.PINKY)
def test_change_onezero(self):
states = State(3)
states.add('foo')
states.add('bar')
states.add('baz')
states.alias('inky', states.FOO | states.BAR)
states.alias('pinky', states.FOO | states.BAZ)
states.put('abcd')
states.next('abcd')
states.change('abcd', states.BAR, 0)
self.assertEqual(states.state('abcd'), states.INKY)
states.change('abcd', 0, states.BAR)
self.assertEqual(states.state('abcd'), states.FOO)
def test_change_dates(self):
states = State(3)
states.add('foo')
states.put('abcd')
states.put('bcde')
a = states.modified('abcd')
b = states.modified('bcde')
self.assertGreater(b, a)
states.set('abcd', states.FOO)
a = states.modified('abcd')
b = states.modified('bcde')
self.assertGreater(a, b)
def test_event_callback(self):
cb = MockCallback()
states = State(3, event_callback=cb.add)
states.add('foo')
states.add('bar')
states.add('baz')
states.alias('xyzzy', states.FOO | states.BAR)
states.put('abcd')
states.set('abcd', states.FOO)
states.set('abcd', states.BAR)
states.change('abcd', states.BAZ, states.XYZZY)
events = cb.items['abcd']
self.assertEqual(len(events), 4)
self.assertEqual(states.from_name(events[0]), states.NEW)
self.assertEqual(states.from_name(events[1]), states.FOO)
self.assertEqual(states.from_name(events[2]), states.XYZZY)
self.assertEqual(states.from_name(events[3]), states.BAZ)
def test_dynamic(self):
states = State(0)
states.add('foo')
states.add('bar')
states.alias('baz', states.FOO | states.BAR)
def test_mask(self):
states = State(3)
states.add('foo')
states.add('bar')
states.add('baz')
states.alias('all', states.FOO | states.BAR | states.BAZ)
mask = states.mask('xyzzy', states.FOO | states.BAZ)
self.assertEqual(mask, states.BAR)
def test_mask_dynamic(self):
states = State(0)
states.add('foo')
states.add('bar')
states.add('baz')
states.alias('all', states.FOO | states.BAR | states.BAZ)
mask = states.mask('xyzzy', states.FOO | states.BAZ)
self.assertEqual(mask, states.BAR)
def test_mask_zero(self):
states = State(0)
states.add('foo')
states.add('bar')
states.add('baz')
states.alias('all', states.FOO | states.BAR | states.BAZ)
mask = states.mask('xyzzy')
self.assertEqual(mask, states.ALL)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -33,6 +33,10 @@ class MockStore:
return self.v[k] return self.v[k]
def list(self):
return list(self.v.keys())
class TestStateItems(unittest.TestCase): class TestStateItems(unittest.TestCase):
def setUp(self): def setUp(self):
@ -80,5 +84,6 @@ class TestStateItems(unittest.TestCase):
self.assertIsNone(self.mockstore.v.get(item)) self.assertIsNone(self.mockstore.v.get(item))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()