Compare commits

...

16 Commits

15 changed files with 298 additions and 116 deletions

View File

@ -12,17 +12,17 @@ import (
"syscall" "syscall"
"git.defalsify.org/vise.git/engine" "git.defalsify.org/vise.git/engine"
"git.defalsify.org/vise.git/lang"
"git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/logging"
"git.defalsify.org/vise.git/resource" "git.defalsify.org/vise.git/resource"
"git.defalsify.org/vise.git/lang"
"git.grassecon.net/urdt/ussd/config" "git.grassecon.net/urdt/ussd/config"
"git.grassecon.net/urdt/ussd/initializers" "git.grassecon.net/urdt/ussd/initializers"
"git.grassecon.net/urdt/ussd/internal/args"
"git.grassecon.net/urdt/ussd/internal/handlers" "git.grassecon.net/urdt/ussd/internal/handlers"
"git.grassecon.net/urdt/ussd/internal/http/at" "git.grassecon.net/urdt/ussd/internal/http/at"
"git.grassecon.net/urdt/ussd/internal/storage" "git.grassecon.net/urdt/ussd/internal/storage"
"git.grassecon.net/urdt/ussd/remote" "git.grassecon.net/urdt/ussd/remote"
"git.grassecon.net/urdt/ussd/internal/args"
) )
var ( var (
@ -42,7 +42,6 @@ func main() {
var connStr string var connStr string
var resourceDir string var resourceDir string
var size uint var size uint
var database string
var engineDebug bool var engineDebug bool
var host string var host string
var port uint var port uint
@ -51,7 +50,8 @@ func main() {
var langs args.LangVar var langs args.LangVar
flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir") flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir")
flag.StringVar(&connStr, "c", "", "connection string") flag.StringVar(&database, "db", "gdbm", "database to be used")
flag.StringVar(&dbSchema, "schema", "public", "database schema to be used")
flag.BoolVar(&engineDebug, "d", false, "use engine debug output") flag.BoolVar(&engineDebug, "d", false, "use engine debug output")
flag.UintVar(&size, "s", 160, "max size of output") flag.UintVar(&size, "s", 160, "max size of output")
flag.StringVar(&host, "h", initializers.GetEnv("HOST", "127.0.0.1"), "http host") flag.StringVar(&host, "h", initializers.GetEnv("HOST", "127.0.0.1"), "http host")
@ -60,7 +60,7 @@ func main() {
flag.Var(&langs, "language", "add symbol resolution for language") flag.Var(&langs, "language", "add symbol resolution for language")
flag.Parse() flag.Parse()
if connStr != "" { if connStr == "" {
connStr = config.DbConn connStr = config.DbConn
} }
connData, err := storage.ToConnData(connStr) connData, err := storage.ToConnData(connStr)
@ -72,7 +72,6 @@ func main() {
logg.Infof("start command", "build", build, "conn", connData, "resourcedir", resourceDir, "outputsize", size) logg.Infof("start command", "build", build, "conn", connData, "resourcedir", resourceDir, "outputsize", size)
ctx := context.Background() ctx := context.Background()
ctx = context.WithValue(ctx, "Database", database)
ln, err := lang.LanguageFromCode(config.DefaultLanguage) ln, err := lang.LanguageFromCode(config.DefaultLanguage)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "default language set error: %v", err) fmt.Fprintf(os.Stderr, "default language set error: %v", err)

View File

@ -10,16 +10,16 @@ import (
"syscall" "syscall"
"git.defalsify.org/vise.git/engine" "git.defalsify.org/vise.git/engine"
"git.defalsify.org/vise.git/lang"
"git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/logging"
"git.defalsify.org/vise.git/resource" "git.defalsify.org/vise.git/resource"
"git.defalsify.org/vise.git/lang"
"git.grassecon.net/urdt/ussd/config" "git.grassecon.net/urdt/ussd/config"
"git.grassecon.net/urdt/ussd/initializers" "git.grassecon.net/urdt/ussd/initializers"
"git.grassecon.net/urdt/ussd/internal/args"
"git.grassecon.net/urdt/ussd/internal/handlers" "git.grassecon.net/urdt/ussd/internal/handlers"
"git.grassecon.net/urdt/ussd/internal/storage" "git.grassecon.net/urdt/ussd/internal/storage"
"git.grassecon.net/urdt/ussd/remote" "git.grassecon.net/urdt/ussd/remote"
"git.grassecon.net/urdt/ussd/internal/args"
) )
var ( var (
@ -52,7 +52,6 @@ func main() {
var sessionId string var sessionId string
var resourceDir string var resourceDir string
var size uint var size uint
var database string
var engineDebug bool var engineDebug bool
var host string var host string
var port uint var port uint
@ -62,7 +61,8 @@ func main() {
flag.StringVar(&sessionId, "session-id", "075xx2123", "session id") flag.StringVar(&sessionId, "session-id", "075xx2123", "session id")
flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir") flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir")
flag.StringVar(&connStr, "c", "", "connection string") flag.StringVar(&database, "db", "gdbm", "database to be used")
flag.StringVar(&dbSchema, "schema", "public", "database schema to be used")
flag.BoolVar(&engineDebug, "d", false, "use engine debug output") flag.BoolVar(&engineDebug, "d", false, "use engine debug output")
flag.UintVar(&size, "s", 160, "max size of output") flag.UintVar(&size, "s", 160, "max size of output")
flag.StringVar(&host, "h", initializers.GetEnv("HOST", "127.0.0.1"), "http host") flag.StringVar(&host, "h", initializers.GetEnv("HOST", "127.0.0.1"), "http host")
@ -71,7 +71,7 @@ func main() {
flag.Var(&langs, "language", "add symbol resolution for language") flag.Var(&langs, "language", "add symbol resolution for language")
flag.Parse() flag.Parse()
if connStr != "" { if connStr == "" {
connStr = config.DbConn connStr = config.DbConn
} }
connData, err := storage.ToConnData(connStr) connData, err := storage.ToConnData(connStr)
@ -83,7 +83,6 @@ func main() {
logg.Infof("start command", "conn", connData, "resourcedir", resourceDir, "outputsize", size, "sessionId", sessionId) logg.Infof("start command", "conn", connData, "resourcedir", resourceDir, "outputsize", size, "sessionId", sessionId)
ctx := context.Background() ctx := context.Background()
ctx = context.WithValue(ctx, "Database", database)
ln, err := lang.LanguageFromCode(config.DefaultLanguage) ln, err := lang.LanguageFromCode(config.DefaultLanguage)
if err != nil { if err != nil {
@ -117,7 +116,6 @@ func main() {
os.Exit(1) os.Exit(1)
} }
userdataStore, err := menuStorageService.GetUserdataDb(ctx) userdataStore, err := menuStorageService.GetUserdataDb(ctx)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, err.Error()) fmt.Fprintf(os.Stderr, err.Error())

View File

@ -12,22 +12,22 @@ import (
"syscall" "syscall"
"git.defalsify.org/vise.git/engine" "git.defalsify.org/vise.git/engine"
"git.defalsify.org/vise.git/lang"
"git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/logging"
"git.defalsify.org/vise.git/resource" "git.defalsify.org/vise.git/resource"
"git.defalsify.org/vise.git/lang"
"git.grassecon.net/urdt/ussd/config" "git.grassecon.net/urdt/ussd/config"
"git.grassecon.net/urdt/ussd/initializers" "git.grassecon.net/urdt/ussd/initializers"
"git.grassecon.net/urdt/ussd/internal/args"
"git.grassecon.net/urdt/ussd/internal/handlers" "git.grassecon.net/urdt/ussd/internal/handlers"
httpserver "git.grassecon.net/urdt/ussd/internal/http" httpserver "git.grassecon.net/urdt/ussd/internal/http"
"git.grassecon.net/urdt/ussd/internal/storage" "git.grassecon.net/urdt/ussd/internal/storage"
"git.grassecon.net/urdt/ussd/remote" "git.grassecon.net/urdt/ussd/remote"
"git.grassecon.net/urdt/ussd/internal/args"
) )
var ( var (
logg = logging.NewVanilla() logg = logging.NewVanilla()
scriptDir = path.Join("services", "registration") scriptDir = path.Join("services", "registration")
menuSeparator = ": " menuSeparator = ": "
) )
@ -41,7 +41,6 @@ func main() {
var connStr string var connStr string
var resourceDir string var resourceDir string
var size uint var size uint
var database string
var engineDebug bool var engineDebug bool
var host string var host string
var port uint var port uint
@ -50,7 +49,8 @@ func main() {
var langs args.LangVar var langs args.LangVar
flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir") flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir")
flag.StringVar(&connStr, "c", "", "connection string") flag.StringVar(&database, "db", "gdbm", "database to be used")
flag.StringVar(&dbSchema, "schema", "public", "database schema to be used")
flag.BoolVar(&engineDebug, "d", false, "use engine debug output") flag.BoolVar(&engineDebug, "d", false, "use engine debug output")
flag.UintVar(&size, "s", 160, "max size of output") flag.UintVar(&size, "s", 160, "max size of output")
flag.StringVar(&host, "h", initializers.GetEnv("HOST", "127.0.0.1"), "http host") flag.StringVar(&host, "h", initializers.GetEnv("HOST", "127.0.0.1"), "http host")
@ -59,7 +59,7 @@ func main() {
flag.Var(&langs, "language", "add symbol resolution for language") flag.Var(&langs, "language", "add symbol resolution for language")
flag.Parse() flag.Parse()
if connStr != "" { if connStr == "" {
connStr = config.DbConn connStr = config.DbConn
} }
connData, err := storage.ToConnData(connStr) connData, err := storage.ToConnData(connStr)
@ -71,7 +71,6 @@ func main() {
logg.Infof("start command", "conn", connData, "resourcedir", resourceDir, "outputsize", size) logg.Infof("start command", "conn", connData, "resourcedir", resourceDir, "outputsize", size)
ctx := context.Background() ctx := context.Background()
ctx = context.WithValue(ctx, "Database", database)
ln, err := lang.LanguageFromCode(config.DefaultLanguage) ln, err := lang.LanguageFromCode(config.DefaultLanguage)
if err != nil { if err != nil {
@ -94,7 +93,7 @@ func main() {
} }
menuStorageService := storage.NewMenuStorageService(connData, resourceDir) menuStorageService := storage.NewMenuStorageService(connData, resourceDir)
rs, err := menuStorageService.GetResource(ctx) rs, err := menuStorageService.GetResource(ctx)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, err.Error()) fmt.Fprintf(os.Stderr, err.Error())

View File

@ -8,14 +8,14 @@ import (
"path" "path"
"git.defalsify.org/vise.git/engine" "git.defalsify.org/vise.git/engine"
"git.defalsify.org/vise.git/lang"
"git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/logging"
"git.defalsify.org/vise.git/resource" "git.defalsify.org/vise.git/resource"
"git.defalsify.org/vise.git/lang"
"git.grassecon.net/urdt/ussd/config" "git.grassecon.net/urdt/ussd/config"
"git.grassecon.net/urdt/ussd/initializers" "git.grassecon.net/urdt/ussd/initializers"
"git.grassecon.net/urdt/ussd/internal/args"
"git.grassecon.net/urdt/ussd/internal/handlers" "git.grassecon.net/urdt/ussd/internal/handlers"
"git.grassecon.net/urdt/ussd/internal/storage" "git.grassecon.net/urdt/ussd/internal/storage"
"git.grassecon.net/urdt/ussd/internal/args"
"git.grassecon.net/urdt/ussd/remote" "git.grassecon.net/urdt/ussd/remote"
) )
@ -36,7 +36,6 @@ func main() {
var connStr string var connStr string
var size uint var size uint
var sessionId string var sessionId string
var database string
var engineDebug bool var engineDebug bool
var resourceDir string var resourceDir string
var err error var err error
@ -45,14 +44,16 @@ func main() {
flag.StringVar(&resourceDir, "resourcedir", scriptDir, "resource dir") flag.StringVar(&resourceDir, "resourcedir", scriptDir, "resource dir")
flag.StringVar(&sessionId, "session-id", "075xx2123", "session id") flag.StringVar(&sessionId, "session-id", "075xx2123", "session id")
flag.StringVar(&connStr, "c", "", "connection string") flag.StringVar(&database, "db", "gdbm", "database to be used")
flag.StringVar(&dbSchema, "schema", "public", "database schema to be used")
flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from")
flag.BoolVar(&engineDebug, "d", false, "use engine debug output") flag.BoolVar(&engineDebug, "d", false, "use engine debug output")
flag.UintVar(&size, "s", 160, "max size of output") flag.UintVar(&size, "s", 160, "max size of output")
flag.StringVar(&gettextDir, "gettext", "", "use gettext translations from given directory") flag.StringVar(&gettextDir, "gettext", "", "use gettext translations from given directory")
flag.Var(&langs, "language", "add symbol resolution for language") flag.Var(&langs, "language", "add symbol resolution for language")
flag.Parse() flag.Parse()
if connStr != "" { if connStr == "" {
connStr = config.DbConn connStr = config.DbConn
} }
connData, err := storage.ToConnData(connStr) connData, err := storage.ToConnData(connStr)
@ -69,7 +70,6 @@ func main() {
ctx := context.Background() ctx := context.Background()
ctx = context.WithValue(ctx, "SessionId", sessionId) ctx = context.WithValue(ctx, "SessionId", sessionId)
ctx = context.WithValue(ctx, "Database", database)
ln, err := lang.LanguageFromCode(config.DefaultLanguage) ln, err := lang.LanguageFromCode(config.DefaultLanguage)
if err != nil { if err != nil {
@ -89,7 +89,7 @@ func main() {
} }
menuStorageService := storage.NewMenuStorageService(connData, resourceDir) menuStorageService := storage.NewMenuStorageService(connData, resourceDir)
if gettextDir != "" { if gettextDir != "" {
menuStorageService = menuStorageService.WithGettext(gettextDir, langs.Langs()) menuStorageService = menuStorageService.WithGettext(gettextDir, langs.Langs())
} }

View File

@ -57,6 +57,8 @@ const (
DATA_ACTIVE_ADDRESS DATA_ACTIVE_ADDRESS
//Holds count of the number of incorrect PIN attempts //Holds count of the number of incorrect PIN attempts
DATA_INCORRECT_PIN_ATTEMPTS DATA_INCORRECT_PIN_ATTEMPTS
//ISO 639 code for the selected language.
DATA_SELECTED_LANGUAGE_CODE
) )
const ( const (

View File

@ -7,31 +7,22 @@ import (
"os" "os"
"path" "path"
"git.defalsify.org/vise.git/db"
"git.defalsify.org/vise.git/logging"
"git.grassecon.net/urdt/ussd/config" "git.grassecon.net/urdt/ussd/config"
"git.grassecon.net/urdt/ussd/initializers" "git.grassecon.net/urdt/ussd/initializers"
"git.grassecon.net/urdt/ussd/internal/storage" "git.grassecon.net/urdt/ussd/internal/storage"
"git.grassecon.net/urdt/ussd/debug" testdataloader "github.com/peteole/testdata-loader"
"git.defalsify.org/vise.git/db"
"git.defalsify.org/vise.git/logging"
) )
var ( var (
logg = logging.NewVanilla() logg = logging.NewVanilla()
baseDir = testdataloader.GetBasePath()
scriptDir = path.Join("services", "registration") scriptDir = path.Join("services", "registration")
) )
func init() { func init() {
initializers.LoadEnvVariables() initializers.LoadEnvVariables(baseDir)
}
func formatItem(k []byte, v []byte) (string, error) {
o, err := debug.FromKey(k)
if err != nil {
return "", err
}
s := fmt.Sprintf("%vValue: %v\n\n", o, string(v))
return s, nil
} }
func main() { func main() {

View File

@ -9,14 +9,16 @@ import (
"path" "path"
"git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/logging"
"git.grassecon.net/urdt/ussd/config"
"git.grassecon.net/urdt/ussd/internal/storage"
"git.grassecon.net/urdt/ussd/initializers"
"git.grassecon.net/urdt/ussd/common" "git.grassecon.net/urdt/ussd/common"
"git.grassecon.net/urdt/ussd/config"
"git.grassecon.net/urdt/ussd/initializers"
"git.grassecon.net/urdt/ussd/internal/storage"
testdataloader "github.com/peteole/testdata-loader"
) )
var ( var (
logg = logging.NewVanilla() logg = logging.NewVanilla()
baseDir = testdataloader.GetBasePath()
scriptDir = path.Join("services", "registration") scriptDir = path.Join("services", "registration")
) )
@ -24,7 +26,6 @@ func init() {
initializers.LoadEnvVariables() initializers.LoadEnvVariables()
} }
func main() { func main() {
config.LoadConfig() config.LoadConfig()
@ -86,5 +87,4 @@ func main() {
fmt.Fprintf(os.Stderr, err.Error()) fmt.Fprintf(os.Stderr, err.Error())
os.Exit(1) os.Exit(1)
} }
} }

View File

@ -3,24 +3,30 @@ package initializers
import ( import (
"log" "log"
"os" "os"
"path"
"strconv" "strconv"
"github.com/joho/godotenv" "github.com/joho/godotenv"
) )
func LoadEnvVariables() { func LoadEnvVariables() {
err := godotenv.Load() LoadEnvVariablesPath(".")
}
func LoadEnvVariablesPath(dir string) {
fp := path.Join(dir, ".env")
err := godotenv.Load(fp)
if err != nil { if err != nil {
log.Fatal("Error loading .env file") log.Fatal("Error loading .env file", err)
} }
} }
// Helper to get environment variables with a default fallback // Helper to get environment variables with a default fallback
func GetEnv(key, defaultVal string) string { func GetEnv(key, defaultVal string) string {
if value, exists := os.LookupEnv(key); exists { if value, exists := os.LookupEnv(key); exists {
return value return value
} }
return defaultVal return defaultVal
} }
// Helper to safely convert environment variables to uint // Helper to safely convert environment variables to uint

View File

@ -161,9 +161,12 @@ func (h *Handlers) SetLanguage(ctx context.Context, sym string, input []byte) (r
//Fallback to english instead? //Fallback to english instead?
code = "eng" code = "eng"
} }
res.FlagSet = append(res.FlagSet, state.FLAG_LANG) err := h.persistLanguageCode(ctx, code)
if err != nil {
return res, err
}
res.Content = code res.Content = code
res.FlagSet = append(res.FlagSet, state.FLAG_LANG)
languageSetFlag, err := h.flagManager.GetFlag("flag_language_set") languageSetFlag, err := h.flagManager.GetFlag("flag_language_set")
if err != nil { if err != nil {
logg.ErrorCtxf(ctx, "Error setting the languageSetFlag", "error", err) logg.ErrorCtxf(ctx, "Error setting the languageSetFlag", "error", err)
@ -2173,3 +2176,18 @@ func (h *Handlers) resetIncorrectPINAttempts(ctx context.Context, sessionId stri
} }
return nil return nil
} }
// persistLanguageCode persists the selected ISO 639 language code
func (h *Handlers) persistLanguageCode(ctx context.Context, code string) error {
store := h.userdataStore
sessionId, ok := ctx.Value("SessionId").(string)
if !ok {
return fmt.Errorf("missing session")
}
err := store.WriteEntry(ctx, sessionId, common.DATA_SELECTED_LANGUAGE_CODE, []byte(code))
if err != nil {
logg.ErrorCtxf(ctx, "failed to persist language code", "key", common.DATA_SELECTED_LANGUAGE_CODE, "value", code, "error", err)
return err
}
return nil
}

View File

@ -775,6 +775,11 @@ func TestSetLanguage(t *testing.T) {
log.Fatal(err) log.Fatal(err)
} }
sessionId := "session123"
ctx, store := InitializeTestStore(t)
ctx = context.WithValue(ctx, "SessionId", sessionId)
// Define test cases // Define test cases
tests := []struct { tests := []struct {
name string name string
@ -807,12 +812,13 @@ func TestSetLanguage(t *testing.T) {
// Create the Handlers instance with the mock flag manager // Create the Handlers instance with the mock flag manager
h := &Handlers{ h := &Handlers{
flagManager: fm.parser, flagManager: fm.parser,
st: mockState, userdataStore: store,
st: mockState,
} }
// Call the method // Call the method
res, err := h.SetLanguage(context.Background(), "set_language", nil) res, err := h.SetLanguage(ctx, "set_language", nil)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -2285,3 +2291,41 @@ func TestResetIncorrectPINAttempts(t *testing.T) {
assert.Equal(t, "0", string(incorrectAttempts)) assert.Equal(t, "0", string(incorrectAttempts))
} }
func TestPersistLanguageCode(t *testing.T) {
ctx, store := InitializeTestStore(t)
sessionId := "session123"
ctx = context.WithValue(ctx, "SessionId", sessionId)
h := &Handlers{
userdataStore: store,
}
tests := []struct {
name string
code string
expectedLanguageCode string
}{
{
name: "Set Default Language (English)",
code: "eng",
expectedLanguageCode: "eng",
},
{
name: "Set Swahili Language",
code: "swa",
expectedLanguageCode: "swa",
},
}
for _, test := range tests {
err := h.persistLanguageCode(ctx, test.code)
if err != nil {
t.Logf(err.Error())
}
code, err := store.ReadEntry(ctx, sessionId, common.DATA_SELECTED_LANGUAGE_CODE)
assert.Equal(t, test.expectedLanguageCode, string(code))
}
}

View File

@ -15,6 +15,7 @@ const (
type ConnData struct { type ConnData struct {
typ int typ int
str string str string
domain string
} }
func (cd *ConnData) DbType() int { func (cd *ConnData) DbType() int {
@ -25,23 +26,38 @@ func (cd *ConnData) String() string {
return cd.str return cd.str
} }
func probePostgres(s string) (string, bool) { func (cd *ConnData) Domain() string {
v, err := url.Parse(s) return cd.domain
if err != nil {
return "", false
}
if v.Scheme != "postgres" {
return "", false
}
return s, true
} }
func probeGdbm(s string) (string, bool) { func (cd *ConnData) Path() string {
v, _ := url.Parse(cd.str)
v.RawQuery = ""
return v.String()
}
func probePostgres(s string) (string, string, bool) {
domain := "public"
v, err := url.Parse(s)
if err != nil {
return "", "", false
}
if v.Scheme != "postgres" {
return "", "", false
}
vv := v.Query()
if vv.Has("search_path") {
domain = vv.Get("search_path")
}
return s, domain, true
}
func probeGdbm(s string) (string, string, bool) {
if !path.IsAbs(s) { if !path.IsAbs(s) {
return "", false return "", "", false
} }
s = path.Clean(s) s = path.Clean(s)
return s, true return s, "", true
} }
func ToConnData(connStr string) (ConnData, error) { func ToConnData(connStr string) (ConnData, error) {
@ -51,14 +67,15 @@ func ToConnData(connStr string) (ConnData, error) {
return o, nil return o, nil
} }
v, ok := probePostgres(connStr) v, domain, ok := probePostgres(connStr)
if ok { if ok {
o.typ = DBTYPE_POSTGRES o.typ = DBTYPE_POSTGRES
o.str = v o.str = v
o.domain = domain
return o, nil return o, nil
} }
v, ok = probeGdbm(connStr) v, _, ok = probeGdbm(connStr)
if ok { if ok {
o.typ = DBTYPE_GDBM o.typ = DBTYPE_GDBM
o.str = v o.str = v

View File

@ -14,6 +14,7 @@ import (
"git.defalsify.org/vise.git/persist" "git.defalsify.org/vise.git/persist"
"git.defalsify.org/vise.git/resource" "git.defalsify.org/vise.git/resource"
gdbmstorage "git.grassecon.net/urdt/ussd/internal/storage/db/gdbm" gdbmstorage "git.grassecon.net/urdt/ussd/internal/storage/db/gdbm"
"github.com/jackc/pgx/v5/pgxpool"
) )
var ( var (
@ -35,7 +36,23 @@ type MenuStorageService struct {
userDataStore db.Db userDataStore db.Db
} }
func NewMenuStorageService(conn ConnData, resourceDir string) *MenuStorageService { func BuildConnStr() string {
host := initializers.GetEnv("DB_HOST", "localhost")
user := initializers.GetEnv("DB_USER", "postgres")
password := initializers.GetEnv("DB_PASSWORD", "")
dbName := initializers.GetEnv("DB_NAME", "")
port := initializers.GetEnv("DB_PORT", "5432")
connString := fmt.Sprintf(
"postgres://%s:%s@%s:%s/%s",
user, password, host, port, dbName,
)
logg.Debugf("pg conn string", "conn", connString)
return connString
}
func NewMenuStorageService(dbDir string, resourceDir string) *MenuStorageService {
return &MenuStorageService{ return &MenuStorageService{
conn: conn, conn: conn,
resourceDir: resourceDir, resourceDir: resourceDir,
@ -46,6 +63,11 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D
var newDb db.Db var newDb db.Db
var err error var err error
schema, ok := ctx.Value("Schema").(string)
if !ok {
return nil, fmt.Errorf("failed to select the schema")
}
if existingDb != nil { if existingDb != nil {
return existingDb, nil return existingDb, nil
} }
@ -54,18 +76,26 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D
connStr := ms.conn.String() connStr := ms.conn.String()
dbTyp := ms.conn.DbType() dbTyp := ms.conn.DbType()
if dbTyp == DBTYPE_POSTGRES { if dbTyp == DBTYPE_POSTGRES {
newDb = postgres.NewPgDb() // TODO: move to vise
} else if dbTyp == DBTYPE_GDBM { err = ensureSchemaExists(ctx, ms.conn)
err = ms.ensureDbDir()
if err != nil { if err != nil {
return nil, err return nil, err
} }
connStr = path.Join(connStr, section) newDb = postgres.NewPgDb().WithSchema(ms.conn.Domain())
} else if dbTyp == DBTYPE_GDBM {
err = ms.ensureDbDir()
if err != nil {
return nil, fmt.Errorf("failed to ensure schema exists: %w", err)
}
newDb = postgres.NewPgDb().WithSchema(schema)
err = newDb.Connect(ctx, connStr)
} else {
newDb = gdbmstorage.NewThreadGdbmDb() newDb = gdbmstorage.NewThreadGdbmDb()
} else { } else {
return nil, fmt.Errorf("unsupported connection string: '%s'\n", ms.conn.String()) return nil, fmt.Errorf("unsupported connection string: '%s'\n", ms.conn.String())
} }
logg.DebugCtxf(ctx, "connecting to db", "conn", connStr) logg.DebugCtxf(ctx, "connecting to db", "conn", connStr, "conndata", ms.conn)
err = newDb.Connect(ctx, connStr) err = newDb.Connect(ctx, connStr)
if err != nil { if err != nil {
return nil, err return nil, err
@ -74,26 +104,21 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D
return newDb, nil return newDb, nil
} }
// WithGettext triggers use of gettext for translation of templates and menus. // ensureSchemaExists creates a new schema if it does not exist
// func ensureSchemaExists(ctx context.Context, conn ConnData) error {
// The first language in `lns` will be used as default language, to resolve node keys to h, err := pgxpool.New(ctx, conn.Path())
// language strings. if err != nil {
// return fmt.Errorf("failed to connect to the database: %w", err)
// If `lns` is an empty array, gettext will not be used.
func (ms *MenuStorageService) WithGettext(path string, lns []lang.Language) *MenuStorageService {
if len(lns) == 0 {
logg.Warnf("Gettext requested but no languages supplied")
return ms
} }
rs := resource.NewPoResource(lns[0], path) defer h.Close()
for _, ln := range(lns) { query := fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", conn.Domain())
rs = rs.WithLanguage(ln) _, err = h.Exec(ctx, query)
if err != nil {
return fmt.Errorf("failed to create schema: %w", err)
} }
ms.poResource = rs return nil
return ms
} }
func (ms *MenuStorageService) GetPersister(ctx context.Context) (*persist.Persister, error) { func (ms *MenuStorageService) GetPersister(ctx context.Context) (*persist.Persister, error) {

View File

@ -3,6 +3,8 @@ package testutil
import ( import (
"context" "context"
"fmt" "fmt"
"log"
"net/url"
"os" "os"
"path" "path"
"path/filepath" "path/filepath"
@ -11,23 +13,95 @@ import (
"git.defalsify.org/vise.git/engine" "git.defalsify.org/vise.git/engine"
"git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/logging"
"git.defalsify.org/vise.git/resource" "git.defalsify.org/vise.git/resource"
"git.grassecon.net/urdt/ussd/config"
"git.grassecon.net/urdt/ussd/initializers"
"git.grassecon.net/urdt/ussd/internal/handlers" "git.grassecon.net/urdt/ussd/internal/handlers"
"git.grassecon.net/urdt/ussd/internal/storage" "git.grassecon.net/urdt/ussd/internal/storage"
"git.grassecon.net/urdt/ussd/internal/testutil/testservice" "git.grassecon.net/urdt/ussd/internal/testutil/testservice"
"git.grassecon.net/urdt/ussd/internal/testutil/testtag" "git.grassecon.net/urdt/ussd/internal/testutil/testtag"
testdataloader "github.com/peteole/testdata-loader"
"git.grassecon.net/urdt/ussd/remote" "git.grassecon.net/urdt/ussd/remote"
"github.com/jackc/pgx/v5/pgxpool"
testdataloader "github.com/peteole/testdata-loader"
) )
var ( var (
baseDir = testdataloader.GetBasePath() logg = logging.NewVanilla()
logg = logging.NewVanilla() baseDir = testdataloader.GetBasePath()
scriptDir = path.Join(baseDir, "services", "registration") scriptDir = path.Join(baseDir, "services", "registration")
setDbType string
setConnStr string
setDbSchema string
) )
func init() {
initializers.LoadEnvVariablesPath(baseDir)
config.LoadConfig()
}
// SetDatabase updates the database used by TestEngine
func SetDatabase(database, connStr, dbSchema string) {
setDbType = database
setConnStr = connStr
setDbSchema = dbSchema
}
// CleanDatabase removes all test data from the database
func CleanDatabase() {
if setDbType == "postgres" {
ctx := context.Background()
// Update the connection string with the new search path
updatedConnStr, err := updateSearchPath(setConnStr, setDbSchema)
if err != nil {
log.Fatalf("Failed to update search path: %v", err)
}
dbConn, err := pgxpool.New(ctx, updatedConnStr)
if err != nil {
log.Fatalf("Failed to connect to database for cleanup: %v", err)
}
defer dbConn.Close()
query := fmt.Sprintf("DELETE FROM %s.kv_vise;", setDbSchema)
_, execErr := dbConn.Exec(ctx, query)
if execErr != nil {
log.Printf("Failed to cleanup table %s.kv_vise: %v", setDbSchema, execErr)
} else {
log.Printf("Successfully cleaned up table %s.kv_vise", setDbSchema)
}
} else {
setConnStr, _ := filepath.Abs(setConnStr)
if err := os.RemoveAll(setConnStr); err != nil {
log.Fatalf("Failed to delete state store %s: %v", setConnStr, err)
}
}
}
// updateSearchPath updates the search_path (schema) to be used in the connection
func updateSearchPath(connStr string, newSearchPath string) (string, error) {
u, err := url.Parse(connStr)
if err != nil {
return "", fmt.Errorf("invalid connection string: %w", err)
}
// Parse the query parameters
q := u.Query()
// Update or add the search_path parameter
q.Set("search_path", newSearchPath)
// Rebuild the connection string with updated parameters
u.RawQuery = q.Encode()
return u.String(), nil
}
func TestEngine(sessionId string) (engine.Engine, func(), chan bool) { func TestEngine(sessionId string) (engine.Engine, func(), chan bool) {
var err error
ctx := context.Background() ctx := context.Background()
ctx = context.WithValue(ctx, "SessionId", sessionId) ctx = context.WithValue(ctx, "SessionId", sessionId)
ctx = context.WithValue(ctx, "Database", selectedDatabase)
ctx = context.WithValue(ctx, "Schema", selectedDbSchema)
pfp := path.Join(scriptDir, "pp.csv") pfp := path.Join(scriptDir, "pp.csv")
var eventChannel = make(chan bool) var eventChannel = make(chan bool)
@ -39,16 +113,27 @@ func TestEngine(sessionId string) (engine.Engine, func(), chan bool) {
FlagCount: uint32(128), FlagCount: uint32(128),
} }
connStr, err := filepath.Abs(".test_state/state.gdbm") if setDbType == "postgres" {
if err != nil { setConnStr = config.DbConn
fmt.Fprintf(os.Stderr, "connstr err: %v", err) setConnStr, err = updateSearchPath(setConnStr, setDbSchema)
os.Exit(1) if err != nil {
fmt.Println("Error:", err)
os.Exit(1)
}
} else {
setConnStr, err = filepath.Abs(setConnStr)
if err != nil {
fmt.Fprintf(os.Stderr, "connstr err: %v", err)
os.Exit(1)
}
} }
conn, err := storage.ToConnData(connStr)
conn, err := storage.ToConnData(setConnStr)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "connstr parse err: %v", err) fmt.Fprintf(os.Stderr, "connstr parse err: %v", err)
os.Exit(1) os.Exit(1)
} }
resourceDir := scriptDir resourceDir := scriptDir
menuStorageService := storage.NewMenuStorageService(conn, resourceDir) menuStorageService := storage.NewMenuStorageService(conn, resourceDir)

View File

@ -6,8 +6,6 @@ import (
"flag" "flag"
"log" "log"
"math/rand" "math/rand"
"os"
"path/filepath"
"regexp" "regexp"
"testing" "testing"
@ -24,11 +22,9 @@ var (
) )
var groupTestFile = flag.String("test-file", "group_test.json", "The test file to use for running the group tests") var groupTestFile = flag.String("test-file", "group_test.json", "The test file to use for running the group tests")
var database = flag.String("db", "gdbm", "Specify the database (gdbm or postgres)")
func testStore() string { var connStr = flag.String("conn", ".test_state", "connection string")
v, _ := filepath.Abs(".test_state/state.gdbm") var dbSchema = flag.String("schema", "test", "Specify the database schema (default test)")
return v
}
func GenerateSessionId() string { func GenerateSessionId() string {
uu := uuid.NewGenWithOptions(uuid.WithRandomReader(g)) uu := uuid.NewGenWithOptions(uuid.WithRandomReader(g))
@ -84,12 +80,15 @@ func extractSendAmount(response []byte) string {
} }
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
// Parse the flags
flag.Parse()
sessionID = GenerateSessionId() sessionID = GenerateSessionId()
defer func() { // set the db
if err := os.RemoveAll(testStore()); err != nil { testutil.SetDatabase(*database, *connStr, *dbSchema)
log.Fatalf("Failed to delete state store %s: %v", testStore(), err)
} // Cleanup the db after tests
}() defer testutil.CleanDatabase()
m.Run() m.Run()
} }
@ -126,7 +125,6 @@ func TestAccountCreationSuccessful(t *testing.T) {
} }
} }
<-eventChannel <-eventChannel
} }
func TestAccountRegistrationRejectTerms(t *testing.T) { func TestAccountRegistrationRejectTerms(t *testing.T) {