Compare commits

...

12 Commits

Author SHA1 Message Date
lash
f7d31e4e81
Update deps 2025-01-21 13:49:12 +00:00
lash
90ecec1798
rehabilitate storage test 2025-01-20 12:31:20 +00:00
lash
874edb3da6
Apply session on the menuhandler store returns 2025-01-20 12:09:09 +00:00
lash
60ff1b0ab3
Set up multiple conns in config 2025-01-19 18:26:39 +00:00
lash
9b3dad579b
Set up multiple conns in config 2025-01-19 15:04:49 +00:00
lash
348fff8936
Allow multiple db connections in menuservice 2025-01-19 15:00:31 +00:00
lash
c5bb1c80a5
Update deps, connbusy fix db postgres 2025-01-19 11:08:07 +00:00
lash
b8a377befb
Implement context for get and put provider 2025-01-19 10:38:41 +00:00
lash
c9b92191f3
Add missing contexts in request handler mocks 2025-01-19 09:40:10 +00:00
lash
ddd8d7cac0
implement missing context 2025-01-19 09:35:09 +00:00
lash
37973a6c9b
Add finish context to mockengine 2025-01-19 09:08:03 +00:00
lash
975720919c
Implement tx enabled db vise 2025-01-19 09:04:37 +00:00
16 changed files with 464 additions and 147 deletions

View File

@ -3,17 +3,21 @@ package config
import (
"strings"
"git.defalsify.org/vise.git/logging"
"git.grassecon.net/grassrootseconomics/visedriver/env"
"git.grassecon.net/grassrootseconomics/visedriver/storage"
)
var (
logg = logging.NewVanilla().WithDomain("visedriver-config")
defaultLanguage = "eng"
languages []string
)
var (
DbConn string
DefaultLanguage string
dbConn string
dbConnMissing bool
stateDbConn string
resourceDbConn string
userDbConn string
Languages []string
)
@ -35,13 +39,63 @@ func setLanguage() error {
return nil
}
func setConn() error {
DbConn = env.GetEnv("DB_CONN", "")
dbConn = env.GetEnv("DB_CONN", "?")
stateDbConn = env.GetEnv("DB_CONN_STATE", dbConn)
resourceDbConn = env.GetEnv("DB_CONN_RESOURCE", dbConn)
userDbConn = env.GetEnv("DB_CONN_USER", dbConn)
return nil
}
func ApplyConn(connStr *string, stateConnStr *string, resourceConnStr *string, userConnStr *string) {
if connStr != nil {
dbConn = *connStr
}
if stateConnStr != nil {
stateDbConn = *stateConnStr
}
if resourceConnStr != nil {
resourceDbConn = *resourceConnStr
}
if userConnStr != nil {
userDbConn = *userConnStr
}
if dbConn == "?" {
dbConn = ""
}
if stateDbConn == "?" {
stateDbConn = dbConn
}
if resourceDbConn == "?" {
resourceDbConn = dbConn
}
if userDbConn == "?" {
userDbConn = dbConn
}
}
func GetConns() (storage.Conns, error) {
o := storage.NewConns()
c, err := storage.ToConnData(stateDbConn)
if err != nil {
return o, err
}
o.Set(c, storage.STORETYPE_STATE)
c, err = storage.ToConnData(resourceDbConn)
if err != nil {
return o, err
}
o.Set(c, storage.STORETYPE_RESOURCE)
c, err = storage.ToConnData(userDbConn)
if err != nil {
return o, err
}
o.Set(c, storage.STORETYPE_USER)
return o, nil
}
// LoadConfig initializes the configuration values after environment variables are loaded.
func LoadConfig() error {
err := setConn()

2
go.mod
View File

@ -3,7 +3,7 @@ module git.grassecon.net/grassrootseconomics/visedriver
go 1.23.0
require (
git.defalsify.org/vise.git v0.2.3-0.20250114225117-3b5fc85b650b
git.defalsify.org/vise.git v0.2.3-0.20250120121301-10739fb4a8c9
github.com/jackc/pgx/v5 v5.7.1
github.com/joho/godotenv v1.5.1
)

4
go.sum
View File

@ -1,5 +1,5 @@
git.defalsify.org/vise.git v0.2.3-0.20250114225117-3b5fc85b650b h1:rwWXMtNSn7aqhb4p1oVZkCA1vC7pVdohwW61QQM8fUs=
git.defalsify.org/vise.git v0.2.3-0.20250114225117-3b5fc85b650b/go.mod h1:jyBMe1qTYUz3mmuoC9JQ/TvFeW0vTanCUcPu3H8p4Ck=
git.defalsify.org/vise.git v0.2.3-0.20250120121301-10739fb4a8c9 h1:sPcqXQcywxA8W3W+9qQncLPmsrgqTIlec7vmD4/7vyA=
git.defalsify.org/vise.git v0.2.3-0.20250120121301-10739fb4a8c9/go.mod h1:jyBMe1qTYUz3mmuoC9JQ/TvFeW0vTanCUcPu3H8p4Ck=
github.com/barbashov/iso639-3 v0.0.0-20211020172741-1f4ffb2d8d1c h1:H9Nm+I7Cg/YVPpEV1RzU3Wq2pjamPc/UtHDgItcb7lE=
github.com/barbashov/iso639-3 v0.0.0-20211020172741-1f4ffb2d8d1c/go.mod h1:rGod7o6KPeJ+hyBpHfhi4v7blx9sf+QsHsA7KAsdN6U=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=

View File

@ -1,6 +1,8 @@
package request
import (
"context"
"git.defalsify.org/vise.git/db"
"git.defalsify.org/vise.git/engine"
"git.defalsify.org/vise.git/persist"
@ -29,8 +31,8 @@ func NewBaseRequestHandler(cfg engine.Config, rs resource.Resource, stateDb db.D
}
}
func (f *BaseRequestHandler) Shutdown() {
err := f.provider.Close()
func (f *BaseRequestHandler) Shutdown(ctx context.Context) {
err := f.provider.Close(ctx)
if err != nil {
logg.Errorf("handler shutdown error", "err", err)
}
@ -49,7 +51,7 @@ func(f *BaseRequestHandler) Process(rqs RequestSession) (RequestSession, error)
logg.InfoCtxf(rqs.Ctx, "new request", "data", rqs)
rqs.Storage, err = f.provider.Get(rqs.Config.SessionId)
rqs.Storage, err = f.provider.Get(rqs.Ctx, rqs.Config.SessionId)
if err != nil {
logg.ErrorCtxf(rqs.Ctx, "", "storage get error", err)
return rqs, errors.ErrStorage
@ -63,7 +65,7 @@ func(f *BaseRequestHandler) Process(rqs RequestSession) (RequestSession, error)
eni := f.GetEngine(rqs.Config, f.rs, rqs.Storage.Persister)
en, ok := eni.(*engine.DefaultEngine)
if !ok {
perr := f.provider.Put(rqs.Config.SessionId, rqs.Storage)
perr := f.provider.Put(rqs.Ctx, rqs.Config.SessionId, rqs.Storage)
rqs.Storage = nil
if perr != nil {
logg.ErrorCtxf(rqs.Ctx, "", "storage put error", perr)
@ -78,7 +80,7 @@ func(f *BaseRequestHandler) Process(rqs RequestSession) (RequestSession, error)
r, err = rqs.Engine.Exec(rqs.Ctx, rqs.Input)
if err != nil {
perr := f.provider.Put(rqs.Config.SessionId, rqs.Storage)
perr := f.provider.Put(rqs.Ctx, rqs.Config.SessionId, rqs.Storage)
rqs.Storage = nil
if perr != nil {
logg.ErrorCtxf(rqs.Ctx, "", "storage put error", perr)
@ -96,9 +98,9 @@ func(f *BaseRequestHandler) Output(rqs RequestSession) (RequestSession, error)
return rqs, err
}
func(f *BaseRequestHandler) Reset(rqs RequestSession) (RequestSession, error) {
defer f.provider.Put(rqs.Config.SessionId, rqs.Storage)
return rqs, rqs.Engine.Finish()
func(f *BaseRequestHandler) Reset(ctx context.Context, rqs RequestSession) (RequestSession, error) {
defer f.provider.Put(ctx, rqs.Config.SessionId, rqs.Storage)
return rqs, rqs.Engine.Finish(ctx)
}
func (f *BaseRequestHandler) GetConfig() engine.Config {

View File

@ -80,7 +80,7 @@ func (hh *HTTPRequestHandler) ServeHTTP(w http.ResponseWriter, req *http.Request
w.WriteHeader(200)
w.Header().Set("Content-Type", "text/plain")
rqs, err = hh.Output(rqs)
rqs, perr = hh.Reset(rqs)
rqs, perr = hh.Reset(rqs.Ctx, rqs)
if err != nil {
hh.WriteError(w, 500, err)
return

View File

@ -88,7 +88,7 @@ func TestRequestHandler_ServeHTTP(t *testing.T) {
OutputFunc: func(rs request.RequestSession) (request.RequestSession, error) {
return rs, tt.outputErr
},
ResetFunc: func(rs request.RequestSession) (request.RequestSession, error) {
ResetFunc: func(ctx context.Context, rs request.RequestSession) (request.RequestSession, error) {
return rs, tt.resetErr
},
GetRequestParserFunc: func() request.RequestParser {

View File

@ -27,16 +27,16 @@ type RequestSession struct {
// TODO: seems like can remove this.
type RequestParser interface {
GetSessionId(ctx context.Context, rq any) (string, error)
GetInput(rq any) ([]byte, error)
GetSessionId(context.Context, any) (string, error)
GetInput(any) ([]byte, error)
}
type RequestHandler interface {
GetConfig() engine.Config
GetRequestParser() RequestParser
GetEngine(cfg engine.Config, rs resource.Resource, pe *persist.Persister) engine.Engine
Process(rs RequestSession) (RequestSession, error)
Output(rs RequestSession) (RequestSession, error)
Reset(rs RequestSession) (RequestSession, error)
Shutdown()
GetEngine(engine.Config, resource.Resource, *persist.Persister) engine.Engine
Process(RequestSession) (RequestSession, error)
Output(RequestSession) (RequestSession, error)
Reset(context.Context, RequestSession) (RequestSession, error)
Shutdown(ctx context.Context)
}

73
storage/conn.go Normal file
View File

@ -0,0 +1,73 @@
package storage
import (
"fmt"
"net/url"
)
const (
DBTYPE_NONE = iota
DBTYPE_MEM
DBTYPE_FS
DBTYPE_GDBM
DBTYPE_POSTGRES
)
const (
STORETYPE_STATE = iota
STORETYPE_RESOURCE
STORETYPE_USER
_STORETYPE_MAX
)
type Conns map[int8]ConnData
func NewConns() Conns {
c := make(Conns)
return c
}
func (c Conns) Set(conn ConnData, typ int8) {
if typ < 0 || typ >= _STORETYPE_MAX {
panic(fmt.Errorf("invalid store type: %d", typ))
}
c[typ] = conn
}
func (c Conns) Have(conn *ConnData) int8 {
for i := range(_STORETYPE_MAX) {
ii := int8(i)
v, ok := c[ii]
if !ok {
continue
}
if v.String() == conn.String() {
return ii
}
}
return -1
}
type ConnData struct {
typ int
str string
domain string
}
func (cd *ConnData) DbType() int {
return cd.typ
}
func (cd *ConnData) String() string {
return cd.str
}
func (cd *ConnData) Domain() string {
return cd.domain
}
func (cd *ConnData) Path() string {
v, _ := url.Parse(cd.str)
v.RawQuery = ""
return v.String()
}

View File

@ -111,11 +111,11 @@ func(tdb *ThreadGdbmDb) Get(ctx context.Context, key []byte) ([]byte, error) {
return v, err
}
func(tdb *ThreadGdbmDb) Close() error {
func(tdb *ThreadGdbmDb) Close(ctx context.Context) error {
tdb.reserve()
close(dbC[tdb.connStr])
delete(dbC, tdb.connStr)
err := tdb.db.Close()
err := tdb.db.Close(ctx)
tdb.db = nil
return err
}
@ -125,3 +125,23 @@ func(tdb *ThreadGdbmDb) Dump(ctx context.Context, key []byte) (*db.Dumper, error
defer tdb.release()
return tdb.db.Dump(ctx, key)
}
func(tdb *ThreadGdbmDb) DecodeKey(ctx context.Context, key []byte) ([]byte, error) {
return tdb.db.DecodeKey(ctx, key)
}
func(tdb *ThreadGdbmDb) Abort(ctx context.Context) {
tdb.db.Abort(ctx)
}
func(tdb *ThreadGdbmDb) Start(ctx context.Context) error {
return tdb.db.Start(ctx)
}
func(tdb *ThreadGdbmDb) Stop(ctx context.Context) error {
return tdb.db.Stop(ctx)
}
func(tdb *ThreadGdbmDb) Connection() string {
return tdb.db.Connection()
}

View File

@ -4,40 +4,9 @@ import (
"fmt"
"net/url"
"path"
"path/filepath"
)
const (
DBTYPE_NONE = iota
DBTYPE_MEM
DBTYPE_FS
DBTYPE_GDBM
DBTYPE_POSTGRES
)
type ConnData struct {
typ int
str string
domain string
}
func (cd *ConnData) DbType() int {
return cd.typ
}
func (cd *ConnData) String() string {
return cd.str
}
func (cd *ConnData) Domain() string {
return cd.domain
}
func (cd *ConnData) Path() string {
v, _ := url.Parse(cd.str)
v.RawQuery = ""
return v.String()
}
func probePostgres(s string) (string, string, bool) {
domain := "public"
v, err := url.Parse(s)
@ -68,9 +37,19 @@ func probeGdbm(s string) (string, string, bool) {
}
func probeFs(s string) (string, string, bool) {
if !path.IsAbs(s) {
var err error
v, _ := url.Parse(s)
if v.Scheme != "" && v.Scheme != "file://" {
return "", "", false
}
if !path.IsAbs(s) {
s, err = filepath.Abs(s)
if err != nil {
panic(err)
}
}
s = path.Clean(s)
return s, "", true
}
@ -85,11 +64,13 @@ func probeMem(s string) (string, string, bool) {
func ToConnData(connStr string) (ConnData, error) {
var o ConnData
if connStr == "" {
v, domain, ok := probeMem(connStr)
if ok {
o.typ = DBTYPE_MEM
return o, nil
}
v, domain, ok := probePostgres(connStr)
v, domain, ok = probePostgres(connStr)
if ok {
o.typ = DBTYPE_POSTGRES
o.str = v
@ -111,11 +92,5 @@ func ToConnData(connStr string) (ConnData, error) {
return o, nil
}
v, _, ok = probeMem(connStr)
if ok {
o.typ = DBTYPE_MEM
return o, nil
}
return o, fmt.Errorf("invalid connection string: %s", connStr)
}

View File

@ -5,24 +5,53 @@ import (
)
func TestParseConnStr(t *testing.T) {
_, err := ToConnData("postgres://foo:bar@localhost:5432/baz")
v, err := ToConnData("postgres://foo:bar@localhost:5432/baz")
if err != nil {
t.Fatal(err)
}
_, err = ToConnData("/foo/bar")
if v.DbType() != DBTYPE_POSTGRES {
t.Fatalf("expected type %v, got %v", DBTYPE_POSTGRES, v.DbType())
}
v, err = ToConnData("gdbm:///foo/bar")
if err != nil {
t.Fatal(err)
}
_, err = ToConnData("/foo/bar/")
if v.DbType() != DBTYPE_GDBM {
t.Fatalf("expected type %v, got %v", DBTYPE_GDBM, v.DbType())
}
v, err = ToConnData("/foo/bar")
if err != nil {
t.Fatal(err)
}
_, err = ToConnData("foo/bar")
if v.DbType() != DBTYPE_FS {
t.Fatalf("expected type %v, got %v", DBTYPE_FS, v.DbType())
}
v, err = ToConnData("/foo/bar/")
if err != nil {
t.Fatal(err)
}
if v.DbType() != DBTYPE_FS {
t.Fatalf("expected type %v, got %v", DBTYPE_FS, v.DbType())
}
v, err = ToConnData("foo/bar")
if err != nil {
t.Fatal(err)
}
if v.DbType() != DBTYPE_FS {
t.Fatalf("expected type %v, got %v", DBTYPE_FS, v.DbType())
}
v, err = ToConnData("")
if err != nil {
t.Fatal(err)
}
if v.DbType() != DBTYPE_MEM {
t.Fatalf("expected type %v, got %v", DBTYPE_MEM, v.DbType())
}
v, err = ToConnData("http://foo/bar")
if err == nil {
t.Fatalf("expected error")
}
_, err = ToConnData("http://foo/bar")
if err == nil {
t.Fatalf("expected error")
if v.DbType() != DBTYPE_NONE {
t.Fatalf("expected type %v, got %v", DBTYPE_NONE, v.DbType())
}
}

View File

@ -1,6 +1,8 @@
package storage
import (
"context"
"git.defalsify.org/vise.git/db"
"git.defalsify.org/vise.git/persist"
)
@ -14,10 +16,14 @@ type Storage struct {
UserdataDb db.Db
}
func (s *Storage) Close(ctx context.Context) error {
return s.UserdataDb.Close(ctx)
}
type StorageProvider interface {
Get(sessionId string) (*Storage, error)
Put(sessionId string, storage *Storage) error
Close() error
Get(ctx context.Context, sessionId string) (*Storage, error)
Put(ctx context.Context, sessionId string, storage *Storage) error
Close(ctx context.Context) error
}
type SimpleStorageProvider struct {
@ -35,14 +41,16 @@ func NewSimpleStorageProvider(stateStore db.Db, userdataStore db.Db) StorageProv
}
}
func (p *SimpleStorageProvider) Get(sessionId string) (*Storage, error) {
func (p *SimpleStorageProvider) Get(ctx context.Context, sessionId string) (*Storage, error) {
p.Storage.UserdataDb.Start(ctx)
return p.Storage, nil
}
func (p *SimpleStorageProvider) Put(sessionId string, storage *Storage) error {
func (p *SimpleStorageProvider) Put(ctx context.Context, sessionId string, storage *Storage) error {
storage.UserdataDb.Stop(ctx)
return nil
}
func (p *SimpleStorageProvider) Close() error {
return p.Storage.UserdataDb.Close()
func (p *SimpleStorageProvider) Close(ctx context.Context) error {
return p.Storage.Close(ctx)
}

View File

@ -2,6 +2,7 @@ package storage
import (
"context"
"errors"
"fmt"
"os"
"path"
@ -9,6 +10,7 @@ import (
"github.com/jackc/pgx/v5/pgxpool"
"git.defalsify.org/vise.git/db"
fsdb "git.defalsify.org/vise.git/db/fs"
memdb "git.defalsify.org/vise.git/db/mem"
"git.defalsify.org/vise.git/db/postgres"
"git.defalsify.org/vise.git/lang"
"git.defalsify.org/vise.git/logging"
@ -27,69 +29,95 @@ type StorageService interface {
GetResource(ctx context.Context) (resource.Resource, error)
}
// TODO: Support individual backend for each store (conndata)
type MenuStorageService struct {
conn ConnData
resourceDir string
conns Conns
poResource resource.Resource
resourceStore db.Db
stateStore db.Db
userDataStore db.Db
store map[int8]db.Db
}
func NewMenuStorageService(conn ConnData, resourceDir string) *MenuStorageService {
func NewMenuStorageService(conn Conns) *MenuStorageService {
return &MenuStorageService{
conn: conn,
resourceDir: resourceDir,
conns: conn,
store: make(map[int8]db.Db),
}
}
func (ms *MenuStorageService) WithResourceDir(resourceDir string) *MenuStorageService {
ms.resourceDir = resourceDir
func (ms *MenuStorageService) WithDb(store db.Db, typ int8) *MenuStorageService {
var err error
if ms.store[typ] != nil {
panic(fmt.Errorf("db already set for typ: %d", typ))
}
ms.store[typ] = store
ms.conns[typ], err = ToConnData(store.Connection())
if err != nil {
panic(err)
}
return ms
}
// TODO: allow fsdb, memdb
func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.Db, section string, typ string) (db.Db, error) {
var newDb db.Db
func (ms *MenuStorageService) checkDb(ctx context.Context,typ int8) db.Db {
store := ms.store[typ]
if store != nil {
return store
}
connData := ms.conns[typ]
v := ms.conns.Have(&connData)
if v == -1 {
return nil
}
src := ms.store[v]
if src == nil {
return nil
}
ms.store[typ] = ms.store[v]
logg.DebugCtxf(ctx, "found existing db", "typ", typ, "srctyp", v, "store", ms.store[typ], "srcstore", ms.store[v])
return ms.store[typ]
}
func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, section string, typ int8) (db.Db, error) {
var err error
if existingDb != nil {
return existingDb, nil
newDb := ms.checkDb(ctx, typ)
if newDb != nil {
logg.InfoCtxf(ctx, "using existing db", "typ", typ, "db", newDb)
return newDb, nil
}
connStr := ms.conn.String()
dbTyp := ms.conn.DbType()
connData := ms.conns[typ]
connStr := connData.String()
dbTyp := connData.DbType()
if dbTyp == DBTYPE_POSTGRES {
// TODO: move to vise
err = ensureSchemaExists(ctx, ms.conn)
err = ensureSchemaExists(ctx, connData)
if err != nil {
return nil, err
}
newDb = postgres.NewPgDb().WithSchema(ms.conn.Domain())
newDb = postgres.NewPgDb().WithSchema(connData.Domain())
} else if dbTyp == DBTYPE_GDBM {
err = ms.ensureDbDir()
err = ms.ensureDbDir(connStr)
if err != nil {
return nil, err
}
connStr = path.Join(connStr, section)
newDb = gdbmstorage.NewThreadGdbmDb()
} else if dbTyp == DBTYPE_FS {
err = ms.ensureDbDir()
err = ms.ensureDbDir(connStr)
if err != nil {
return nil, err
}
newDb = fsdb.NewFsDb().WithBinary()
} else if dbTyp == DBTYPE_MEM {
logg.WarnCtxf(ctx, "using volatile storage (memdb)")
newDb = memdb.NewMemDb()
} else {
return nil, fmt.Errorf("unsupported connection string: '%s'\n", ms.conn.String())
return nil, fmt.Errorf("unsupported connection string: '%s'\n", connData.String())
}
logg.DebugCtxf(ctx, "connecting to db", "conn", connStr, "conndata", ms.conn, "typ", typ)
logg.InfoCtxf(ctx, "connecting to db", "conn", connData, "typ", typ)
err = newDb.Connect(ctx, connStr)
if err != nil {
return nil, err
}
ms.store[typ] = newDb
return newDb, nil
}
@ -133,11 +161,24 @@ func ensureSchemaExists(ctx context.Context, conn ConnData) error {
return nil
}
func applySession(ctx context.Context, store db.Db) error {
sessionId, ok := ctx.Value("SessionId").(string)
if !ok {
return fmt.Errorf("missing session to apply to store: %v", store)
}
store.SetSession(sessionId)
return nil
}
func (ms *MenuStorageService) GetPersister(ctx context.Context) (*persist.Persister, error) {
stateStore, err := ms.GetStateStore(ctx)
if err != nil {
return nil, err
}
err = applySession(ctx, stateStore)
if err != nil {
return nil, err
}
pr := persist.NewPersister(stateStore)
logg.TraceCtxf(ctx, "menu storage service", "persist", pr, "store", stateStore)
@ -145,26 +186,24 @@ func (ms *MenuStorageService) GetPersister(ctx context.Context) (*persist.Persis
}
func (ms *MenuStorageService) GetUserdataDb(ctx context.Context) (db.Db, error) {
if ms.userDataStore != nil {
return ms.userDataStore, nil
}
userDataStore, err := ms.getOrCreateDb(ctx, ms.userDataStore, "userdata.gdbm", "userdata")
userStore, err := ms.getOrCreateDb(ctx, "userdata.gdbm", STORETYPE_USER)
if err != nil {
return nil, err
}
ms.userDataStore = userDataStore
return ms.userDataStore, nil
err = applySession(ctx, userStore)
if err != nil {
return nil, err
}
return userStore, nil
}
func (ms *MenuStorageService) GetResource(ctx context.Context) (resource.Resource, error) {
ms.resourceStore = fsdb.NewFsDb()
err := ms.resourceStore.Connect(ctx, ms.resourceDir)
store, err := ms.getOrCreateDb(ctx, "resource.gdbm", STORETYPE_RESOURCE)
if err != nil {
return nil, err
}
rfs := resource.NewDbResource(ms.resourceStore)
rfs := resource.NewDbResource(store)
if ms.poResource != nil {
logg.InfoCtxf(ctx, "using poresource for menu and template")
rfs.WithMenuGetter(ms.poResource.GetMenu)
@ -174,33 +213,34 @@ func (ms *MenuStorageService) GetResource(ctx context.Context) (resource.Resourc
}
func (ms *MenuStorageService) GetStateStore(ctx context.Context) (db.Db, error) {
if ms.stateStore != nil {
return ms.stateStore, nil
}
stateStore, err := ms.getOrCreateDb(ctx, ms.stateStore, "state.gdbm", "state")
if err != nil {
return nil, err
}
ms.stateStore = stateStore
return ms.stateStore, nil
return ms.getOrCreateDb(ctx, "state.gdbm", STORETYPE_STATE)
}
func (ms *MenuStorageService) ensureDbDir() error {
err := os.MkdirAll(ms.conn.String(), 0700)
func (ms *MenuStorageService) ensureDbDir(path string) error {
err := os.MkdirAll(path, 0700)
if err != nil {
return fmt.Errorf("state dir create exited with error: %v\n", err)
return fmt.Errorf("store dir create exited with error: %v\n", err)
}
return nil
}
func (ms *MenuStorageService) Close() error {
errA := ms.stateStore.Close()
errB := ms.userDataStore.Close()
errC := ms.resourceStore.Close()
if errA != nil || errB != nil || errC != nil {
return fmt.Errorf("%v %v %v", errA, errB, errC)
// TODO: how to handle persister here?
func (ms *MenuStorageService) Close(ctx context.Context) error {
var errs []error
var haveErr bool
for i := range(_STORETYPE_MAX) {
err := ms.store[int8(i)].Close(ctx)
if err != nil {
haveErr = true
}
errs = append(errs, err)
}
if haveErr {
errStr := ""
for i, err := range(errs) {
errStr += fmt.Sprintf("(%d: %v)", i, err)
}
return errors.New(errStr)
}
return nil
}

View File

@ -0,0 +1,114 @@
package storage
import (
"context"
"os"
"testing"
fsdb "git.defalsify.org/vise.git/db/fs"
)
func TestMenuStorageServiceOneSet(t *testing.T) {
d, err := os.MkdirTemp("", "visedriver-menustorageservice")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(d)
conns := NewConns()
connData, err := ToConnData(d)
if err != nil {
t.Fatal(err)
}
conns.Set(connData, STORETYPE_STATE)
ctx := context.Background()
ms := NewMenuStorageService(conns)
_, err = ms.GetStateStore(ctx)
if err != nil {
t.Fatal(err)
}
_, err = ms.GetResource(ctx)
if err == nil {
t.Fatalf("expected error getting resource")
}
_, err = ms.GetUserdataDb(ctx)
if err == nil {
t.Fatalf("expected error getting userdata")
}
}
func TestMenuStorageServiceExplicit(t *testing.T) {
d, err := os.MkdirTemp("", "visedriver-menustorageservice")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(d)
conns := NewConns()
connData, err := ToConnData(d)
if err != nil {
t.Fatal(err)
}
conns.Set(connData, STORETYPE_STATE)
ctx := context.Background()
d, err = os.MkdirTemp("", "visedriver-menustorageservice")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(d)
store := fsdb.NewFsDb()
err = store.Connect(ctx, d)
if err != nil {
t.Fatal(err)
}
ms := NewMenuStorageService(conns)
ms = ms.WithDb(store, STORETYPE_RESOURCE)
_, err = ms.GetStateStore(ctx)
if err != nil {
t.Fatal(err)
}
_, err = ms.GetResource(ctx)
if err != nil {
t.Fatal(err)
}
_, err = ms.GetUserdataDb(ctx)
if err == nil {
t.Fatalf("expected error getting userdata")
}
}
func TestMenuStorageServiceReuse(t *testing.T) {
d, err := os.MkdirTemp("", "visedriver-menustorageservice")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(d)
conns := NewConns()
connData, err := ToConnData(d)
if err != nil {
t.Fatal(err)
}
conns.Set(connData, STORETYPE_STATE)
conns.Set(connData, STORETYPE_USER)
ctx := context.Background()
ctx = context.WithValue(ctx, "SessionId", "foo")
ms := NewMenuStorageService(conns)
stateStore, err := ms.GetStateStore(ctx)
if err != nil {
t.Fatal(err)
}
_, err = ms.GetResource(ctx)
if err == nil {
t.Fatalf("expected error getting resource")
}
userStore, err := ms.GetUserdataDb(ctx)
if err != nil {
t.Fatal(err)
}
if userStore != stateStore {
t.Fatalf("expected same store, but they are %p and %p", userStore, stateStore)
}
}

View File

@ -10,7 +10,7 @@ type MockEngine struct {
InitFunc func(context.Context) (bool, error)
ExecFunc func(context.Context, []byte) (bool, error)
FlushFunc func(context.Context, io.Writer) (int, error)
FinishFunc func() error
FinishFunc func(context.Context) error
}
func (m *MockEngine) Init(ctx context.Context) (bool, error) {
@ -25,6 +25,6 @@ func (m *MockEngine) Flush(ctx context.Context, w io.Writer) (int, error) {
return m.FlushFunc(ctx, w)
}
func (m *MockEngine) Finish() error {
return m.FinishFunc()
func (m *MockEngine) Finish(ctx context.Context) error {
return m.FinishFunc(ctx)
}

View File

@ -1,6 +1,8 @@
package httpmocks
import (
"context"
"git.defalsify.org/vise.git/engine"
"git.defalsify.org/vise.git/persist"
"git.defalsify.org/vise.git/resource"
@ -13,8 +15,8 @@ type MockRequestHandler struct {
GetConfigFunc func() engine.Config
GetEngineFunc func(cfg engine.Config, rs resource.Resource, pe *persist.Persister) engine.Engine
OutputFunc func(rs request.RequestSession) (request.RequestSession, error)
ResetFunc func(rs request.RequestSession) (request.RequestSession, error)
ShutdownFunc func()
ResetFunc func(ctx context.Context, rs request.RequestSession) (request.RequestSession, error)
ShutdownFunc func(ctx context.Context)
GetRequestParserFunc func() request.RequestParser
}
@ -34,12 +36,12 @@ func (m *MockRequestHandler) Output(rs request.RequestSession) (request.RequestS
return m.OutputFunc(rs)
}
func (m *MockRequestHandler) Reset(rs request.RequestSession) (request.RequestSession, error) {
return m.ResetFunc(rs)
func (m *MockRequestHandler) Reset(ctx context.Context, rs request.RequestSession) (request.RequestSession, error) {
return m.ResetFunc(ctx, rs)
}
func (m *MockRequestHandler) Shutdown() {
m.ShutdownFunc()
func (m *MockRequestHandler) Shutdown(ctx context.Context) {
m.ShutdownFunc(ctx)
}
func (m *MockRequestHandler) GetRequestParser() request.RequestParser {