Complete base and alias inserts

This commit is contained in:
lash 2022-01-31 08:38:14 +00:00
parent 320f56b3c1
commit 78a3df73bb
Signed by: lash
GPG Key ID: 21D2E7BB88C2A746
3 changed files with 108 additions and 11 deletions

View File

@ -1,2 +1,6 @@
class StateExists(Exception): class StateExists(Exception):
pass pass
class StateInvalid(Exception):
pass

View File

@ -3,7 +3,10 @@ import enum
import logging import logging
# local imports # local imports
from .error import StateExists from schiz.error import (
StateExists,
StateInvalid,
)
logg = logging.getLogger(__name__) logg = logging.getLogger(__name__)
@ -12,27 +15,84 @@ class State:
def __init__(self, bits): def __init__(self, bits):
self.__bits = bits self.__bits = bits
self.__limit = (1 << bits) - 1
self.__c = 0 self.__c = 0
self.__reverse = {} self.__reverse = {}
def _persist(self): def __store(self):
pass pass
def __is_pure(self, v):
c = 1
for i in range(self.__bits):
if c & v > 0:
break
c <<= 1
return c == v
def add(self, name):
if self.__c == self.__bits:
raise OverflowError(self.__c + 1)
v = 1 << self.__c
k = name.upper()
def __check_name(self, k):
k = k.upper()
try: try:
getattr(self, k) getattr(self, k)
raise StateExists(k) raise StateExists(k)
except AttributeError: except AttributeError:
pass pass
return k
def __check_cover(self, v):
z = 0
c = 1
for i in range(self.__bits):
if c & v > 0:
if self.__reverse.get(c) == None:
raise StateInvalid(v)
c <<= 1
return c == v
def __check_value(self, v):
v = int(v)
if self.__reverse.get(v):
raise StateValueExists(v)
if v > self.__limit:
raise OverflowError(v)
return v
def __check(self, k, v):
k = self.__check_name(k)
v = self.__check_value(v)
return (k, v,)
def __set(self, k, v):
setattr(self, k, v) setattr(self, k, v)
self.__reverse[v] = k
self.__c += 1 self.__c += 1
def add(self, k):
v = 1 << self.__c
(k, v) = self.__check(k, v)
self.__set(k, v)
def alias(self, k, v):
(k, v) = self.__check(k, v)
if self.__is_pure(v):
raise ValueError('use add to add pure values')
self.__check_cover(v)
self.__set(k, v)
# def all(self):
# l = []
# for k in dir(self):
# if k[0] == '_':
# continue
# if k.upper() != k:
# continue
# l.append(k)

View File

@ -3,7 +3,10 @@ import unittest
# local imports # local imports
from schiz import State from schiz import State
from schiz.error import StateExists from schiz.error import (
StateExists,
StateInvalid,
)
class TestState(unittest.TestCase): class TestState(unittest.TestCase):
@ -31,5 +34,35 @@ class TestState(unittest.TestCase):
states.add('foo') states.add('foo')
def test_alias(self):
states = State(2)
states.add('foo')
states.add('bar')
states.alias('baz', states.FOO | states.BAR)
self.assertEqual(states.BAZ, 3)
def test_alias_limit(self):
states = State(2)
states.add('foo')
states.add('bar')
states.alias('baz', states.FOO | states.BAR)
def test_alias_nopure(self):
states = State(3)
with self.assertRaises(ValueError):
states.alias('foo', 4)
def test_alias_cover(self):
states = State(3)
states.add('foo')
states.add('bar')
with self.assertRaises(StateInvalid):
states.alias('baz', 5)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()