diff --git a/shep/persist.py b/shep/persist.py index 796ca52..08e3d5d 100644 --- a/shep/persist.py +++ b/shep/persist.py @@ -15,12 +15,12 @@ class PersistedState(State): self.__stores[k] = self.__store_factory(k) - def put(self, key, contents=None, state=None): + def put(self, key, contents=None, state=None, force=False): k = self.name(state) self.__ensure_store(k) - self.__stores[k].add(key, contents) + self.__stores[k].add(key, contents, force=force) - super(PersistedState, self).put(key, state=state, contents=contents) + super(PersistedState, self).put(key, state=state, contents=contents, force=force) def move(self, key, to_state): @@ -49,8 +49,8 @@ class PersistedState(State): super(PersistedState, self).purge(key) - def get(self, key): + def get(self, key=None): state = self.state(key) k = self.name(state) - self.__stores[k].get(k) + return self.__stores[k].get(key) diff --git a/shep/state.py b/shep/state.py index 3204bdb..a2d6557 100644 --- a/shep/state.py +++ b/shep/state.py @@ -152,15 +152,20 @@ class State: return (alias, r,) - - def put(self, key, state=None, contents=None): + + def put(self, key, state=None, contents=None, force=False): if state == None: state = self.NEW elif self.__reverse.get(state) == None: raise StateInvalid(state) - self.__check_key(key) + try: + self.__check_key(key) + except StateItemExists as e: + if not force: + raise(e) self.__add_state_list(state, key) - self.__contents[key] = contents + if contents != None: + self.__contents[key] = contents def move(self, key, to_state): @@ -245,4 +250,4 @@ class State: def get(self, key): - return self.__contents[key] + return self.__contents.get(key) diff --git a/shep/store/file.py b/shep/store/file.py index bbde70c..7e9adb0 100644 --- a/shep/store/file.py +++ b/shep/store/file.py @@ -9,20 +9,37 @@ class SimpleFileStore: os.makedirs(self.path, exist_ok=True) - def add(self, v, contents): - if contents == None: - contents = '' - fp = os.path.join(self.path, v) + def add(self, k, contents=None, force=False): + fp = os.path.join(self.path, k) + have_file = False try: os.stat(fp) - raise FileExistsError(fp) + have_file = True except FileNotFoundError: pass + + if have_file: + if not force: + raise FileExistsError(fp) + if contents == None: + raise FileExistsError('will not overwrite empty content on existing file {}. Use rm then add instead'.format(fp)) + elif contents == None: + contents = '' + + print('wriging {}'.format(fp)) f = open(fp, 'w') f.write(contents) f.close() + def get(self, k): + fp = os.path.join(self.path, k) + f = open(fp, 'r') + r = f.read() + f.close() + return r + + class SimpleFileStoreFactory: def __init__(self, path): diff --git a/tests/test_file.py b/tests/test_file.py index 072e27f..fa94b21 100644 --- a/tests/test_file.py +++ b/tests/test_file.py @@ -37,6 +37,18 @@ class TestStateReport(unittest.TestCase): with self.assertRaises(FileExistsError): self.states.put('abcd', state=self.states.FOO) + with self.assertRaises(FileExistsError): + self.states.put('abcd', state=self.states.FOO, force=True) + + self.states.put('abcd', contents='foo', state=self.states.FOO, force=True) + self.assertEqual(self.states.get('abcd'), 'foo') + + with self.assertRaises(FileExistsError): + self.states.put('abcd', state=self.states.FOO, force=True) + + self.states.put('abcd', contents='bar', state=self.states.FOO, force=True) + self.assertEqual(self.states.get('abcd'), 'bar') + if __name__ == '__main__': unittest.main() diff --git a/tests/test_item.py b/tests/test_item.py index 35bcef1..9acbb5e 100644 --- a/tests/test_item.py +++ b/tests/test_item.py @@ -76,7 +76,7 @@ class TestStateItems(unittest.TestCase): def test_item_get(self): item = b'foo' - self.states.put(item, self.states.BAZ, contents='bar') + self.states.put(item, state=self.states.BAZ, contents='bar') self.assertEqual(self.states.state(item), self.states.BAZ) v = self.states.get(item) self.assertEqual(v, 'bar') @@ -84,26 +84,26 @@ class TestStateItems(unittest.TestCase): def test_item_set(self): item = b'foo' - self.states.put(item, self.states.FOO) + self.states.put(item, state=self.states.FOO) self.states.set(item, self.states.BAR) self.assertEqual(self.states.state(item), self.states.PLUGH) def test_item_set_invalid(self): item = b'foo' - self.states.put(item, self.states.FOO) + self.states.put(item, state=self.states.FOO) with self.assertRaises(StateInvalid): self.states.set(item, self.states.BAZ) item = b'bar' - self.states.put(item, self.states.BAR) + self.states.put(item, state=self.states.BAR) with self.assertRaises(ValueError): self.states.set(item, self.states.XYZZY) def test_item_set_invalid(self): item = b'foo' - self.states.put(item, self.states.XYZZY) + self.states.put(item, state=self.states.XYZZY) self.states.unset(item, self.states.BAZ) self.assertEqual(self.states.state(item), self.states.BAR) @@ -116,5 +116,22 @@ class TestStateItems(unittest.TestCase): self.states.unset(item, self.states.FOO) # bit not set + def test_item_force(self): + item = b'foo' + self.states.put(item, state=self.states.XYZZY) + + contents = 'xyzzy' + self.states.put(item, state=self.states.XYZZY, contents=contents, force=True) + self.assertEqual(self.states.get(item), 'xyzzy') + + contents = None + self.states.put(item, state=self.states.XYZZY, contents=contents, force=True) + self.assertEqual(self.states.get(item), 'xyzzy') + + contents = 'plugh' + self.states.put(item, state=self.states.XYZZY, contents=contents, force=True) + self.assertEqual(self.states.get(item), 'plugh') + + if __name__ == '__main__': unittest.main() diff --git a/tests/test_store.py b/tests/test_store.py index 9eb992b..e1aad88 100644 --- a/tests/test_store.py +++ b/tests/test_store.py @@ -21,7 +21,7 @@ class MockStore: self.for_state = 0 - def add(self, k, contents): + def add(self, k, contents=None, force=False): self.v[k] = contents