Merge branch 'master' into copy-language-code

This commit is contained in:
Carlosokumu 2025-01-10 11:06:32 +03:00
commit 2b9c6d641e
Signed by untrusted user: carlos
GPG Key ID: 7BD6BC8160A5C953
22 changed files with 433 additions and 250 deletions

View File

@ -6,13 +6,9 @@ HOST=127.0.0.1
AT_ENDPOINT=/ussd/africastalking AT_ENDPOINT=/ussd/africastalking
#PostgreSQL #PostgreSQL
DB_HOST=localhost DB_CONN=postgres://postgres:strongpass@localhost:5432/urdt_ussd
DB_USER=postgres #DB_TIMEZONE=Africa/Nairobi
DB_PASSWORD=strongpass #DB_SCHEMA=vise
DB_NAME=urdt_ussd
DB_PORT=5432
DB_SSLMODE=disable
DB_TIMEZONE=Africa/Nairobi
#External API Calls #External API Calls
CUSTODIAL_URL_BASE=http://localhost:5003 CUSTODIAL_URL_BASE=http://localhost:5003

View File

@ -20,7 +20,6 @@ import (
"git.grassecon.net/urdt/ussd/initializers" "git.grassecon.net/urdt/ussd/initializers"
"git.grassecon.net/urdt/ussd/internal/handlers" "git.grassecon.net/urdt/ussd/internal/handlers"
"git.grassecon.net/urdt/ussd/internal/http/at" "git.grassecon.net/urdt/ussd/internal/http/at"
httpserver "git.grassecon.net/urdt/ussd/internal/http/at"
"git.grassecon.net/urdt/ussd/internal/storage" "git.grassecon.net/urdt/ussd/internal/storage"
"git.grassecon.net/urdt/ussd/remote" "git.grassecon.net/urdt/ussd/remote"
"git.grassecon.net/urdt/ussd/internal/args" "git.grassecon.net/urdt/ussd/internal/args"
@ -36,21 +35,23 @@ var (
func init() { func init() {
initializers.LoadEnvVariables() initializers.LoadEnvVariables()
} }
func main() { func main() {
config.LoadConfig() config.LoadConfig()
var dbDir string var connStr string
var resourceDir string var resourceDir string
var size uint var size uint
var database string var database string
var engineDebug bool var engineDebug bool
var host string var host string
var port uint var port uint
var err error
var gettextDir string var gettextDir string
var langs args.LangVar var langs args.LangVar
flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from")
flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir") flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir")
flag.StringVar(&database, "db", "gdbm", "database to be used") flag.StringVar(&connStr, "c", "", "connection string")
flag.BoolVar(&engineDebug, "d", false, "use engine debug output") flag.BoolVar(&engineDebug, "d", false, "use engine debug output")
flag.UintVar(&size, "s", 160, "max size of output") flag.UintVar(&size, "s", 160, "max size of output")
flag.StringVar(&host, "h", initializers.GetEnv("HOST", "127.0.0.1"), "http host") flag.StringVar(&host, "h", initializers.GetEnv("HOST", "127.0.0.1"), "http host")
@ -59,7 +60,16 @@ func main() {
flag.Var(&langs, "language", "add symbol resolution for language") flag.Var(&langs, "language", "add symbol resolution for language")
flag.Parse() flag.Parse()
logg.Infof("start command", "build", build, "dbdir", dbDir, "resourcedir", resourceDir, "outputsize", size) if connStr != "" {
connStr = config.DbConn
}
connData, err := storage.ToConnData(connStr)
if err != nil {
fmt.Fprintf(os.Stderr, "connstr err: %v", err)
os.Exit(1)
}
logg.Infof("start command", "build", build, "conn", connData, "resourcedir", resourceDir, "outputsize", size)
ctx := context.Background() ctx := context.Background()
ctx = context.WithValue(ctx, "Database", database) ctx = context.WithValue(ctx, "Database", database)
@ -83,14 +93,13 @@ func main() {
cfg.EngineDebug = true cfg.EngineDebug = true
} }
menuStorageService := storage.NewMenuStorageService(dbDir, resourceDir) menuStorageService := storage.NewMenuStorageService(connData, resourceDir)
rs, err := menuStorageService.GetResource(ctx)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, err.Error()) fmt.Fprintf(os.Stderr, err.Error())
os.Exit(1) os.Exit(1)
} }
err = menuStorageService.EnsureDbDir() rs, err := menuStorageService.GetResource(ctx)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, err.Error()) fmt.Fprintf(os.Stderr, err.Error())
os.Exit(1) os.Exit(1)
@ -136,7 +145,7 @@ func main() {
rp := &at.ATRequestParser{} rp := &at.ATRequestParser{}
bsh := handlers.NewBaseSessionHandler(cfg, rs, stateStore, userdataStore, rp, hl) bsh := handlers.NewBaseSessionHandler(cfg, rs, stateStore, userdataStore, rp, hl)
sh := httpserver.NewATSessionHandler(bsh) sh := at.NewATSessionHandler(bsh)
mux := http.NewServeMux() mux := http.NewServeMux()
mux.Handle(initializers.GetEnv("AT_ENDPOINT", "/"), sh) mux.Handle(initializers.GetEnv("AT_ENDPOINT", "/"), sh)

View File

@ -48,20 +48,21 @@ func (p *asyncRequestParser) GetInput(r any) ([]byte, error) {
func main() { func main() {
config.LoadConfig() config.LoadConfig()
var connStr string
var sessionId string var sessionId string
var dbDir string
var resourceDir string var resourceDir string
var size uint var size uint
var database string var database string
var engineDebug bool var engineDebug bool
var host string var host string
var port uint var port uint
var err error
var gettextDir string var gettextDir string
var langs args.LangVar var langs args.LangVar
flag.StringVar(&sessionId, "session-id", "075xx2123", "session id") flag.StringVar(&sessionId, "session-id", "075xx2123", "session id")
flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from")
flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir") flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir")
flag.StringVar(&database, "db", "gdbm", "database to be used") flag.StringVar(&connStr, "c", "", "connection string")
flag.BoolVar(&engineDebug, "d", false, "use engine debug output") flag.BoolVar(&engineDebug, "d", false, "use engine debug output")
flag.UintVar(&size, "s", 160, "max size of output") flag.UintVar(&size, "s", 160, "max size of output")
flag.StringVar(&host, "h", initializers.GetEnv("HOST", "127.0.0.1"), "http host") flag.StringVar(&host, "h", initializers.GetEnv("HOST", "127.0.0.1"), "http host")
@ -70,7 +71,16 @@ func main() {
flag.Var(&langs, "language", "add symbol resolution for language") flag.Var(&langs, "language", "add symbol resolution for language")
flag.Parse() flag.Parse()
logg.Infof("start command", "dbdir", dbDir, "resourcedir", resourceDir, "outputsize", size, "sessionId", sessionId) if connStr != "" {
connStr = config.DbConn
}
connData, err := storage.ToConnData(connStr)
if err != nil {
fmt.Fprintf(os.Stderr, "connstr err: %v", err)
os.Exit(1)
}
logg.Infof("start command", "conn", connData, "resourcedir", resourceDir, "outputsize", size, "sessionId", sessionId)
ctx := context.Background() ctx := context.Background()
ctx = context.WithValue(ctx, "Database", database) ctx = context.WithValue(ctx, "Database", database)
@ -95,18 +105,18 @@ func main() {
cfg.EngineDebug = true cfg.EngineDebug = true
} }
menuStorageService := storage.NewMenuStorageService(dbDir, resourceDir) menuStorageService := storage.NewMenuStorageService(connData, resourceDir)
if err != nil {
fmt.Fprintf(os.Stderr, err.Error())
os.Exit(1)
}
rs, err := menuStorageService.GetResource(ctx) rs, err := menuStorageService.GetResource(ctx)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, err.Error()) fmt.Fprintf(os.Stderr, err.Error())
os.Exit(1) os.Exit(1)
} }
err = menuStorageService.EnsureDbDir()
if err != nil {
fmt.Fprintf(os.Stderr, err.Error())
os.Exit(1)
}
userdataStore, err := menuStorageService.GetUserdataDb(ctx) userdataStore, err := menuStorageService.GetUserdataDb(ctx)
if err != nil { if err != nil {

View File

@ -38,18 +38,19 @@ func init() {
func main() { func main() {
config.LoadConfig() config.LoadConfig()
var dbDir string var connStr string
var resourceDir string var resourceDir string
var size uint var size uint
var database string var database string
var engineDebug bool var engineDebug bool
var host string var host string
var port uint var port uint
var err error
var gettextDir string var gettextDir string
var langs args.LangVar var langs args.LangVar
flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from")
flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir") flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir")
flag.StringVar(&database, "db", "gdbm", "database to be used") flag.StringVar(&connStr, "c", "", "connection string")
flag.BoolVar(&engineDebug, "d", false, "use engine debug output") flag.BoolVar(&engineDebug, "d", false, "use engine debug output")
flag.UintVar(&size, "s", 160, "max size of output") flag.UintVar(&size, "s", 160, "max size of output")
flag.StringVar(&host, "h", initializers.GetEnv("HOST", "127.0.0.1"), "http host") flag.StringVar(&host, "h", initializers.GetEnv("HOST", "127.0.0.1"), "http host")
@ -58,7 +59,16 @@ func main() {
flag.Var(&langs, "language", "add symbol resolution for language") flag.Var(&langs, "language", "add symbol resolution for language")
flag.Parse() flag.Parse()
logg.Infof("start command", "dbdir", dbDir, "resourcedir", resourceDir, "outputsize", size) if connStr != "" {
connStr = config.DbConn
}
connData, err := storage.ToConnData(connStr)
if err != nil {
fmt.Fprintf(os.Stderr, "connstr err: %v", err)
os.Exit(1)
}
logg.Infof("start command", "conn", connData, "resourcedir", resourceDir, "outputsize", size)
ctx := context.Background() ctx := context.Background()
ctx = context.WithValue(ctx, "Database", database) ctx = context.WithValue(ctx, "Database", database)
@ -83,14 +93,9 @@ func main() {
cfg.EngineDebug = true cfg.EngineDebug = true
} }
menuStorageService := storage.NewMenuStorageService(dbDir, resourceDir) menuStorageService := storage.NewMenuStorageService(connData, resourceDir)
rs, err := menuStorageService.GetResource(ctx)
if err != nil {
fmt.Fprintf(os.Stderr, err.Error())
os.Exit(1)
}
err = menuStorageService.EnsureDbDir() rs, err := menuStorageService.GetResource(ctx)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, err.Error()) fmt.Fprintf(os.Stderr, err.Error())
os.Exit(1) os.Exit(1)

View File

@ -33,23 +33,35 @@ func init() {
func main() { func main() {
config.LoadConfig() config.LoadConfig()
var dbDir string var connStr string
var size uint var size uint
var sessionId string var sessionId string
var database string var database string
var engineDebug bool var engineDebug bool
var resourceDir string
var err error
var gettextDir string var gettextDir string
var langs args.LangVar var langs args.LangVar
flag.StringVar(&resourceDir, "resourcedir", scriptDir, "resource dir")
flag.StringVar(&sessionId, "session-id", "075xx2123", "session id") flag.StringVar(&sessionId, "session-id", "075xx2123", "session id")
flag.StringVar(&database, "db", "gdbm", "database to be used") flag.StringVar(&connStr, "c", "", "connection string")
flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from")
flag.BoolVar(&engineDebug, "d", false, "use engine debug output") flag.BoolVar(&engineDebug, "d", false, "use engine debug output")
flag.UintVar(&size, "s", 160, "max size of output") flag.UintVar(&size, "s", 160, "max size of output")
flag.StringVar(&gettextDir, "gettext", "", "use gettext translations from given directory") flag.StringVar(&gettextDir, "gettext", "", "use gettext translations from given directory")
flag.Var(&langs, "language", "add symbol resolution for language") flag.Var(&langs, "language", "add symbol resolution for language")
flag.Parse() flag.Parse()
logg.Infof("start command", "dbdir", dbDir, "outputsize", size) if connStr != "" {
connStr = config.DbConn
}
connData, err := storage.ToConnData(connStr)
if err != nil {
fmt.Fprintf(os.Stderr, "connstr err: %v", err)
os.Exit(1)
}
logg.Infof("start command", "conn", connData, "outputsize", size)
if len(langs.Langs()) == 0 { if len(langs.Langs()) == 0 {
langs.Set(config.DefaultLanguage) langs.Set(config.DefaultLanguage)
@ -76,18 +88,12 @@ func main() {
MenuSeparator: menuSeparator, MenuSeparator: menuSeparator,
} }
resourceDir := scriptDir menuStorageService := storage.NewMenuStorageService(connData, resourceDir)
menuStorageService := storage.NewMenuStorageService(dbDir, resourceDir)
if gettextDir != "" { if gettextDir != "" {
menuStorageService = menuStorageService.WithGettext(gettextDir, langs.Langs()) menuStorageService = menuStorageService.WithGettext(gettextDir, langs.Langs())
} }
err = menuStorageService.EnsureDbDir()
if err != nil {
fmt.Fprintf(os.Stderr, err.Error())
os.Exit(1)
}
rs, err := menuStorageService.GetResource(ctx) rs, err := menuStorageService.GetResource(ctx)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, err.Error()) fmt.Fprintf(os.Stderr, err.Error())

View File

@ -14,7 +14,10 @@ import (
"git.defalsify.org/vise.git/engine" "git.defalsify.org/vise.git/engine"
"git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/logging"
"git.grassecon.net/urdt/ussd/config"
"git.grassecon.net/urdt/ussd/initializers"
"git.grassecon.net/urdt/ussd/internal/ssh" "git.grassecon.net/urdt/ussd/internal/ssh"
"git.grassecon.net/urdt/ussd/internal/storage"
) )
var ( var (
@ -26,25 +29,49 @@ var (
build = "dev" build = "dev"
) )
func init() {
initializers.LoadEnvVariables()
}
func main() { func main() {
var dbDir string config.LoadConfig()
var connStr string
var authConnStr string
var resourceDir string var resourceDir string
var size uint var size uint
var engineDebug bool var engineDebug bool
var stateDebug bool var stateDebug bool
var host string var host string
var port uint var port uint
flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from") flag.StringVar(&connStr, "c", "", "connection string")
flag.StringVar(&authConnStr, "authdb", "", "auth connection string")
flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir") flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir")
flag.BoolVar(&engineDebug, "engine-debug", false, "use engine debug output") flag.BoolVar(&engineDebug, "d", false, "use engine debug output")
flag.BoolVar(&stateDebug, "state-debug", false, "use engine debug output")
flag.UintVar(&size, "s", 160, "max size of output") flag.UintVar(&size, "s", 160, "max size of output")
flag.StringVar(&host, "h", "127.0.0.1", "http host") flag.StringVar(&host, "h", "127.0.0.1", "socket host")
flag.UintVar(&port, "p", 7122, "http port") flag.UintVar(&port, "p", 7122, "socket port")
flag.Parse() flag.Parse()
if connStr == "" {
connStr = config.DbConn
}
if authConnStr == "" {
authConnStr = connStr
}
connData, err := storage.ToConnData(connStr)
if err != nil {
fmt.Fprintf(os.Stderr, "connstr err: %v", err)
os.Exit(1)
}
authConnData, err := storage.ToConnData(authConnStr)
if err != nil {
fmt.Fprintf(os.Stderr, "auth connstr err: %v", err)
os.Exit(1)
}
sshKeyFile := flag.Arg(0) sshKeyFile := flag.Arg(0)
_, err := os.Stat(sshKeyFile) _, err = os.Stat(sshKeyFile)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "cannot open ssh server private key file: %v\n", err) fmt.Fprintf(os.Stderr, "cannot open ssh server private key file: %v\n", err)
os.Exit(1) os.Exit(1)
@ -57,7 +84,7 @@ func main() {
logg.WarnCtxf(ctx, "!!!!! Do not expose to internet and only use with tunnel!") logg.WarnCtxf(ctx, "!!!!! Do not expose to internet and only use with tunnel!")
logg.WarnCtxf(ctx, "!!!!! (See ssh -L <...>)") logg.WarnCtxf(ctx, "!!!!! (See ssh -L <...>)")
logg.Infof("start command", "dbdir", dbDir, "resourcedir", resourceDir, "outputsize", size, "keyfile", sshKeyFile, "host", host, "port", port) logg.Infof("start command", "conn", connData, "authconn", authConnData, "resourcedir", resourceDir, "outputsize", size, "keyfile", sshKeyFile, "host", host, "port", port)
pfp := path.Join(scriptDir, "pp.csv") pfp := path.Join(scriptDir, "pp.csv")
@ -73,7 +100,7 @@ func main() {
cfg.EngineDebug = true cfg.EngineDebug = true
} }
authKeyStore, err := ssh.NewSshKeyStore(ctx, dbDir) authKeyStore, err := ssh.NewSshKeyStore(ctx, authConnData.String())
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "keystore file open error: %v", err) fmt.Fprintf(os.Stderr, "keystore file open error: %v", err)
os.Exit(1) os.Exit(1)
@ -92,10 +119,10 @@ func main() {
signal.Notify(cterm, os.Interrupt, syscall.SIGTERM) signal.Notify(cterm, os.Interrupt, syscall.SIGTERM)
runner := &ssh.SshRunner{ runner := &ssh.SshRunner{
Cfg: cfg, Cfg: cfg,
Debug: engineDebug, Debug: engineDebug,
FlagFile: pfp, FlagFile: pfp,
DbDir: dbDir, Conn: connData,
ResourceDir: resourceDir, ResourceDir: resourceDir,
SrvKeyFile: sshKeyFile, SrvKeyFile: sshKeyFile,
Host: host, Host: host,

View File

@ -23,17 +23,17 @@ type StorageServices interface {
GetPersister(ctx context.Context) (*persist.Persister, error) GetPersister(ctx context.Context) (*persist.Persister, error)
GetUserdataDb(ctx context.Context) (db.Db, error) GetUserdataDb(ctx context.Context) (db.Db, error)
GetResource(ctx context.Context) (resource.Resource, error) GetResource(ctx context.Context) (resource.Resource, error)
EnsureDbDir() error
} }
type StorageService struct { type StorageService struct {
svc *storage.MenuStorageService svc *storage.MenuStorageService
} }
func NewStorageService(dbDir string) *StorageService { func NewStorageService(conn storage.ConnData) (*StorageService, error) {
return &StorageService{ svc := &StorageService{
svc: storage.NewMenuStorageService(dbDir, ""), svc: storage.NewMenuStorageService(conn, ""),
} }
return svc, nil
} }
func(ss *StorageService) GetPersister(ctx context.Context) (*persist.Persister, error) { func(ss *StorageService) GetPersister(ctx context.Context) (*persist.Persister, error) {
@ -47,7 +47,3 @@ func(ss *StorageService) GetUserdataDb(ctx context.Context) (db.Db, error) {
func(ss *StorageService) GetResource(ctx context.Context) (resource.Resource, error) { func(ss *StorageService) GetResource(ctx context.Context) (resource.Resource, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func(ss *StorageService) EnsureDbDir() error {
return ss.svc.EnsureDbDir()
}

View File

@ -40,6 +40,7 @@ var (
VoucherTransfersURL string VoucherTransfersURL string
VoucherDataURL string VoucherDataURL string
CheckAliasURL string CheckAliasURL string
DbConn string
DefaultLanguage string DefaultLanguage string
Languages []string Languages []string
) )
@ -69,14 +70,20 @@ func setBase() error {
dataURLBase = initializers.GetEnv("DATA_URL_BASE", "http://localhost:5006") dataURLBase = initializers.GetEnv("DATA_URL_BASE", "http://localhost:5006")
BearerToken = initializers.GetEnv("BEARER_TOKEN", "") BearerToken = initializers.GetEnv("BEARER_TOKEN", "")
_, err = url.JoinPath(custodialURLBase, "/foo") _, err = url.Parse(custodialURLBase)
if err != nil { if err != nil {
return err return err
} }
_, err = url.JoinPath(dataURLBase, "/bar") _, err = url.Parse(dataURLBase)
if err != nil { if err != nil {
return err return err
} }
return nil
}
func setConn() error {
DbConn = initializers.GetEnv("DB_CONN", "")
return nil return nil
} }
@ -86,6 +93,10 @@ func LoadConfig() error {
if err != nil { if err != nil {
return err return err
} }
err = setConn()
if err != nil {
return err
}
err = setLanguage() err = setLanguage()
if err != nil { if err != nil {
return err return err

View File

@ -37,23 +37,34 @@ func formatItem(k []byte, v []byte) (string, error) {
func main() { func main() {
config.LoadConfig() config.LoadConfig()
var dbDir string var connStr string
var sessionId string var sessionId string
var database string var database string
var engineDebug bool var engineDebug bool
var err error
flag.StringVar(&sessionId, "session-id", "075xx2123", "session id") flag.StringVar(&sessionId, "session-id", "075xx2123", "session id")
flag.StringVar(&database, "db", "gdbm", "database to be used") flag.StringVar(&connStr, "c", ".state", "connection string")
flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from")
flag.BoolVar(&engineDebug, "d", false, "use engine debug output") flag.BoolVar(&engineDebug, "d", false, "use engine debug output")
flag.Parse() flag.Parse()
if connStr != "" {
connStr = config.DbConn
}
connData, err := storage.ToConnData(config.DbConn)
if err != nil {
fmt.Fprintf(os.Stderr, "connstr err: %v", err)
os.Exit(1)
}
logg.Infof("start command", "conn", connData)
ctx := context.Background() ctx := context.Background()
ctx = context.WithValue(ctx, "SessionId", sessionId) ctx = context.WithValue(ctx, "SessionId", sessionId)
ctx = context.WithValue(ctx, "Database", database) ctx = context.WithValue(ctx, "Database", database)
resourceDir := scriptDir resourceDir := scriptDir
menuStorageService := storage.NewMenuStorageService(dbDir, resourceDir) menuStorageService := storage.NewMenuStorageService(connData, resourceDir)
store, err := menuStorageService.GetUserdataDb(ctx) store, err := menuStorageService.GetUserdataDb(ctx)
if err != nil { if err != nil {

View File

@ -28,23 +28,34 @@ func init() {
func main() { func main() {
config.LoadConfig() config.LoadConfig()
var dbDir string var connStr string
var sessionId string var sessionId string
var database string var database string
var engineDebug bool var engineDebug bool
var err error
flag.StringVar(&sessionId, "session-id", "075xx2123", "session id") flag.StringVar(&sessionId, "session-id", "075xx2123", "session id")
flag.StringVar(&database, "db", "gdbm", "database to be used") flag.StringVar(&connStr, "c", "", "connection string")
flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from")
flag.BoolVar(&engineDebug, "d", false, "use engine debug output") flag.BoolVar(&engineDebug, "d", false, "use engine debug output")
flag.Parse() flag.Parse()
if connStr != "" {
connStr = config.DbConn
}
connData, err := storage.ToConnData(config.DbConn)
if err != nil {
fmt.Fprintf(os.Stderr, "connstr err: %v", err)
os.Exit(1)
}
logg.Infof("start command", "conn", connData)
ctx := context.Background() ctx := context.Background()
ctx = context.WithValue(ctx, "SessionId", sessionId) ctx = context.WithValue(ctx, "SessionId", sessionId)
ctx = context.WithValue(ctx, "Database", database) ctx = context.WithValue(ctx, "Database", database)
resourceDir := scriptDir resourceDir := scriptDir
menuStorageService := storage.NewMenuStorageService(dbDir, resourceDir) menuStorageService := storage.NewMenuStorageService(connData, resourceDir)
store, err := menuStorageService.GetUserdataDb(ctx) store, err := menuStorageService.GetUserdataDb(ctx)
if err != nil { if err != nil {

View File

@ -1,4 +1,4 @@
package ussd package application
import ( import (
"bytes" "bytes"

View File

@ -1,4 +1,4 @@
package ussd package application
import ( import (
"context" "context"

View File

@ -6,42 +6,42 @@ import (
"git.defalsify.org/vise.git/persist" "git.defalsify.org/vise.git/persist"
"git.defalsify.org/vise.git/resource" "git.defalsify.org/vise.git/resource"
"git.grassecon.net/urdt/ussd/internal/handlers/ussd" "git.grassecon.net/urdt/ussd/internal/handlers/application"
"git.grassecon.net/urdt/ussd/internal/storage" "git.grassecon.net/urdt/ussd/internal/storage"
) )
type BaseSessionHandler struct { type BaseSessionHandler struct {
cfgTemplate engine.Config cfgTemplate engine.Config
rp RequestParser rp RequestParser
rs resource.Resource rs resource.Resource
hn *ussd.Handlers hn *application.Handlers
provider storage.StorageProvider provider storage.StorageProvider
} }
func NewBaseSessionHandler(cfg engine.Config, rs resource.Resource, stateDb db.Db, userdataDb db.Db, rp RequestParser, hn *ussd.Handlers) *BaseSessionHandler { func NewBaseSessionHandler(cfg engine.Config, rs resource.Resource, stateDb db.Db, userdataDb db.Db, rp RequestParser, hn *application.Handlers) *BaseSessionHandler {
return &BaseSessionHandler{ return &BaseSessionHandler{
cfgTemplate: cfg, cfgTemplate: cfg,
rs: rs, rs: rs,
hn: hn, hn: hn,
rp: rp, rp: rp,
provider: storage.NewSimpleStorageProvider(stateDb, userdataDb), provider: storage.NewSimpleStorageProvider(stateDb, userdataDb),
} }
} }
func(f* BaseSessionHandler) Shutdown() { func (f *BaseSessionHandler) Shutdown() {
err := f.provider.Close() err := f.provider.Close()
if err != nil { if err != nil {
logg.Errorf("handler shutdown error", "err", err) logg.Errorf("handler shutdown error", "err", err)
} }
} }
func(f *BaseSessionHandler) GetEngine(cfg engine.Config, rs resource.Resource, pr *persist.Persister) engine.Engine { func (f *BaseSessionHandler) GetEngine(cfg engine.Config, rs resource.Resource, pr *persist.Persister) engine.Engine {
en := engine.NewEngine(cfg, rs) en := engine.NewEngine(cfg, rs)
en = en.WithPersister(pr) en = en.WithPersister(pr)
return en return en
} }
func(f *BaseSessionHandler) Process(rqs RequestSession) (RequestSession, error) { func (f *BaseSessionHandler) Process(rqs RequestSession) (RequestSession, error) {
var r bool var r bool
var err error var err error
var ok bool var ok bool
@ -88,21 +88,21 @@ func(f *BaseSessionHandler) Process(rqs RequestSession) (RequestSession, error)
return rqs, nil return rqs, nil
} }
func(f *BaseSessionHandler) Output(rqs RequestSession) (RequestSession, error) { func (f *BaseSessionHandler) Output(rqs RequestSession) (RequestSession, error) {
var err error var err error
_, err = rqs.Engine.Flush(rqs.Ctx, rqs.Writer) _, err = rqs.Engine.Flush(rqs.Ctx, rqs.Writer)
return rqs, err return rqs, err
} }
func(f *BaseSessionHandler) Reset(rqs RequestSession) (RequestSession, error) { func (f *BaseSessionHandler) Reset(rqs RequestSession) (RequestSession, error) {
defer f.provider.Put(rqs.Config.SessionId, rqs.Storage) defer f.provider.Put(rqs.Config.SessionId, rqs.Storage)
return rqs, rqs.Engine.Finish() return rqs, rqs.Engine.Finish()
} }
func(f *BaseSessionHandler) GetConfig() engine.Config { func (f *BaseSessionHandler) GetConfig() engine.Config {
return f.cfgTemplate return f.cfgTemplate
} }
func(f *BaseSessionHandler) GetRequestParser() RequestParser { func (f *BaseSessionHandler) GetRequestParser() RequestParser {
return f.rp return f.rp
} }

View File

@ -10,13 +10,13 @@ import (
"git.defalsify.org/vise.git/persist" "git.defalsify.org/vise.git/persist"
"git.defalsify.org/vise.git/resource" "git.defalsify.org/vise.git/resource"
"git.grassecon.net/urdt/ussd/internal/handlers/ussd" "git.grassecon.net/urdt/ussd/internal/handlers/application"
"git.grassecon.net/urdt/ussd/internal/utils" "git.grassecon.net/urdt/ussd/internal/utils"
"git.grassecon.net/urdt/ussd/remote" "git.grassecon.net/urdt/ussd/remote"
) )
type HandlerService interface { type HandlerService interface {
GetHandler() (*ussd.Handlers, error) GetHandler() (*application.Handlers, error)
} }
func getParser(fp string, debug bool) (*asm.FlagParser, error) { func getParser(fp string, debug bool) (*asm.FlagParser, error) {
@ -64,73 +64,73 @@ func (ls *LocalHandlerService) SetDataStore(db *db.Db) {
ls.UserdataStore = db ls.UserdataStore = db
} }
func (ls *LocalHandlerService) GetHandler(accountService remote.AccountServiceInterface) (*ussd.Handlers, error) { func (ls *LocalHandlerService) GetHandler(accountService remote.AccountServiceInterface) (*application.Handlers, error) {
replaceSeparatorFunc := func(input string) string { replaceSeparatorFunc := func(input string) string {
return strings.ReplaceAll(input, ":", ls.Cfg.MenuSeparator) return strings.ReplaceAll(input, ":", ls.Cfg.MenuSeparator)
} }
ussdHandlers, err := ussd.NewHandlers(ls.Parser, *ls.UserdataStore, ls.AdminStore, accountService, replaceSeparatorFunc) appHandlers, err := application.NewHandlers(ls.Parser, *ls.UserdataStore, ls.AdminStore, accountService, replaceSeparatorFunc)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ussdHandlers = ussdHandlers.WithPersister(ls.Pe) appHandlers = appHandlers.WithPersister(ls.Pe)
ls.DbRs.AddLocalFunc("set_language", ussdHandlers.SetLanguage) ls.DbRs.AddLocalFunc("set_language", appHandlers.SetLanguage)
ls.DbRs.AddLocalFunc("create_account", ussdHandlers.CreateAccount) ls.DbRs.AddLocalFunc("create_account", appHandlers.CreateAccount)
ls.DbRs.AddLocalFunc("save_temporary_pin", ussdHandlers.SaveTemporaryPin) ls.DbRs.AddLocalFunc("save_temporary_pin", appHandlers.SaveTemporaryPin)
ls.DbRs.AddLocalFunc("verify_create_pin", ussdHandlers.VerifyCreatePin) ls.DbRs.AddLocalFunc("verify_create_pin", appHandlers.VerifyCreatePin)
ls.DbRs.AddLocalFunc("check_identifier", ussdHandlers.CheckIdentifier) ls.DbRs.AddLocalFunc("check_identifier", appHandlers.CheckIdentifier)
ls.DbRs.AddLocalFunc("check_account_status", ussdHandlers.CheckAccountStatus) ls.DbRs.AddLocalFunc("check_account_status", appHandlers.CheckAccountStatus)
ls.DbRs.AddLocalFunc("authorize_account", ussdHandlers.Authorize) ls.DbRs.AddLocalFunc("authorize_account", appHandlers.Authorize)
ls.DbRs.AddLocalFunc("quit", ussdHandlers.Quit) ls.DbRs.AddLocalFunc("quit", appHandlers.Quit)
ls.DbRs.AddLocalFunc("check_balance", ussdHandlers.CheckBalance) ls.DbRs.AddLocalFunc("check_balance", appHandlers.CheckBalance)
ls.DbRs.AddLocalFunc("validate_recipient", ussdHandlers.ValidateRecipient) ls.DbRs.AddLocalFunc("validate_recipient", appHandlers.ValidateRecipient)
ls.DbRs.AddLocalFunc("transaction_reset", ussdHandlers.TransactionReset) ls.DbRs.AddLocalFunc("transaction_reset", appHandlers.TransactionReset)
ls.DbRs.AddLocalFunc("invite_valid_recipient", ussdHandlers.InviteValidRecipient) ls.DbRs.AddLocalFunc("invite_valid_recipient", appHandlers.InviteValidRecipient)
ls.DbRs.AddLocalFunc("max_amount", ussdHandlers.MaxAmount) ls.DbRs.AddLocalFunc("max_amount", appHandlers.MaxAmount)
ls.DbRs.AddLocalFunc("validate_amount", ussdHandlers.ValidateAmount) ls.DbRs.AddLocalFunc("validate_amount", appHandlers.ValidateAmount)
ls.DbRs.AddLocalFunc("reset_transaction_amount", ussdHandlers.ResetTransactionAmount) ls.DbRs.AddLocalFunc("reset_transaction_amount", appHandlers.ResetTransactionAmount)
ls.DbRs.AddLocalFunc("get_recipient", ussdHandlers.GetRecipient) ls.DbRs.AddLocalFunc("get_recipient", appHandlers.GetRecipient)
ls.DbRs.AddLocalFunc("get_sender", ussdHandlers.GetSender) ls.DbRs.AddLocalFunc("get_sender", appHandlers.GetSender)
ls.DbRs.AddLocalFunc("get_amount", ussdHandlers.GetAmount) ls.DbRs.AddLocalFunc("get_amount", appHandlers.GetAmount)
ls.DbRs.AddLocalFunc("reset_incorrect", ussdHandlers.ResetIncorrectPin) ls.DbRs.AddLocalFunc("reset_incorrect", appHandlers.ResetIncorrectPin)
ls.DbRs.AddLocalFunc("save_firstname", ussdHandlers.SaveFirstname) ls.DbRs.AddLocalFunc("save_firstname", appHandlers.SaveFirstname)
ls.DbRs.AddLocalFunc("save_familyname", ussdHandlers.SaveFamilyname) ls.DbRs.AddLocalFunc("save_familyname", appHandlers.SaveFamilyname)
ls.DbRs.AddLocalFunc("save_gender", ussdHandlers.SaveGender) ls.DbRs.AddLocalFunc("save_gender", appHandlers.SaveGender)
ls.DbRs.AddLocalFunc("save_location", ussdHandlers.SaveLocation) ls.DbRs.AddLocalFunc("save_location", appHandlers.SaveLocation)
ls.DbRs.AddLocalFunc("save_yob", ussdHandlers.SaveYob) ls.DbRs.AddLocalFunc("save_yob", appHandlers.SaveYob)
ls.DbRs.AddLocalFunc("save_offerings", ussdHandlers.SaveOfferings) ls.DbRs.AddLocalFunc("save_offerings", appHandlers.SaveOfferings)
ls.DbRs.AddLocalFunc("reset_account_authorized", ussdHandlers.ResetAccountAuthorized) ls.DbRs.AddLocalFunc("reset_account_authorized", appHandlers.ResetAccountAuthorized)
ls.DbRs.AddLocalFunc("reset_allow_update", ussdHandlers.ResetAllowUpdate) ls.DbRs.AddLocalFunc("reset_allow_update", appHandlers.ResetAllowUpdate)
ls.DbRs.AddLocalFunc("get_profile_info", ussdHandlers.GetProfileInfo) ls.DbRs.AddLocalFunc("get_profile_info", appHandlers.GetProfileInfo)
ls.DbRs.AddLocalFunc("verify_yob", ussdHandlers.VerifyYob) ls.DbRs.AddLocalFunc("verify_yob", appHandlers.VerifyYob)
ls.DbRs.AddLocalFunc("reset_incorrect_date_format", ussdHandlers.ResetIncorrectYob) ls.DbRs.AddLocalFunc("reset_incorrect_date_format", appHandlers.ResetIncorrectYob)
ls.DbRs.AddLocalFunc("initiate_transaction", ussdHandlers.InitiateTransaction) ls.DbRs.AddLocalFunc("initiate_transaction", appHandlers.InitiateTransaction)
ls.DbRs.AddLocalFunc("verify_new_pin", ussdHandlers.VerifyNewPin) ls.DbRs.AddLocalFunc("verify_new_pin", appHandlers.VerifyNewPin)
ls.DbRs.AddLocalFunc("confirm_pin_change", ussdHandlers.ConfirmPinChange) ls.DbRs.AddLocalFunc("confirm_pin_change", appHandlers.ConfirmPinChange)
ls.DbRs.AddLocalFunc("quit_with_help", ussdHandlers.QuitWithHelp) ls.DbRs.AddLocalFunc("quit_with_help", appHandlers.QuitWithHelp)
ls.DbRs.AddLocalFunc("fetch_community_balance", ussdHandlers.FetchCommunityBalance) ls.DbRs.AddLocalFunc("fetch_community_balance", appHandlers.FetchCommunityBalance)
ls.DbRs.AddLocalFunc("set_default_voucher", ussdHandlers.SetDefaultVoucher) ls.DbRs.AddLocalFunc("set_default_voucher", appHandlers.SetDefaultVoucher)
ls.DbRs.AddLocalFunc("check_vouchers", ussdHandlers.CheckVouchers) ls.DbRs.AddLocalFunc("check_vouchers", appHandlers.CheckVouchers)
ls.DbRs.AddLocalFunc("get_vouchers", ussdHandlers.GetVoucherList) ls.DbRs.AddLocalFunc("get_vouchers", appHandlers.GetVoucherList)
ls.DbRs.AddLocalFunc("view_voucher", ussdHandlers.ViewVoucher) ls.DbRs.AddLocalFunc("view_voucher", appHandlers.ViewVoucher)
ls.DbRs.AddLocalFunc("set_voucher", ussdHandlers.SetVoucher) ls.DbRs.AddLocalFunc("set_voucher", appHandlers.SetVoucher)
ls.DbRs.AddLocalFunc("get_voucher_details", ussdHandlers.GetVoucherDetails) ls.DbRs.AddLocalFunc("get_voucher_details", appHandlers.GetVoucherDetails)
ls.DbRs.AddLocalFunc("reset_valid_pin", ussdHandlers.ResetValidPin) ls.DbRs.AddLocalFunc("reset_valid_pin", appHandlers.ResetValidPin)
ls.DbRs.AddLocalFunc("check_pin_mismatch", ussdHandlers.CheckBlockedNumPinMisMatch) ls.DbRs.AddLocalFunc("check_pin_mismatch", appHandlers.CheckBlockedNumPinMisMatch)
ls.DbRs.AddLocalFunc("validate_blocked_number", ussdHandlers.ValidateBlockedNumber) ls.DbRs.AddLocalFunc("validate_blocked_number", appHandlers.ValidateBlockedNumber)
ls.DbRs.AddLocalFunc("retrieve_blocked_number", ussdHandlers.RetrieveBlockedNumber) ls.DbRs.AddLocalFunc("retrieve_blocked_number", appHandlers.RetrieveBlockedNumber)
ls.DbRs.AddLocalFunc("reset_unregistered_number", ussdHandlers.ResetUnregisteredNumber) ls.DbRs.AddLocalFunc("reset_unregistered_number", appHandlers.ResetUnregisteredNumber)
ls.DbRs.AddLocalFunc("reset_others_pin", ussdHandlers.ResetOthersPin) ls.DbRs.AddLocalFunc("reset_others_pin", appHandlers.ResetOthersPin)
ls.DbRs.AddLocalFunc("save_others_temporary_pin", ussdHandlers.SaveOthersTemporaryPin) ls.DbRs.AddLocalFunc("save_others_temporary_pin", appHandlers.SaveOthersTemporaryPin)
ls.DbRs.AddLocalFunc("get_current_profile_info", ussdHandlers.GetCurrentProfileInfo) ls.DbRs.AddLocalFunc("get_current_profile_info", appHandlers.GetCurrentProfileInfo)
ls.DbRs.AddLocalFunc("check_transactions", ussdHandlers.CheckTransactions) ls.DbRs.AddLocalFunc("check_transactions", appHandlers.CheckTransactions)
ls.DbRs.AddLocalFunc("get_transactions", ussdHandlers.GetTransactionsList) ls.DbRs.AddLocalFunc("get_transactions", appHandlers.GetTransactionsList)
ls.DbRs.AddLocalFunc("view_statement", ussdHandlers.ViewTransactionStatement) ls.DbRs.AddLocalFunc("view_statement", appHandlers.ViewTransactionStatement)
ls.DbRs.AddLocalFunc("update_all_profile_items", ussdHandlers.UpdateAllProfileItems) ls.DbRs.AddLocalFunc("update_all_profile_items", appHandlers.UpdateAllProfileItems)
ls.DbRs.AddLocalFunc("set_back", ussdHandlers.SetBack) ls.DbRs.AddLocalFunc("set_back", appHandlers.SetBack)
ls.DbRs.AddLocalFunc("show_blocked_account", ussdHandlers.ShowBlockedAccount) ls.DbRs.AddLocalFunc("show_blocked_account", appHandlers.ShowBlockedAccount)
return ussdHandlers, nil return appHandlers, nil
} }
// TODO: enable setting of sessionId on engine init time // TODO: enable setting of sessionId on engine init time

View File

@ -41,6 +41,7 @@ func NewAuther(ctx context.Context, keyStore *SshKeyStore) *auther {
} }
func(a *auther) Check(conn ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) { func(a *auther) Check(conn ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
logg.TraceCtxf(a.Ctx, "looking for publickey", "pubkey", fmt.Sprintf("%x", pubKey))
va, err := a.keyStore.Get(a.Ctx, pubKey) va, err := a.keyStore.Get(a.Ctx, pubKey)
if err != nil { if err != nil {
return nil, err return nil, err
@ -71,6 +72,20 @@ func(a *auther) Get(k []byte) (string, error) {
return v, nil return v, nil
} }
type SshRunner struct {
Ctx context.Context
Cfg engine.Config
FlagFile string
Conn storage.ConnData
ResourceDir string
Debug bool
SrvKeyFile string
Host string
Port uint
wg sync.WaitGroup
lst net.Listener
}
func(s *SshRunner) serve(ctx context.Context, sessionId string, ch ssh.NewChannel, en engine.Engine) error { func(s *SshRunner) serve(ctx context.Context, sessionId string, ch ssh.NewChannel, en engine.Engine) error {
if ch == nil { if ch == nil {
return errors.New("nil channel") return errors.New("nil channel")
@ -128,32 +143,13 @@ func(s *SshRunner) serve(ctx context.Context, sessionId string, ch ssh.NewChanne
return nil return nil
} }
type SshRunner struct {
Ctx context.Context
Cfg engine.Config
FlagFile string
DbDir string
ResourceDir string
Debug bool
SrvKeyFile string
Host string
Port uint
wg sync.WaitGroup
lst net.Listener
}
func(s *SshRunner) Stop() error { func(s *SshRunner) Stop() error {
return s.lst.Close() return s.lst.Close()
} }
func(s *SshRunner) GetEngine(sessionId string) (engine.Engine, func(), error) { func(s *SshRunner) GetEngine(sessionId string) (engine.Engine, func(), error) {
ctx := s.Ctx ctx := s.Ctx
menuStorageService := storage.NewMenuStorageService(s.DbDir, s.ResourceDir) menuStorageService := storage.NewMenuStorageService(s.Conn, s.ResourceDir)
err := menuStorageService.EnsureDbDir()
if err != nil {
return nil, nil, err
}
rs, err := menuStorageService.GetResource(ctx) rs, err := menuStorageService.GetResource(ctx)
if err != nil { if err != nil {
@ -208,6 +204,7 @@ func(s *SshRunner) GetEngine(sessionId string) (engine.Engine, func(), error) {
// adapted example from crypto/ssh package, NewServerConn doc // adapted example from crypto/ssh package, NewServerConn doc
func(s *SshRunner) Run(ctx context.Context, keyStore *SshKeyStore) { func(s *SshRunner) Run(ctx context.Context, keyStore *SshKeyStore) {
s.Ctx = ctx
running := true running := true
// TODO: waitgroup should probably not be global // TODO: waitgroup should probably not be global

69
internal/storage/parse.go Normal file
View File

@ -0,0 +1,69 @@
package storage
import (
"fmt"
"net/url"
"path"
)
const (
DBTYPE_MEM = iota
DBTYPE_GDBM
DBTYPE_POSTGRES
)
type ConnData struct {
typ int
str string
}
func (cd *ConnData) DbType() int {
return cd.typ
}
func (cd *ConnData) String() string {
return cd.str
}
func probePostgres(s string) (string, bool) {
v, err := url.Parse(s)
if err != nil {
return "", false
}
if v.Scheme != "postgres" {
return "", false
}
return s, true
}
func probeGdbm(s string) (string, bool) {
if !path.IsAbs(s) {
return "", false
}
s = path.Clean(s)
return s, true
}
func ToConnData(connStr string) (ConnData, error) {
var o ConnData
if connStr == "" {
return o, nil
}
v, ok := probePostgres(connStr)
if ok {
o.typ = DBTYPE_POSTGRES
o.str = v
return o, nil
}
v, ok = probeGdbm(connStr)
if ok {
o.typ = DBTYPE_GDBM
o.str = v
return o, nil
}
return o, fmt.Errorf("invalid connection string: %s", connStr)
}

View File

@ -0,0 +1,28 @@
package storage
import (
"testing"
)
func TestParseConnStr(t *testing.T) {
_, err := ToConnData("postgres://foo:bar@localhost:5432/baz")
if err != nil {
t.Fatal(err)
}
_, err = ToConnData("/foo/bar")
if err != nil {
t.Fatal(err)
}
_, err = ToConnData("/foo/bar/")
if err != nil {
t.Fatal(err)
}
_, err = ToConnData("foo/bar")
if err == nil {
t.Fatalf("expected error")
}
_, err = ToConnData("http://foo/bar")
if err == nil {
t.Fatalf("expected error")
}
}

View File

@ -13,7 +13,6 @@ import (
"git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/logging"
"git.defalsify.org/vise.git/persist" "git.defalsify.org/vise.git/persist"
"git.defalsify.org/vise.git/resource" "git.defalsify.org/vise.git/resource"
"git.grassecon.net/urdt/ussd/initializers"
gdbmstorage "git.grassecon.net/urdt/ussd/internal/storage/db/gdbm" gdbmstorage "git.grassecon.net/urdt/ussd/internal/storage/db/gdbm"
) )
@ -25,11 +24,10 @@ type StorageService interface {
GetPersister(ctx context.Context) (*persist.Persister, error) GetPersister(ctx context.Context) (*persist.Persister, error)
GetUserdataDb(ctx context.Context) db.Db GetUserdataDb(ctx context.Context) db.Db
GetResource(ctx context.Context) (resource.Resource, error) GetResource(ctx context.Context) (resource.Resource, error)
EnsureDbDir() error
} }
type MenuStorageService struct { type MenuStorageService struct {
dbDir string conn ConnData
resourceDir string resourceDir string
poResource resource.Resource poResource resource.Resource
resourceStore db.Db resourceStore db.Db
@ -37,29 +35,45 @@ type MenuStorageService struct {
userDataStore db.Db userDataStore db.Db
} }
func buildConnStr() string { func NewMenuStorageService(conn ConnData, resourceDir string) *MenuStorageService {
host := initializers.GetEnv("DB_HOST", "localhost")
user := initializers.GetEnv("DB_USER", "postgres")
password := initializers.GetEnv("DB_PASSWORD", "")
dbName := initializers.GetEnv("DB_NAME", "")
port := initializers.GetEnv("DB_PORT", "5432")
connString := fmt.Sprintf(
"postgres://%s:%s@%s:%s/%s",
user, password, host, port, dbName,
)
logg.Debugf("pg conn string", "conn", connString)
return connString
}
func NewMenuStorageService(dbDir string, resourceDir string) *MenuStorageService {
return &MenuStorageService{ return &MenuStorageService{
dbDir: dbDir, conn: conn,
resourceDir: resourceDir, resourceDir: resourceDir,
} }
} }
func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.Db, section string) (db.Db, error) {
var newDb db.Db
var err error
if existingDb != nil {
return existingDb, nil
}
connStr := ms.conn.String()
dbTyp := ms.conn.DbType()
if dbTyp == DBTYPE_POSTGRES {
newDb = postgres.NewPgDb()
} else if dbTyp == DBTYPE_GDBM {
err = ms.ensureDbDir()
if err != nil {
return nil, err
}
connStr = path.Join(connStr, section)
newDb = gdbmstorage.NewThreadGdbmDb()
} else {
return nil, fmt.Errorf("unsupported connection string: '%s'\n", ms.conn.String())
}
logg.DebugCtxf(ctx, "connecting to db", "conn", connStr)
err = newDb.Connect(ctx, connStr)
if err != nil {
return nil, err
}
return newDb, nil
}
// WithGettext triggers use of gettext for translation of templates and menus. // WithGettext triggers use of gettext for translation of templates and menus.
// //
// The first language in `lns` will be used as default language, to resolve node keys to // The first language in `lns` will be used as default language, to resolve node keys to
@ -82,36 +96,6 @@ func (ms *MenuStorageService) WithGettext(path string, lns []lang.Language) *Men
return ms return ms
} }
func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.Db, fileName string) (db.Db, error) {
database, ok := ctx.Value("Database").(string)
if !ok {
return nil, fmt.Errorf("failed to select the database")
}
if existingDb != nil {
return existingDb, nil
}
var newDb db.Db
var err error
if database == "postgres" {
newDb = postgres.NewPgDb()
connStr := buildConnStr()
err = newDb.Connect(ctx, connStr)
} else {
newDb = gdbmstorage.NewThreadGdbmDb()
storeFile := path.Join(ms.dbDir, fileName)
err = newDb.Connect(ctx, storeFile)
}
if err != nil {
return nil, err
}
return newDb, nil
}
func (ms *MenuStorageService) GetPersister(ctx context.Context) (*persist.Persister, error) { func (ms *MenuStorageService) GetPersister(ctx context.Context) (*persist.Persister, error) {
stateStore, err := ms.GetStateStore(ctx) stateStore, err := ms.GetStateStore(ctx)
if err != nil { if err != nil {
@ -166,8 +150,8 @@ func (ms *MenuStorageService) GetStateStore(ctx context.Context) (db.Db, error)
return ms.stateStore, nil return ms.stateStore, nil
} }
func (ms *MenuStorageService) EnsureDbDir() error { func (ms *MenuStorageService) ensureDbDir() error {
err := os.MkdirAll(ms.dbDir, 0700) err := os.MkdirAll(ms.conn.String(), 0700)
if err != nil { if err != nil {
return fmt.Errorf("state dir create exited with error: %v\n", err) return fmt.Errorf("state dir create exited with error: %v\n", err)
} }

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"os" "os"
"path" "path"
"path/filepath"
"time" "time"
"git.defalsify.org/vise.git/engine" "git.defalsify.org/vise.git/engine"
@ -27,7 +28,6 @@ var (
func TestEngine(sessionId string) (engine.Engine, func(), chan bool) { func TestEngine(sessionId string) (engine.Engine, func(), chan bool) {
ctx := context.Background() ctx := context.Background()
ctx = context.WithValue(ctx, "SessionId", sessionId) ctx = context.WithValue(ctx, "SessionId", sessionId)
ctx = context.WithValue(ctx, "Database", "gdbm")
pfp := path.Join(scriptDir, "pp.csv") pfp := path.Join(scriptDir, "pp.csv")
var eventChannel = make(chan bool) var eventChannel = make(chan bool)
@ -39,37 +39,40 @@ func TestEngine(sessionId string) (engine.Engine, func(), chan bool) {
FlagCount: uint32(128), FlagCount: uint32(128),
} }
dbDir := ".test_state" connStr, err := filepath.Abs(".test_state/state.gdbm")
resourceDir := scriptDir
menuStorageService := storage.NewMenuStorageService(dbDir, resourceDir)
err := menuStorageService.EnsureDbDir()
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, err.Error()) fmt.Fprintf(os.Stderr, "connstr err: %v", err)
os.Exit(1) os.Exit(1)
} }
conn, err := storage.ToConnData(connStr)
if err != nil {
fmt.Fprintf(os.Stderr, "connstr parse err: %v", err)
os.Exit(1)
}
resourceDir := scriptDir
menuStorageService := storage.NewMenuStorageService(conn, resourceDir)
rs, err := menuStorageService.GetResource(ctx) rs, err := menuStorageService.GetResource(ctx)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, err.Error()) fmt.Fprintf(os.Stderr, "resource error: %v", err)
os.Exit(1) os.Exit(1)
} }
pe, err := menuStorageService.GetPersister(ctx) pe, err := menuStorageService.GetPersister(ctx)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, err.Error()) fmt.Fprintf(os.Stderr, "persister error: %v", err)
os.Exit(1) os.Exit(1)
} }
userDataStore, err := menuStorageService.GetUserdataDb(ctx) userDataStore, err := menuStorageService.GetUserdataDb(ctx)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, err.Error()) fmt.Fprintf(os.Stderr, "userdb error: %v", err)
os.Exit(1) os.Exit(1)
} }
dbResource, ok := rs.(*resource.DbResource) dbResource, ok := rs.(*resource.DbResource)
if !ok { if !ok {
fmt.Fprintf(os.Stderr, err.Error()) fmt.Fprintf(os.Stderr, "dbresource cast error")
os.Exit(1) os.Exit(1)
} }

View File

@ -0,0 +1,15 @@
package testutil
import (
"testing"
)
func TestCreateEngine(t *testing.T) {
o, clean, eventC := TestEngine("foo")
defer clean()
defer func() {
<-eventC
close(eventC)
}()
_ = o
}

View File

@ -7,6 +7,7 @@ import (
"log" "log"
"math/rand" "math/rand"
"os" "os"
"path/filepath"
"regexp" "regexp"
"testing" "testing"
@ -17,7 +18,6 @@ import (
var ( var (
testData = driver.ReadData() testData = driver.ReadData()
testStore = ".test_state"
sessionID string sessionID string
src = rand.NewSource(42) src = rand.NewSource(42)
g = rand.New(src) g = rand.New(src)
@ -25,6 +25,11 @@ var (
var groupTestFile = flag.String("test-file", "group_test.json", "The test file to use for running the group tests") var groupTestFile = flag.String("test-file", "group_test.json", "The test file to use for running the group tests")
func testStore() string {
v, _ := filepath.Abs(".test_state/state.gdbm")
return v
}
func GenerateSessionId() string { func GenerateSessionId() string {
uu := uuid.NewGenWithOptions(uuid.WithRandomReader(g)) uu := uuid.NewGenWithOptions(uuid.WithRandomReader(g))
v, err := uu.NewV4() v, err := uu.NewV4()
@ -81,8 +86,8 @@ func extractSendAmount(response []byte) string {
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
sessionID = GenerateSessionId() sessionID = GenerateSessionId()
defer func() { defer func() {
if err := os.RemoveAll(testStore); err != nil { if err := os.RemoveAll(testStore()); err != nil {
log.Fatalf("Failed to delete state store %s: %v", testStore, err) log.Fatalf("Failed to delete state store %s: %v", testStore(), err)
} }
}() }()
m.Run() m.Run()