diff --git a/go/state/state.go b/go/state/state.go index 7feb550..c8c9087 100644 --- a/go/state/state.go +++ b/go/state/state.go @@ -13,7 +13,8 @@ type State struct { CacheMap map[string]string ExecPath []string Arg *string - Idx uint16 + sizes map[string]uint16 + idx uint16 } func NewState(bitSize uint64) State { @@ -63,6 +64,7 @@ func(st *State) Down(input string) { m := make(map[string]string) st.Cache = append(st.Cache, m) st.CacheMap = make(map[string]string) + st.sizes = make(map[string]uint16) st.ExecPath = append(st.ExecPath, input) } @@ -84,10 +86,18 @@ func(st *State) Add(key string, value string, sizeHint uint16) error { log.Printf("add key %s value size %v", key, sz) st.Cache[len(st.Cache)-1][key] = value st.CacheUseSize += sz + st.sizes[key] = sizeHint return nil } func(st *State) Update(key string, value string) error { + sizeHint := st.sizes[key] + if st.sizes[key] > 0 { + l := uint16(len(value)) + if l > sizeHint { + return fmt.Errorf("update value length %v exceeds value size limit %v", l, sizeHint) + } + } checkFrame := st.frameOf(key) if checkFrame == -1 { return fmt.Errorf("key %v not defined", key) @@ -167,6 +177,16 @@ func(st *State) Check(key string) bool { return st.frameOf(key) == -1 } +func(st *State) Size() (uint32, uint32) { + var l int + var c uint16 + for k, v := range st.CacheMap { + l += len(v) + c += st.sizes[k] + } + return uint32(l), uint32(c) +} + // return 0-indexed frame number where key is defined. -1 if not defined func(st *State) frameOf(key string) int { log.Printf("--- %s", key) diff --git a/go/state/state_test.go b/go/state/state_test.go index 4183279..8697a9c 100644 --- a/go/state/state_test.go +++ b/go/state/state_test.go @@ -113,3 +113,36 @@ func TestStateLoadDup(t *testing.T) { t.Errorf("expected fail on duplicate load") } } + +func TestStateCurrentSize(t *testing.T) { + st := NewState(17) + st.Down("one") + err := st.Add("foo", "bar", 0) + if err != nil { + t.Error(err) + } + st.Down("two") + err = st.Add("bar", "xyzzy", 10) + if err != nil { + t.Error(err) + } + err = st.Map("bar") + if err != nil { + t.Error(err) + } + err = st.Add("baz", "inkypinkyblinkyclyde", 40) + if err != nil { + t.Error(err) + } + err = st.Map("baz") + if err != nil { + t.Error(err) + } + l, c := st.Size() + if l != 25 { + t.Errorf("expected actual length 25, got %v", l) + } + if c != 50 { + t.Errorf("expected actual length 50, got %v", c) + } +}