Compare commits

..

No commits in common. "master" and "0.1.0-dev" have entirely different histories.

20 changed files with 88 additions and 1534 deletions

View File

@ -1,38 +1,3 @@
- 0.2.10
* Add count active states method
* Enable complete replace of NEW state on state instantiation
- 0.2.9
* Enable access to is_pure method
* Numeric option for elements return value
- 0.2.8
* Remove sync on persist set
- 0.2.7
* Handle missing files in fs readdir list
- 0.2.6
* Ensure atomic state lock in fs
- 0.2.5
* Correct handling of persistent sync when no not-state filter has been set
- 0.2.4
* Add optional concurrency lock for persistence store, implemented for file store
- 0.2.3
* Add noop-store, for convenience for code using persist constructor but will only use memory state
- 0.2.2
* Fix composite state factory load regex
- 0.2.1
* Add rocksdb backend
- 0.2.0
* Add redis backend
* UTC timestamp for modification time in core state
- 0.1.1
* Optional, pluggable verifier to protect state transition
* Change method for atomic simultaneous set and unset
* Optionally allow undefined composite states
* Dynamic bits
* Optional binary contents
* Sync all if no state passed as argument
* Mask method for client-side state manipulation
- 0.1.0
* Release version bump
- 0.0.19:
* Enable alias with comma separated values
- 0.0.18

View File

@ -1,64 +0,0 @@
from shep.state import State
# we don't like "NEW" as the default label for a new item in the queue, so we change it to BACKLOG
State.set_default_state('backlog')
# define all the valid states
st = State(5)
st.add('pending')
st.add('blocked')
st.add('doing')
st.add('review')
st.add('finished')
# define a couple of states that give a bit more context to progress; something is blocked before starting development or something is blocked during development...
st.alias('startblock', st.BLOCKED, st.PENDING)
st.alias('doingblock', st.BLOCKED, st.DOING)
# create the foo key which will forever languish in backlog
k = 'foo'
st.put(k)
foo_state = st.state(k)
foo_state_name = st.name(foo_state)
foo_contents_r = st.get('foo')
print('{} {} {}'.format(k, foo_state_name, foo_contents_r))
# Create bar->baz and advance it from backlog to pending
k = 'bar'
bar_contents = 'baz'
st.put(k, contents=bar_contents)
st.next(k)
bar_state = st.state(k)
bar_state_name = st.name(bar_state)
bar_contents_r = st.get('bar')
print('{} {} {}'.format(k, bar_state_name, bar_contents_r))
# Create inky->pinky and move to doing then doing-blocked
k = 'inky'
inky_contents = 'pinky'
st.put(k, contents=inky_contents)
inky_state = st.state(k)
st.move(k, st.DOING)
st.set(k, st.BLOCKED)
inky_state = st.state(k)
inky_state_name = st.name(inky_state)
inky_contents_r = st.get('inky')
print('{} {} {}'.format(k, inky_state_name, bar_contents_r))
# then replace the content
# note that replace could potentially mean some VCS below
inky_new_contents = 'blinky'
st.replace(k, inky_new_contents)
inky_contents_r = st.get('inky')
print('{} {} {}'.format(k, inky_state_name, inky_contents_r))
# so now move to review
st.move(k, st.REVIEW)
inky_state = st.state(k)
inky_state_name = st.name(inky_state)
print('{} {} {}'.format(k, inky_state_name, inky_contents_r))

View File

@ -1,6 +1,6 @@
[metadata]
name = shep
version = 0.2.10
version = 0.1.0rc1
description = Multi-state key stores using bit masks
author = Louis Holbrook
author_email = dev@holbrook.no
@ -22,7 +22,7 @@ licence_files =
[options]
include_package_data = True
python_requires = >= 3.7
python_requires = >= 3.6
packages =
shep
shep.store

View File

@ -1,8 +1,3 @@
from setuptools import setup
setup(
extras_require={
'redis': 'redis==3.5.3',
'rocksdb': 'lbry-rocksdb==0.8.2',
},
)
setup()

View File

@ -26,15 +26,3 @@ class StateCorruptionError(RuntimeError):
"""An irrecoverable discrepancy between persisted state and memory state has occurred.
"""
pass
class StateTransitionInvalid(Exception):
"""Raised if state transition verification fails
"""
pass
class StateLockedKey(Exception):
"""Attempt to write to a state key that is being written to by another client
"""
pass

View File

@ -1,12 +1,6 @@
# standard imports
import datetime
# local imports
from .state import State
from .error import (
StateItemExists,
StateLockedKey,
)
from .error import StateItemExists
class PersistedState(State):
@ -20,8 +14,8 @@ class PersistedState(State):
:type logger: object
"""
def __init__(self, factory, bits, logger=None, verifier=None, check_alias=True, event_callback=None, default_state=None):
super(PersistedState, self).__init__(bits, logger=logger, verifier=verifier, check_alias=check_alias, event_callback=event_callback, default_state=default_state)
def __init__(self, factory, bits, logger=None):
super(PersistedState, self).__init__(bits, logger=logger)
self.__store_factory = factory
self.__stores = {}
@ -37,15 +31,12 @@ class PersistedState(State):
See shep.state.State.put
"""
k = self.to_name(state)
to_state = super(PersistedState, self).put(key, state=state, contents=contents)
k = self.name(to_state)
self.__ensure_store(k)
self.__stores[k].put(key, contents)
super(PersistedState, self).put(key, state=state, contents=contents)
self.register_modify(key)
self.__stores[k].add(key, contents)
def set(self, key, or_state):
@ -60,16 +51,9 @@ class PersistedState(State):
k_to = self.name(to_state)
self.__ensure_store(k_to)
contents = None
try:
contents = self.__stores[k_from].get(key)
self.__stores[k_to].put(key, contents)
self.__stores[k_from].remove(key)
except StateLockedKey as e:
super(PersistedState, self).unset(key, or_state, allow_base=True)
raise e
#self.sync(to_state)
contents = self.__stores[k_from].get(key)
self.__stores[k_to].add(key, contents)
self.__stores[k_from].remove(key)
return to_state
@ -88,34 +72,12 @@ class PersistedState(State):
self.__ensure_store(k_to)
contents = self.__stores[k_from].get(key)
self.__stores[k_to].put(key, contents)
self.__stores[k_to].add(key, contents)
self.__stores[k_from].remove(key)
return to_state
def change(self, key, bits_set, bits_unset):
"""Persist a new state for a key or key/content.
See shep.state.State.unset
"""
from_state = self.state(key)
k_from = self.name(from_state)
to_state = super(PersistedState, self).change(key, bits_set, bits_unset)
k_to = self.name(to_state)
self.__ensure_store(k_to)
contents = self.__stores[k_from].get(key)
self.__stores[k_to].put(key, contents)
self.__stores[k_from].remove(key)
self.register_modify(key)
return to_state
def move(self, key, to_state):
"""Persist a new state for a key or key/content.
@ -134,17 +96,13 @@ class PersistedState(State):
self.__ensure_store(k_to)
contents = self.__stores[k_from].get(key)
self.__stores[k_to].put(key, contents)
self.__stores[k_to].add(key, contents)
self.__stores[k_from].remove(key)
self.register_modify(key)
self.sync(to_state)
return to_state
def sync(self, state=None, not_state=None):
def sync(self, state):
"""Reload resources for a single state in memory from the persisted state store.
:param state: State to load
@ -152,33 +110,16 @@ class PersistedState(State):
:raises StateItemExists: A content key is already recorded with a different state in memory than in persisted store.
# :todo: if sync state is none, sync all
"""
k = self.name(state)
states_numeric = []
if state == None:
states_numeric = list(self.all(numeric=True))
else:
states_numeric = [state]
states = []
for state in states_numeric:
if not_state != None:
if state & not_state == 0:
states.append(self.name(state))
else:
states.append(self.name(state))
self.__ensure_store(k)
ks = []
for k in states:
ks.append(k)
for k in ks:
for o in self.__stores[k].list():
self.__ensure_store(k)
for o in self.__stores[k].list():
state = self.from_name(k)
try:
super(PersistedState, self).put(o[0], state=state, contents=o[1])
except StateItemExists as e:
pass
try:
super(PersistedState, self).put(o[0], state=state, contents=o[1])
except StateItemExists:
pass
def list(self, state):
@ -190,6 +131,7 @@ class PersistedState(State):
"""
k = self.name(state)
self.__ensure_store(k)
#return self.__stores[k].list(state)
return super(PersistedState, self).list(state)
@ -226,14 +168,7 @@ class PersistedState(State):
See shep.state.State.replace
"""
state = self.state(key)
k = self.name(state)
r = self.__stores[k].replace(key, contents)
super(PersistedState, self).replace(key, contents)
return r
def modified(self, key):
state = self.state(key)
k = self.name(state)
return self.__stores[k].modified(key)
return self.__stores[k].replace(key, contents)

View File

@ -1,20 +1,12 @@
# standard imports
import re
import datetime
# local imports
from shep.error import (
StateExists,
StateInvalid,
StateItemExists,
StateItemNotFound,
StateTransitionInvalid,
StateCorruptionError,
)
re_name = r'^[a-zA-Z_\.]+$'
class State:
"""State is an in-memory bitmasked state store for key-value pairs, or even just keys alone.
@ -27,39 +19,20 @@ class State:
:param logger: Standard library logging instance to output to
:type logger: logging.Logger
"""
base_state_name = 'NEW'
def __init__(self, bits, logger=None, verifier=None, check_alias=True, event_callback=None, default_state=None):
self.__initial_bits = bits
def __init__(self, bits, logger=None):
self.__bits = bits
self.__limit = (1 << bits) - 1
self.__c = 0
self.NEW = 0
if default_state == None:
default_state = self.base_state_name
setattr(self, default_state, 0)
self.__reverse = {0: getattr(self, default_state)}
self.__keys = {getattr(self, default_state): []}
self.__reverse = {0: self.NEW}
self.__keys = {self.NEW: []}
self.__keys_reverse = {}
if default_state != self.base_state_name:
self.__keys_reverse[default_state] = 0
self.__contents = {}
self.modified_last = {}
self.verifier = verifier
self.check_alias = check_alias
self.event_callback = event_callback
@classmethod
def set_default_state(cls, state_name):
cls.base_state_name = state_name.upper()
# return true if v is a single-bit state
def is_pure(self, v):
def __is_pure(self, v):
if v == 0:
return True
c = 1
@ -72,8 +45,8 @@ class State:
# validates a state name and return its canonical representation
def __check_name_valid(self, k):
if not re.match(re_name, k):
raise ValueError('only alpha and underscore')
if not k.isalpha():
raise ValueError('only alpha')
return k.upper()
@ -98,11 +71,7 @@ class State:
# enforces state value within bit limit of instantiation
def __check_limit(self, v, pure=True):
if pure:
if self.__initial_bits == 0:
self.__bits += 1
self.__limit = (1 << self.__bits) - 1
def __check_limit(self, v):
if v > self.__limit:
raise OverflowError(v)
return v
@ -145,20 +114,8 @@ class State:
def __add_state_list(self, state, item):
if self.__keys.get(state) == None:
self.__keys[state] = []
if not self.is_pure(state) or state == 0:
self.__keys[state].append(item)
c = 1
for i in range(self.__bits):
part = c & state
if part > 0:
if self.__keys.get(part) == None:
self.__keys[part] = []
self.__keys[part].append(item)
c <<= 1
self.__keys[state].append(item)
self.__keys_reverse[item] = state
if self.__reverse.get(state) == None and not self.check_alias:
s = self.elements(state)
self.__alias(s, state)
def __state_list_index(self, item, state_list):
@ -189,23 +146,7 @@ class State:
k = self.__check_name(k)
v = self.__check_value(v)
self.__set(k, v)
def to_name(self, k):
if k == None:
k = 0
return self.name(k)
def __alias(self, k, *args):
v = 0
for a in args:
a = self.__check_value_cursor(a)
v = self.__check_limit(v | a, pure=False)
if self.is_pure(v):
raise ValueError('use add to add pure values')
return self.__set(k, v)
def alias(self, k, *args):
"""Add an alias for a combination of states in the store.
@ -220,64 +161,32 @@ class State:
:raises ValueError: Attempt to use bit value as alias
"""
k = self.__check_name(k)
return self.__alias(k, *args)
v = 0
for a in args:
a = self.__check_value_cursor(a)
v = self.__check_limit(v | a)
if self.__is_pure(v):
raise ValueError('use add to add pure values')
self.__set(k, v)
def all(self, pure=False, numeric=False):
"""Return list of all unique atomic and alias state strings.
def all(self):
"""Return list of all unique atomic and alias states.
:rtype: list of ints
:return: states
"""
l = []
for k in dir(self):
state = None
if k[0] == '_':
continue
if k.upper() != k:
continue
if pure:
state = self.from_name(k)
if not self.is_pure(state):
continue
if numeric:
if state == None:
state = self.from_name(k)
l.append(state)
else:
l.append(k)
l.append(k)
l.sort()
return l
def elements(self, v, numeric=False, as_string=True):
r = []
if v == None or v == 0:
return self.base_state_name
c = 1
for i in range(self.__bits):
if v & c > 0:
if numeric:
r.append(c)
else:
r.append(self.name(c))
c <<= 1
if numeric or not as_string:
return r
return '_' + '.'.join(r)
def from_elements(self, k):
r = 0
if k[0] != '_':
raise ValueError('elements string must start with underscore (_), got {}'.format(k))
for v in k[1:].split('.'):
r |= self.from_name(v)
return r
def name(self, v):
"""Retrieve that string representation of the state attribute represented by the given state integer value.
@ -287,14 +196,11 @@ class State:
:rtype: str
:return: State name
"""
if v == None or v == 0:
return 'NEW'
k = self.__reverse.get(v)
if k == None:
if self.check_alias:
raise StateInvalid(v)
else:
k = self.elements(v)
elif v == None or v == 0:
return self.base_state_name
raise StateInvalid(v)
return k
@ -346,13 +252,13 @@ class State:
def put(self, key, state=None, contents=None):
"""Add a key to an existing state.
If no state it specified, the default state attribute State.base_state_name will be used.
If no state it specified, the default state attribute "NEW" will be used.
Contents may be supplied as value to pair with the given key. Contents may be changed later by calling the `replace` method.
:param key: Content key to add
:type key: str
:param state: Initial state for the put. If not given, initial state will be State.base_state_name
:param state: Initial state for the put. If not given, initial state will be NEW
:type state: int
:param contents: Contents to associate with key. A valie of None should be recognized as an undefined value as opposed to a zero-length value throughout any backend
:type contents: str
@ -362,21 +268,14 @@ class State:
:return: Resulting state that key is put under (should match the input state)
"""
if state == None:
state = getattr(self, self.base_state_name)
elif self.__reverse.get(state) == None and self.check_alias:
state = self.NEW
elif self.__reverse.get(state) == None:
raise StateInvalid(state)
self.__check_key(key)
if self.event_callback != None:
old_state = self.__keys_reverse.get(key)
self.event_callback(key, None, self.name(state))
self.__add_state_list(state, key)
if contents != None:
self.__contents[key] = contents
self.register_modify(key)
return state
@ -397,7 +296,7 @@ class State:
raise StateItemNotFound(key)
new_state = self.__reverse.get(to_state)
if new_state == None and self.check_alias:
if new_state == None:
raise StateInvalid(to_state)
return self.__move(key, current_state, to_state)
@ -415,20 +314,8 @@ class State:
if current_state_list == None:
raise StateCorruptionError(to_state)
if self.verifier != None:
r = self.verifier(self, from_state, to_state)
if r != None:
raise StateTransitionInvalid(r)
current_state_list.pop(idx)
if self.event_callback != None:
old_state = self.__keys_reverse.get(key)
self.event_callback(key, self.name(old_state), self.name(to_state))
self.__add_state_list(to_state, key)
self.register_modify(key)
current_state_list.pop(idx)
return to_state
@ -446,7 +333,7 @@ class State:
:rtype: int
:returns: Resulting state
"""
if not self.is_pure(or_state):
if not self.__is_pure(or_state):
raise ValueError('can only apply using single bit states')
current_state = self.__keys_reverse.get(key)
@ -455,28 +342,28 @@ class State:
to_state = current_state | or_state
new_state = self.__reverse.get(to_state)
if new_state == None and self.check_alias:
if new_state == None:
raise StateInvalid('resulting to state is unknown: {}'.format(to_state))
return self.__move(key, current_state, to_state)
def unset(self, key, not_state, allow_base=False):
def unset(self, key, not_state):
"""Unset a single bit, moving to a pure or alias state.
The resulting state cannot be State.base_state_name (0).
The resulting state cannot be NEW (0).
:param key: Content key to modify state for
:type key: str
:param or_state: Atomic stat to add
:type or_state: int
:raises ValueError: State is not a single bit state, or attempts to revert to State.base_state_name
:raises ValueError: State is not a single bit state, or attempts to revert to NEW
:raises StateItemNotFound: Content key is not registered
:raises StateInvalid: Resulting state after addition of atomic state is unknown
:rtype: int
:returns: Resulting state
"""
if not self.is_pure(not_state):
if not self.__is_pure(not_state):
raise ValueError('can only apply using single bit states')
current_state = self.__keys_reverse.get(key)
@ -487,30 +374,8 @@ class State:
if to_state == current_state:
raise ValueError('invalid change for state {}: {}'.format(key, not_state))
if to_state == getattr(self, self.base_state_name) and not allow_base:
raise ValueError('State {} for {} cannot be reverted to {}'.format(current_state, key, self.base_state_name))
new_state = self.__reverse.get(to_state)
if new_state == None:
raise StateInvalid('resulting to state is unknown: {}'.format(to_state))
return self.__move(key, current_state, to_state)
def change(self, key, sets, unsets):
current_state = self.__keys_reverse.get(key)
if current_state == None:
raise StateItemNotFound(key)
to_state = current_state | sets
to_state &= ~unsets & self.__limit
if sets == 0:
to_state = current_state & (~unsets)
if to_state == current_state:
raise ValueError('invalid change by unsets for state {}: {}'.format(key, unsets))
if to_state == getattr(self, self.base_state_name):
raise ValueError('State {} for {} cannot be reverted to {}'.format(current_state, key, self.base_state_name))
if to_state == self.NEW:
raise ValueError('State {} for {} cannot be reverted to NEW'.format(current_state, key))
new_state = self.__reverse.get(to_state)
if new_state == None:
@ -559,7 +424,7 @@ class State:
return []
def sync(self, state=None):
def sync(self, state):
"""Noop method for interface implementation providing sync to backend.
:param state: State to sync.
@ -592,14 +457,14 @@ class State:
state = self.__keys_reverse.get(key)
if state == None:
raise StateItemNotFound(key)
if not self.is_pure(state):
if not self.__is_pure(state):
raise StateInvalid('cannot run next on an alias state')
if state == 0:
state = 1
else:
state <<= 1
if state > self.__limit:
if state > self.__c:
raise StateInvalid('unknown state {}'.format(state))
return state
@ -631,43 +496,3 @@ class State:
"""
self.state(key)
self.__contents[key] = contents
def modified(self, key):
return self.modified_last[key]
def register_modify(self, key):
self.modified_last[key] = datetime.datetime.utcnow().timestamp()
def mask(self, key, states=0):
statemask = self.__limit + 1
statemask |= states
statemask = ~statemask
statemask &= self.__limit
return statemask
def purge(self, key):
state = self.state(key)
state_name = self.name(state)
v = self.__keys.get(state)
v.remove(key)
del self.__keys_reverse[key]
try:
del self.__contents[key]
except KeyError:
pass
try:
del self.modified_last[key]
except KeyError:
pass
def count(self):
return self.__c

View File

@ -1,19 +0,0 @@
re_processedname = r'^_?[A-Z\._]*$'
class StoreFactory:
def __del__(self):
self.close()
def add(self, k):
raise NotImplementedError()
def close(self):
pass
def ls(self):
raise NotImplementedError()

View File

@ -1,13 +1,5 @@
# standard imports
import os
import re
# local imports
from .base import (
re_processedname,
StoreFactory,
)
from shep.error import StateLockedKey
class SimpleFileStore:
@ -16,43 +8,12 @@ class SimpleFileStore:
:param path: Filesystem base path for all state directory
:type path: str
"""
def __init__(self, path, binary=False, lock_path=None):
def __init__(self, path):
self.__path = path
os.makedirs(self.__path, exist_ok=True)
if binary:
self.__m = ['rb', 'wb']
else:
self.__m = ['r', 'w']
self.__lock_path = lock_path
if self.__lock_path != None:
os.makedirs(lock_path, exist_ok=True)
def __lock(self, k):
if self.__lock_path == None:
return
fp = os.path.join(self.__lock_path, k)
f = None
try:
f = open(fp, 'x')
except FileExistsError:
pass
if f == None:
raise StateLockedKey(k)
f.close()
def __unlock(self, k):
if self.__lock_path == None:
return
fp = os.path.join(self.__lock_path, k)
try:
os.unlink(fp)
except FileNotFoundError:
pass
def put(self, k, contents=None):
def add(self, k, contents=None):
"""Add a new key and optional contents
:param k: Content key to add
@ -60,18 +21,13 @@ class SimpleFileStore:
:param contents: Optional contents to assign for content key
:type contents: any
"""
self.__lock(k)
fp = os.path.join(self.__path, k)
if contents == None:
if self.__m[1] == 'wb':
contents = b''
else:
contents = ''
contents = ''
f = open(fp, self.__m[1])
f = open(fp, 'w')
f.write(contents)
f.close()
self.__unlock(k)
def remove(self, k):
@ -81,10 +37,8 @@ class SimpleFileStore:
:type k: str
:raises FileNotFoundError: Content key does not exist in the state
"""
self.__lock(k)
fp = os.path.join(self.__path, k)
os.unlink(fp)
self.__unlock(k)
def get(self, k):
@ -96,12 +50,10 @@ class SimpleFileStore:
:rtype: any
:return: Contents
"""
self.__lock(k)
fp = os.path.join(self.__path, k)
f = open(fp, self.__m[0])
f = open(fp, 'r')
r = f.read()
f.close()
self.__unlock(k)
return r
@ -111,21 +63,15 @@ class SimpleFileStore:
:rtype: list of str
:return: Content keys in state
"""
self.__lock('.list')
files = []
for p in os.listdir(self.__path):
fp = os.path.join(self.__path, p)
f = None
try:
f = open(fp, self.__m[0])
except FileNotFoundError:
continue
f = open(fp, 'r')
r = f.read()
f.close()
if len(r) == 0:
r = None
files.append((p, r,))
self.__unlock('.list')
return files
@ -150,37 +96,21 @@ class SimpleFileStore:
:param contents: Contents
:type contents: any
"""
self.__lock(k)
fp = os.path.join(self.__path, k)
os.stat(fp)
f = open(fp, self.__m[1])
f = open(fp, 'w')
r = f.write(contents)
f.close()
self.__unlock(k)
def modified(self, k):
self.__lock(k)
path = self.path(k)
st = os.stat(path)
self.__unlock(k)
return st.st_ctime
def register_modify(self, k):
pass
class SimpleFileStoreFactory(StoreFactory):
class SimpleFileStoreFactory:
"""Provide a method to instantiate SimpleFileStore instances that provide persistence for individual states.
:param path: Filesystem path as base path for states
:type path: str
"""
def __init__(self, path, binary=False, use_lock=False):
def __init__(self, path):
self.__path = path
self.__binary = binary
self.__use_lock = use_lock
def add(self, k):
@ -191,22 +121,6 @@ class SimpleFileStoreFactory(StoreFactory):
:rtype: SimpleFileStore
:return: A filesystem persistence instance with the given identifier as subdirectory
"""
lock_path = None
if self.__use_lock:
lock_path = os.path.join(self.__path, '.lock')
k = str(k)
store_path = os.path.join(self.__path, k)
return SimpleFileStore(store_path, binary=self.__binary, lock_path=lock_path)
def ls(self):
r = []
for v in os.listdir(self.__path):
if re.match(re_processedname, v):
r.append(v)
return r
def close(self):
pass
return SimpleFileStore(store_path)

View File

@ -1,44 +0,0 @@
# local imports
from .base import StoreFactory
class NoopStore:
def put(self, k, contents=None):
pass
def remove(self, k):
pass
def get(self, k):
pass
def list(self):
return []
def path(self):
return None
def replace(self, k, contents):
pass
def modified(self, k):
pass
def register_modify(self, k):
pass
class NoopStoreFactory(StoreFactory):
def add(self, k):
return NoopStore()
def ls(self):
return []

View File

@ -1,117 +0,0 @@
# standard imports
import datetime
# external imports
import redis
# local imports
from .base import StoreFactory
class RedisStore:
def __init__(self, path, redis, binary=False):
self.redis = redis
self.__path = path
self.__binary = binary
def __to_path(self, k):
return '.'.join([self.__path, k])
def __from_path(self, s):
(left, right) = s.split(b'.', maxsplit=1)
return right
def __to_result(self, v):
if self.__binary:
return v
return v.decode('utf-8')
def put(self, k, contents=b''):
if contents == None:
contents = b''
k = self.__to_path(k)
self.redis.set(k, contents)
def remove(self, k):
k = self.__to_path(k)
self.redis.delete(k)
def get(self, k):
k = self.__to_path(k)
v = self.redis.get(k)
return self.__to_result(v)
def list(self):
(cursor, matches) = self.redis.scan(match=self.__path + '.*')
r = []
for s in matches:
k = self.__from_path(s)
v = self.redis.get(k)
r.append((k, v,))
return r
def path(self):
return None
def replace(self, k, contents):
if contents == None:
contents = b''
k = self.__to_path(k)
v = self.redis.get(k)
if v == None:
raise FileNotFoundError(k)
self.redis.set(k, contents)
def modified(self, k):
k = self.__to_path(k)
k = '_mod' + k
v = self.redis.get(k)
return int(v)
def register_modify(self, k):
k = self.__to_path(k)
k = '_mod' + k
ts = datetime.datetime.utcnow().timestamp()
self.redis.set(k)
class RedisStoreFactory(StoreFactory):
def __init__(self, host='localhost', port=6379, db=2, binary=False):
self.redis = redis.Redis(host=host, port=port, db=db)
self.__binary = binary
def add(self, k):
k = str(k)
return RedisStore(k, self.redis, binary=self.__binary)
def close(self):
self.redis.close()
def ls(self):
r = []
(c, ks) = self.redis.scan(match='*')
for k in ks:
v = k.rsplit(b'.', maxsplit=1)
if v != k:
v = v[0].decode('utf-8')
if v not in r:
r.append(v)
return r

View File

@ -1,147 +0,0 @@
# standard imports
import datetime
import os
# external imports
import rocksdb
# local imports
from .base import StoreFactory
class RocksDbStore:
def __init__(self, path, db, binary=False):
self.db = db
self.__path = path
self.__binary = binary
def __to_key(self, k):
return k.encode('utf-8')
def __to_contents(self, v):
if isinstance(v, bytes):
return v
return v.encode('utf-8')
def __to_path(self, k):
return '.'.join([self.__path, k])
def __from_path(self, s):
(left, right) = s.split('.', maxsplit=1)
return right
def __to_result(self, v):
if self.__binary:
return v
return v.decode('utf-8')
def put(self, k, contents=b''):
if contents == None:
contents = b''
else:
contents = self.__to_contents(contents)
k = self.__to_path(k)
k = self.__to_key(k)
self.db.put(k, contents)
def remove(self, k):
k = self.__to_path(k)
k = self.__to_key(k)
self.db.delete(k)
def get(self, k):
k = self.__to_path(k)
k = self.__to_key(k)
v = self.db.get(k)
return self.__to_result(v)
def list(self):
it = self.db.iteritems()
kb_start = self.__to_key(self.__path)
it.seek(kb_start)
r = []
l = len(self.__path)
for (kb, v) in it:
k = kb.decode('utf-8')
if len(k) < l or k[:l] != self.__path:
break
k = self.__from_path(k)
v = self.db.get(kb)
r.append((k, v,))
return r
def path(self):
return None
def replace(self, k, contents):
if contents == None:
contents = b''
else:
contents = self.__to_contents(contents)
k = self.__to_path(k)
k = self.__to_key(k)
v = self.db.get(k)
if v == None:
raise FileNotFoundError(k)
self.db.put(k, contents)
def modified(self, k):
k = self.__to_path(k)
k = '_mod' + k
v = self.db.get(k)
return int(v)
def register_modify(self, k):
k = self.__to_path(k)
k = '_mod' + k
ts = datetime.datetime.utcnow().timestamp()
self.db.set(k)
class RocksDbStoreFactory(StoreFactory):
def __init__(self, path, binary=False):
try:
os.stat(path)
except FileNotFoundError:
os.makedirs(path)
self.db = rocksdb.DB(path, rocksdb.Options(create_if_missing=True))
self.__binary = binary
def add(self, k):
k = str(k)
return RocksDbStore(k, self.db, binary=self.__binary)
def close(self):
self.db.close()
def ls(self):
it = self.db.iterkeys()
it.seek_to_first()
r = []
for k in it:
v = k.rsplit(b'.', maxsplit=1)
if v != k:
v = v[0].decode('utf-8')
if v not in r:
r.append(v)
return r

View File

@ -1,2 +0,0 @@
def default_checker(statestore, old, new):
return None

View File

@ -2,7 +2,6 @@
import unittest
import tempfile
import os
import shutil
# local imports
from shep.persist import PersistedState
@ -11,25 +10,20 @@ from shep.error import (
StateExists,
StateInvalid,
StateItemExists,
StateLockedKey,
)
class TestFileStore(unittest.TestCase):
class TestStateReport(unittest.TestCase):
def setUp(self):
self.d = tempfile.mkdtemp()
self.factory = SimpleFileStoreFactory(self.d)
self.states = PersistedState(self.factory.add, 3)
self.states = PersistedState(self.factory.add, 4)
self.states.add('foo')
self.states.add('bar')
self.states.add('baz')
def tearDown(self):
shutil.rmtree(self.d)
def test_add(self):
self.states.put('abcd', state=self.states.FOO, contents='baz')
fp = os.path.join(self.d, 'FOO', 'abcd')
@ -79,41 +73,7 @@ class TestFileStore(unittest.TestCase):
with self.assertRaises(FileNotFoundError):
os.stat(fp)
def test_change(self):
self.states.alias('inky', self.states.FOO | self.states.BAR)
self.states.put('abcd', state=self.states.FOO, contents='foo')
self.states.change('abcd', self.states.BAR, 0)
fp = os.path.join(self.d, 'INKY', 'abcd')
f = open(fp, 'r')
v = f.read()
f.close()
fp = os.path.join(self.d, 'FOO', 'abcd')
with self.assertRaises(FileNotFoundError):
os.stat(fp)
fp = os.path.join(self.d, 'BAR', 'abcd')
with self.assertRaises(FileNotFoundError):
os.stat(fp)
self.states.change('abcd', 0, self.states.BAR)
fp = os.path.join(self.d, 'FOO', 'abcd')
f = open(fp, 'r')
v = f.read()
f.close()
fp = os.path.join(self.d, 'INKY', 'abcd')
with self.assertRaises(FileNotFoundError):
os.stat(fp)
fp = os.path.join(self.d, 'BAR', 'abcd')
with self.assertRaises(FileNotFoundError):
os.stat(fp)
def test_set(self):
self.states.alias('xyzzy', self.states.FOO | self.states.BAR)
self.states.put('abcd', state=self.states.FOO, contents='foo')
@ -148,7 +108,7 @@ class TestFileStore(unittest.TestCase):
os.stat(fp)
def test_sync_one(self):
def test_sync(self):
self.states.put('abcd', state=self.states.FOO, contents='foo')
self.states.put('xxx', state=self.states.FOO)
self.states.put('yyy', state=self.states.FOO)
@ -168,33 +128,6 @@ class TestFileStore(unittest.TestCase):
self.assertEqual(self.states.get('zzzz'), 'xyzzy')
def test_sync_all(self):
self.states.put('abcd', state=self.states.FOO)
self.states.put('xxx', state=self.states.BAR)
fp = os.path.join(self.d, 'FOO', 'abcd')
f = open(fp, 'w')
f.write('foofoo')
f.close()
fp = os.path.join(self.d, 'BAR', 'zzzz')
f = open(fp, 'w')
f.write('barbar')
f.close()
fp = os.path.join(self.d, 'BAR', 'yyyy')
f = open(fp, 'w')
f.close()
self.states.sync()
self.assertEqual(self.states.get('abcd'), None)
self.assertEqual(self.states.state('abcd'), self.states.FOO)
self.assertEqual(self.states.get('zzzz'), 'barbar')
self.assertEqual(self.states.state('zzzz'), self.states.BAR)
self.assertEqual(self.states.get('yyyy'), None)
self.assertEqual(self.states.state('yyyy'), self.states.BAR)
def test_path(self):
self.states.put('yyy', state=self.states.FOO)
@ -214,20 +147,14 @@ class TestFileStore(unittest.TestCase):
self.states.next('abcd')
self.assertEqual(self.states.state('abcd'), self.states.BAR)
self.states.next('abcd')
self.assertEqual(self.states.state('abcd'), self.states.BAZ)
with self.assertRaises(StateInvalid):
self.states.next('abcd')
v = self.states.state('abcd')
self.assertEqual(v, self.states.BAZ)
fp = os.path.join(self.d, 'FOO', 'abcd')
with self.assertRaises(FileNotFoundError):
os.stat(fp)
fp = os.path.join(self.d, 'BAZ', 'abcd')
fp = os.path.join(self.d, 'BAR', 'abcd')
os.stat(fp)
@ -243,67 +170,5 @@ class TestFileStore(unittest.TestCase):
self.assertEqual(r, 'foo')
def test_factory_ls(self):
self.states.put('abcd')
self.states.put('xxxx', state=self.states.BAZ)
r = self.factory.ls()
self.assertEqual(len(r), 2)
self.states.put('yyyy', state=self.states.BAZ)
r = self.factory.ls()
self.assertEqual(len(r), 2)
self.states.put('zzzz', state=self.states.BAR)
r = self.factory.ls()
self.assertEqual(len(r), 3)
def test_lock(self):
factory = SimpleFileStoreFactory(self.d, use_lock=True)
states = PersistedState(factory.add, 3)
states.add('foo')
states.add('bar')
states.add('baz')
states.alias('xyzzy', states.FOO | states.BAR)
states.alias('plugh', states.FOO | states.BAR | states.BAZ)
states.put('abcd')
lock_path = os.path.join(self.d, '.lock')
os.stat(lock_path)
fp = os.path.join(self.d, '.lock', 'xxxx')
f = open(fp, 'w')
f.close()
with self.assertRaises(StateLockedKey):
states.put('xxxx')
os.unlink(fp)
states.put('xxxx')
states.set('xxxx', states.FOO)
states.set('xxxx', states.BAR)
states.replace('xxxx', contents='zzzz')
fp = os.path.join(self.d, '.lock', 'xxxx')
f = open(fp, 'w')
f.close()
with self.assertRaises(StateLockedKey):
states.set('xxxx', states.BAZ)
v = states.state('xxxx')
self.assertEqual(v, states.XYZZY)
with self.assertRaises(StateLockedKey):
states.unset('xxxx', states.FOO)
with self.assertRaises(StateLockedKey):
states.replace('xxxx', contents='yyyy')
v = states.get('xxxx')
self.assertEqual(v, 'zzzz')
if __name__ == '__main__':
unittest.main()

View File

@ -1,78 +0,0 @@
# standard imports
import unittest
import os
import logging
import sys
import importlib
import tempfile
# local imports
from shep.persist import PersistedState
from shep.store.noop import NoopStoreFactory
from shep.error import (
StateExists,
StateInvalid,
StateItemExists,
StateItemNotFound,
)
logging.basicConfig(level=logging.DEBUG)
logg = logging.getLogger()
class TestNoopStore(unittest.TestCase):
def setUp(self):
self.factory = NoopStoreFactory()
self.states = PersistedState(self.factory.add, 3)
self.states.add('foo')
self.states.add('bar')
self.states.add('baz')
def test_add(self):
self.states.put('abcd', state=self.states.FOO, contents='baz')
v = self.states.get('abcd')
self.assertEqual(v, 'baz')
v = self.states.state('abcd')
self.assertEqual(v, self.states.FOO)
def test_next(self):
self.states.put('abcd')
self.states.next('abcd')
self.assertEqual(self.states.state('abcd'), self.states.FOO)
self.states.next('abcd')
self.assertEqual(self.states.state('abcd'), self.states.BAR)
self.states.next('abcd')
self.assertEqual(self.states.state('abcd'), self.states.BAZ)
with self.assertRaises(StateInvalid):
self.states.next('abcd')
v = self.states.state('abcd')
self.assertEqual(v, self.states.BAZ)
def test_replace(self):
with self.assertRaises(StateItemNotFound):
self.states.replace('abcd', contents='foo')
self.states.put('abcd', state=self.states.FOO, contents='baz')
self.states.replace('abcd', contents='bar')
v = self.states.get('abcd')
self.assertEqual(v, 'bar')
def test_factory_ls(self):
self.states.put('abcd')
self.states.put('xxxx', state=self.states.BAZ)
r = self.factory.ls()
self.assertEqual(len(r), 0)
if __name__ == '__main__':
unittest.main()

View File

@ -1,112 +0,0 @@
# standard imports
import unittest
import os
import logging
import sys
import importlib
# local imports
from shep.persist import PersistedState
from shep.error import (
StateExists,
StateInvalid,
StateItemExists,
StateItemNotFound,
)
logging.basicConfig(level=logging.DEBUG)
logg = logging.getLogger()
class TestRedisStore(unittest.TestCase):
def setUp(self):
from shep.store.redis import RedisStoreFactory
self.factory = RedisStoreFactory()
self.factory.redis.flushall()
self.states = PersistedState(self.factory.add, 3)
self.states.add('foo')
self.states.add('bar')
self.states.add('baz')
def test_add(self):
self.states.put('abcd', state=self.states.FOO, contents='baz')
v = self.states.get('abcd')
self.assertEqual(v, 'baz')
v = self.states.state('abcd')
self.assertEqual(v, self.states.FOO)
def test_next(self):
self.states.put('abcd')
self.states.next('abcd')
self.assertEqual(self.states.state('abcd'), self.states.FOO)
self.states.next('abcd')
self.assertEqual(self.states.state('abcd'), self.states.BAR)
self.states.next('abcd')
self.assertEqual(self.states.state('abcd'), self.states.BAZ)
with self.assertRaises(StateInvalid):
self.states.next('abcd')
v = self.states.state('abcd')
self.assertEqual(v, self.states.BAZ)
def test_replace(self):
with self.assertRaises(StateItemNotFound):
self.states.replace('abcd', contents='foo')
self.states.put('abcd', state=self.states.FOO, contents='baz')
self.states.replace('abcd', contents='bar')
v = self.states.get('abcd')
self.assertEqual(v, 'bar')
def test_factory_ls(self):
r = self.factory.ls()
self.assertEqual(len(r), 0)
self.states.put('abcd')
self.states.put('xxxx', state=self.states.BAZ)
r = self.factory.ls()
self.assertEqual(len(r), 2)
self.states.put('yyyy', state=self.states.BAZ)
r = self.factory.ls()
self.assertEqual(len(r), 2)
self.states.put('zzzz', state=self.states.BAR)
r = self.factory.ls()
self.assertEqual(len(r), 3)
if __name__ == '__main__':
noredis = False
redis = None
try:
redis = importlib.import_module('redis')
except ModuleNotFoundError:
logg.critical('redis module not available, skipping tests.')
sys.exit(0)
host = os.environ.get('REDIS_HOST', 'localhost')
port = os.environ.get('REDIS_PORT', 6379)
port = int(port)
db = os.environ.get('REDIS_DB', 0)
db = int(db)
r = redis.Redis(host=host, port=port, db=db)
try:
r.get('foo')
except redis.exceptions.ConnectionError:
logg.critical('could not connect to redis, skipping tests.')
sys.exit(0)
except redis.exceptions.InvalidResponse as e:
logg.critical('is that really redis running on {}:{}? Got unexpected response: {}'.format(host, port, e))
sys.exit(0)
unittest.main()

View File

@ -1,100 +0,0 @@
# standard imports
import unittest
import os
import logging
import sys
import importlib
import tempfile
import shutil
# local imports
from shep.persist import PersistedState
from shep.error import (
StateExists,
StateInvalid,
StateItemExists,
StateItemNotFound,
)
logging.basicConfig(level=logging.DEBUG)
logg = logging.getLogger()
class TestRedisStore(unittest.TestCase):
def setUp(self):
from shep.store.rocksdb import RocksDbStoreFactory
self.d = tempfile.mkdtemp()
self.factory = RocksDbStoreFactory(self.d)
self.states = PersistedState(self.factory.add, 3)
self.states.add('foo')
self.states.add('bar')
self.states.add('baz')
def tearDown(self):
shutil.rmtree(self.d)
def test_add(self):
self.states.put('abcd', state=self.states.FOO, contents='baz')
v = self.states.get('abcd')
self.assertEqual(v, 'baz')
v = self.states.state('abcd')
self.assertEqual(v, self.states.FOO)
def test_next(self):
self.states.put('abcd')
self.states.next('abcd')
self.assertEqual(self.states.state('abcd'), self.states.FOO)
self.states.next('abcd')
self.assertEqual(self.states.state('abcd'), self.states.BAR)
self.states.next('abcd')
self.assertEqual(self.states.state('abcd'), self.states.BAZ)
with self.assertRaises(StateInvalid):
self.states.next('abcd')
v = self.states.state('abcd')
self.assertEqual(v, self.states.BAZ)
def test_replace(self):
with self.assertRaises(StateItemNotFound):
self.states.replace('abcd', contents='foo')
self.states.put('abcd', state=self.states.FOO, contents='baz')
self.states.replace('abcd', contents='bar')
v = self.states.get('abcd')
self.assertEqual(v, 'bar')
def test_factory_ls(self):
self.states.put('abcd')
self.states.put('xxxx', state=self.states.BAZ)
r = self.factory.ls()
self.assertEqual(len(r), 2)
self.states.put('yyyy', state=self.states.BAZ)
r = self.factory.ls()
self.assertEqual(len(r), 2)
self.states.put('zzzz', state=self.states.BAR)
r = self.factory.ls()
self.assertEqual(len(r), 3)
if __name__ == '__main__':
norocksdb = False
rocksdb = None
try:
importlib.import_module('rocksdb')
except ModuleNotFoundError:
logg.critical('rocksdb module not available, skipping tests.')
sys.exit(0)
unittest.main()

View File

@ -1,33 +1,13 @@
# standard imports
import unittest
import logging
# local imports
from shep import State
from shep.error import (
StateExists,
StateInvalid,
StateItemNotFound,
)
logging.basicConfig(level=logging.DEBUG)
logg = logging.getLogger()
class MockCallback:
def __init__(self):
self.items = {}
self.items_from = {}
def add(self, k, v_from, v_to):
if self.items.get(k) == None:
self.items[k] = []
self.items_from[k] = []
self.items[k].append(v_to)
self.items_from[k].append(v_from)
class TestState(unittest.TestCase):
@ -38,6 +18,7 @@ class TestState(unittest.TestCase):
for k in [
'f0o',
'f oo',
'f_oo',
]:
with self.assertRaises(ValueError):
states.add(k)
@ -52,12 +33,11 @@ class TestState(unittest.TestCase):
def test_limit(self):
states = State(3)
states = State(2)
states.add('foo')
states.add('bar')
states.add('baz')
with self.assertRaises(OverflowError):
states.add('gaz')
states.add('baz')
def test_dup(self):
@ -102,33 +82,10 @@ class TestState(unittest.TestCase):
states.add('bar')
with self.assertRaises(StateInvalid):
states.alias('baz', 5)
def test_alias_invalid(self):
states = State(3)
states.add('foo')
states.add('bar')
states.put('abcd')
states.set('abcd', states.FOO)
with self.assertRaises(StateInvalid):
states.set('abcd', states.BAR)
def test_alias_invalid_ignore(self):
states = State(3, check_alias=False)
states.add('foo')
states.add('bar')
states.add('baz')
states.put('abcd')
states.set('abcd', states.FOO)
states.set('abcd', states.BAZ)
v = states.state('abcd')
s = states.name(v)
self.assertEqual(s, '_FOO.BAZ')
def test_peek(self):
states = State(2)
states = State(3)
states.add('foo')
states.add('bar')
@ -141,7 +98,7 @@ class TestState(unittest.TestCase):
states.move('abcd', states.BAR)
with self.assertRaises(StateInvalid):
states.peek('abcd')
self.assertEqual(states.peek('abcd'))
def test_from_name(self):
@ -150,177 +107,5 @@ class TestState(unittest.TestCase):
self.assertEqual(states.from_name('foo'), states.FOO)
def test_change(self):
states = State(3)
states.add('foo')
states.add('bar')
states.add('baz')
states.alias('inky', states.FOO | states.BAR)
states.alias('pinky', states.FOO | states.BAZ)
states.put('abcd')
states.next('abcd')
states.set('abcd', states.BAR)
states.change('abcd', states.BAZ, states.BAR)
self.assertEqual(states.state('abcd'), states.PINKY)
def test_change_onezero(self):
states = State(3)
states.add('foo')
states.add('bar')
states.add('baz')
states.alias('inky', states.FOO | states.BAR)
states.alias('pinky', states.FOO | states.BAZ)
states.put('abcd')
states.next('abcd')
states.change('abcd', states.BAR, 0)
self.assertEqual(states.state('abcd'), states.INKY)
states.change('abcd', 0, states.BAR)
self.assertEqual(states.state('abcd'), states.FOO)
def test_change_dates(self):
states = State(3)
states.add('foo')
states.put('abcd')
states.put('bcde')
a = states.modified('abcd')
b = states.modified('bcde')
self.assertGreater(b, a)
states.set('abcd', states.FOO)
a = states.modified('abcd')
b = states.modified('bcde')
self.assertGreater(a, b)
def test_event_callback(self):
cb = MockCallback()
states = State(3, event_callback=cb.add)
states.add('foo')
states.add('bar')
states.add('baz')
states.alias('xyzzy', states.FOO | states.BAR)
states.put('abcd')
states.set('abcd', states.FOO)
states.set('abcd', states.BAR)
states.change('abcd', states.BAZ, states.XYZZY)
events = cb.items['abcd']
self.assertEqual(len(events), 4)
self.assertEqual(states.from_name(events[0]), states.NEW)
self.assertEqual(states.from_name(events[1]), states.FOO)
self.assertEqual(states.from_name(events[2]), states.XYZZY)
self.assertEqual(states.from_name(events[3]), states.BAZ)
def test_dynamic(self):
states = State(0)
states.add('foo')
states.add('bar')
states.alias('baz', states.FOO | states.BAR)
def test_mask(self):
states = State(3)
states.add('foo')
states.add('bar')
states.add('baz')
states.alias('all', states.FOO | states.BAR | states.BAZ)
mask = states.mask('xyzzy', states.FOO | states.BAZ)
self.assertEqual(mask, states.BAR)
def test_mask_dynamic(self):
states = State(0)
states.add('foo')
states.add('bar')
states.add('baz')
states.alias('all', states.FOO | states.BAR | states.BAZ)
mask = states.mask('xyzzy', states.FOO | states.BAZ)
self.assertEqual(mask, states.BAR)
def test_mask_zero(self):
states = State(0)
states.add('foo')
states.add('bar')
states.add('baz')
states.alias('all', states.FOO | states.BAR | states.BAZ)
mask = states.mask('xyzzy')
self.assertEqual(mask, states.ALL)
def test_remove(self):
states = State(1)
states.add('foo')
states.put('xyzzy', contents='plugh')
v = states.get('xyzzy')
self.assertEqual(v, 'plugh')
states.next('xyzzy')
v = states.state('xyzzy')
self.assertEqual(states.FOO, v)
states.purge('xyzzy')
with self.assertRaises(StateItemNotFound):
states.state('xyzzy')
def test_elements(self):
states = State(2)
states.add('foo')
states.add('bar')
states.alias('baz', states.FOO, states.BAR)
v = states.elements(states.BAZ)
self.assertIn('FOO', v)
self.assertIn('BAR', v)
self.assertIsInstance(v, str)
v = states.elements(states.BAZ, numeric=True)
self.assertIn(states.FOO, v)
self.assertIn(states.BAR, v)
v = states.elements(states.BAZ, as_string=False)
self.assertIn('FOO', v)
self.assertIn('BAR', v)
self.assertNotIsInstance(v, str)
self.assertIsInstance(v, list)
def test_count(self):
states = State(3)
states.add('foo')
states.add('bar')
self.assertEqual(states.count(), 2)
states.add('baz')
self.assertEqual(states.count(), 3)
def test_pure(self):
states = State(2)
states.add('foo')
states.add('bar')
states.alias('baz', states.FOO, states.BAR)
v = states.is_pure(states.BAZ)
self.assertFalse(v)
v = states.is_pure(states.FOO)
self.assertTrue(v)
def test_default(self):
states = State(2, default_state='FOO')
with self.assertRaises(StateItemNotFound):
states.state('NEW')
getattr(states, 'FOO')
states.state('FOO')
if __name__ == '__main__':
unittest.main()

View File

@ -21,7 +21,7 @@ class MockStore:
self.for_state = 0
def put(self, k, contents=None):
def add(self, k, contents=None):
self.v[k] = contents
@ -33,10 +33,6 @@ class MockStore:
return self.v[k]
def list(self):
return list(self.v.keys())
class TestStateItems(unittest.TestCase):
def setUp(self):

View File

@ -1,31 +0,0 @@
# standard imports
import unittest
# local imports
from shep import State
from shep.error import (
StateTransitionInvalid,
)
def mock_verify(state, from_state, to_state):
if from_state == state.FOO:
if to_state == state.BAR:
return 'bar cannot follow foo'
class TestState(unittest.TestCase):
def test_verify(self):
states = State(2, verifier=mock_verify)
states.add('foo')
states.add('bar')
states.put('xyzzy')
states.next('xyzzy')
with self.assertRaises(StateTransitionInvalid):
states.next('xyzzy')
if __name__ == '__main__':
unittest.main()