diff --git a/cmd/africastalking/main.go b/cmd/africastalking/main.go index 72d3944..24812a1 100644 --- a/cmd/africastalking/main.go +++ b/cmd/africastalking/main.go @@ -12,17 +12,17 @@ import ( "syscall" "git.defalsify.org/vise.git/engine" + "git.defalsify.org/vise.git/lang" "git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/resource" - "git.defalsify.org/vise.git/lang" "git.grassecon.net/urdt/ussd/config" "git.grassecon.net/urdt/ussd/initializers" + "git.grassecon.net/urdt/ussd/internal/args" "git.grassecon.net/urdt/ussd/internal/handlers" "git.grassecon.net/urdt/ussd/internal/http/at" "git.grassecon.net/urdt/ussd/internal/storage" "git.grassecon.net/urdt/ussd/remote" - "git.grassecon.net/urdt/ussd/internal/args" ) var ( @@ -42,7 +42,6 @@ func main() { var connStr string var resourceDir string var size uint - var database string var engineDebug bool var host string var port uint @@ -60,7 +59,7 @@ func main() { flag.Var(&langs, "language", "add symbol resolution for language") flag.Parse() - if connStr != "" { + if connStr == "" { connStr = config.DbConn } connData, err := storage.ToConnData(connStr) @@ -72,7 +71,6 @@ func main() { logg.Infof("start command", "build", build, "conn", connData, "resourcedir", resourceDir, "outputsize", size) ctx := context.Background() - ctx = context.WithValue(ctx, "Database", database) ln, err := lang.LanguageFromCode(config.DefaultLanguage) if err != nil { fmt.Fprintf(os.Stderr, "default language set error: %v", err) diff --git a/cmd/async/main.go b/cmd/async/main.go index dc293e6..27db453 100644 --- a/cmd/async/main.go +++ b/cmd/async/main.go @@ -10,16 +10,16 @@ import ( "syscall" "git.defalsify.org/vise.git/engine" + "git.defalsify.org/vise.git/lang" "git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/resource" - "git.defalsify.org/vise.git/lang" "git.grassecon.net/urdt/ussd/config" "git.grassecon.net/urdt/ussd/initializers" + "git.grassecon.net/urdt/ussd/internal/args" "git.grassecon.net/urdt/ussd/internal/handlers" "git.grassecon.net/urdt/ussd/internal/storage" "git.grassecon.net/urdt/ussd/remote" - "git.grassecon.net/urdt/ussd/internal/args" ) var ( @@ -52,7 +52,6 @@ func main() { var sessionId string var resourceDir string var size uint - var database string var engineDebug bool var host string var port uint @@ -71,7 +70,7 @@ func main() { flag.Var(&langs, "language", "add symbol resolution for language") flag.Parse() - if connStr != "" { + if connStr == "" { connStr = config.DbConn } connData, err := storage.ToConnData(connStr) @@ -83,7 +82,6 @@ func main() { logg.Infof("start command", "conn", connData, "resourcedir", resourceDir, "outputsize", size, "sessionId", sessionId) ctx := context.Background() - ctx = context.WithValue(ctx, "Database", database) ln, err := lang.LanguageFromCode(config.DefaultLanguage) if err != nil { @@ -117,7 +115,6 @@ func main() { os.Exit(1) } - userdataStore, err := menuStorageService.GetUserdataDb(ctx) if err != nil { fmt.Fprintf(os.Stderr, err.Error()) diff --git a/cmd/http/main.go b/cmd/http/main.go index 8e65232..6617ca5 100644 --- a/cmd/http/main.go +++ b/cmd/http/main.go @@ -12,22 +12,22 @@ import ( "syscall" "git.defalsify.org/vise.git/engine" + "git.defalsify.org/vise.git/lang" "git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/resource" - "git.defalsify.org/vise.git/lang" "git.grassecon.net/urdt/ussd/config" "git.grassecon.net/urdt/ussd/initializers" + "git.grassecon.net/urdt/ussd/internal/args" "git.grassecon.net/urdt/ussd/internal/handlers" httpserver "git.grassecon.net/urdt/ussd/internal/http" "git.grassecon.net/urdt/ussd/internal/storage" "git.grassecon.net/urdt/ussd/remote" - "git.grassecon.net/urdt/ussd/internal/args" ) var ( - logg = logging.NewVanilla() - scriptDir = path.Join("services", "registration") + logg = logging.NewVanilla() + scriptDir = path.Join("services", "registration") menuSeparator = ": " ) @@ -41,7 +41,6 @@ func main() { var connStr string var resourceDir string var size uint - var database string var engineDebug bool var host string var port uint @@ -59,7 +58,7 @@ func main() { flag.Var(&langs, "language", "add symbol resolution for language") flag.Parse() - if connStr != "" { + if connStr == "" { connStr = config.DbConn } connData, err := storage.ToConnData(connStr) @@ -71,7 +70,6 @@ func main() { logg.Infof("start command", "conn", connData, "resourcedir", resourceDir, "outputsize", size) ctx := context.Background() - ctx = context.WithValue(ctx, "Database", database) ln, err := lang.LanguageFromCode(config.DefaultLanguage) if err != nil { @@ -94,7 +92,7 @@ func main() { } menuStorageService := storage.NewMenuStorageService(connData, resourceDir) - + rs, err := menuStorageService.GetResource(ctx) if err != nil { fmt.Fprintf(os.Stderr, err.Error()) diff --git a/cmd/main.go b/cmd/main.go index 3939c9d..d2fe0ba 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -8,14 +8,14 @@ import ( "path" "git.defalsify.org/vise.git/engine" + "git.defalsify.org/vise.git/lang" "git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/resource" - "git.defalsify.org/vise.git/lang" "git.grassecon.net/urdt/ussd/config" "git.grassecon.net/urdt/ussd/initializers" + "git.grassecon.net/urdt/ussd/internal/args" "git.grassecon.net/urdt/ussd/internal/handlers" "git.grassecon.net/urdt/ussd/internal/storage" - "git.grassecon.net/urdt/ussd/internal/args" "git.grassecon.net/urdt/ussd/remote" ) @@ -36,7 +36,6 @@ func main() { var connStr string var size uint var sessionId string - var database string var engineDebug bool var resourceDir string var err error @@ -52,7 +51,7 @@ func main() { flag.Var(&langs, "language", "add symbol resolution for language") flag.Parse() - if connStr != "" { + if connStr == "" { connStr = config.DbConn } connData, err := storage.ToConnData(connStr) @@ -69,7 +68,6 @@ func main() { ctx := context.Background() ctx = context.WithValue(ctx, "SessionId", sessionId) - ctx = context.WithValue(ctx, "Database", database) ln, err := lang.LanguageFromCode(config.DefaultLanguage) if err != nil { @@ -89,7 +87,7 @@ func main() { } menuStorageService := storage.NewMenuStorageService(connData, resourceDir) - + if gettextDir != "" { menuStorageService = menuStorageService.WithGettext(gettextDir, langs.Langs()) } diff --git a/devtools/store/generate/main.go b/devtools/store/generate/main.go index 749f340..c421d1a 100644 --- a/devtools/store/generate/main.go +++ b/devtools/store/generate/main.go @@ -9,14 +9,16 @@ import ( "path" "git.defalsify.org/vise.git/logging" - "git.grassecon.net/urdt/ussd/config" - "git.grassecon.net/urdt/ussd/internal/storage" - "git.grassecon.net/urdt/ussd/initializers" "git.grassecon.net/urdt/ussd/common" + "git.grassecon.net/urdt/ussd/config" + "git.grassecon.net/urdt/ussd/initializers" + "git.grassecon.net/urdt/ussd/internal/storage" + testdataloader "github.com/peteole/testdata-loader" ) var ( logg = logging.NewVanilla() + baseDir = testdataloader.GetBasePath() scriptDir = path.Join("services", "registration") ) @@ -24,7 +26,6 @@ func init() { initializers.LoadEnvVariables() } - func main() { config.LoadConfig() @@ -86,5 +87,4 @@ func main() { fmt.Fprintf(os.Stderr, err.Error()) os.Exit(1) } - } diff --git a/initializers/load.go b/initializers/load.go index 4ea5980..fc61746 100644 --- a/initializers/load.go +++ b/initializers/load.go @@ -3,24 +3,30 @@ package initializers import ( "log" "os" + "path" "strconv" "github.com/joho/godotenv" ) func LoadEnvVariables() { - err := godotenv.Load() + LoadEnvVariablesPath(".") +} + +func LoadEnvVariablesPath(dir string) { + fp := path.Join(dir, ".env") + err := godotenv.Load(fp) if err != nil { - log.Fatal("Error loading .env file") + log.Fatal("Error loading .env file", err) } } // Helper to get environment variables with a default fallback func GetEnv(key, defaultVal string) string { - if value, exists := os.LookupEnv(key); exists { - return value + if value, exists := os.LookupEnv(key); exists { + return value } - return defaultVal + return defaultVal } // Helper to safely convert environment variables to uint diff --git a/internal/storage/parse.go b/internal/storage/parse.go index bb25627..41dac6b 100644 --- a/internal/storage/parse.go +++ b/internal/storage/parse.go @@ -15,6 +15,7 @@ const ( type ConnData struct { typ int str string + domain string } func (cd *ConnData) DbType() int { @@ -25,23 +26,38 @@ func (cd *ConnData) String() string { return cd.str } -func probePostgres(s string) (string, bool) { - v, err := url.Parse(s) - if err != nil { - return "", false - } - if v.Scheme != "postgres" { - return "", false - } - return s, true +func (cd *ConnData) Domain() string { + return cd.domain } -func probeGdbm(s string) (string, bool) { +func (cd *ConnData) Path() string { + v, _ := url.Parse(cd.str) + v.RawQuery = "" + return v.String() +} + +func probePostgres(s string) (string, string, bool) { + domain := "public" + v, err := url.Parse(s) + if err != nil { + return "", "", false + } + if v.Scheme != "postgres" { + return "", "", false + } + vv := v.Query() + if vv.Has("search_path") { + domain = vv.Get("search_path") + } + return s, domain, true +} + +func probeGdbm(s string) (string, string, bool) { if !path.IsAbs(s) { - return "", false + return "", "", false } s = path.Clean(s) - return s, true + return s, "", true } func ToConnData(connStr string) (ConnData, error) { @@ -51,14 +67,15 @@ func ToConnData(connStr string) (ConnData, error) { return o, nil } - v, ok := probePostgres(connStr) + v, domain, ok := probePostgres(connStr) if ok { o.typ = DBTYPE_POSTGRES o.str = v + o.domain = domain return o, nil } - v, ok = probeGdbm(connStr) + v, _, ok = probeGdbm(connStr) if ok { o.typ = DBTYPE_GDBM o.str = v diff --git a/internal/storage/storageservice.go b/internal/storage/storageservice.go index 2e093a5..374af74 100644 --- a/internal/storage/storageservice.go +++ b/internal/storage/storageservice.go @@ -14,6 +14,7 @@ import ( "git.defalsify.org/vise.git/persist" "git.defalsify.org/vise.git/resource" gdbmstorage "git.grassecon.net/urdt/ussd/internal/storage/db/gdbm" + "github.com/jackc/pgx/v5/pgxpool" ) var ( @@ -54,7 +55,12 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D connStr := ms.conn.String() dbTyp := ms.conn.DbType() if dbTyp == DBTYPE_POSTGRES { - newDb = postgres.NewPgDb() + // TODO: move to vise + err = ensureSchemaExists(ctx, ms.conn) + if err != nil { + return nil, err + } + newDb = postgres.NewPgDb().WithSchema(ms.conn.Domain()) } else if dbTyp == DBTYPE_GDBM { err = ms.ensureDbDir() if err != nil { @@ -65,7 +71,7 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D } else { return nil, fmt.Errorf("unsupported connection string: '%s'\n", ms.conn.String()) } - logg.DebugCtxf(ctx, "connecting to db", "conn", connStr) + logg.DebugCtxf(ctx, "connecting to db", "conn", connStr, "conndata", ms.conn) err = newDb.Connect(ctx, connStr) if err != nil { return nil, err @@ -96,6 +102,23 @@ func (ms *MenuStorageService) WithGettext(path string, lns []lang.Language) *Men return ms } +// ensureSchemaExists creates a new schema if it does not exist +func ensureSchemaExists(ctx context.Context, conn ConnData) error { + h, err := pgxpool.New(ctx, conn.Path()) + if err != nil { + return fmt.Errorf("failed to connect to the database: %w", err) + } + defer h.Close() + + query := fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", conn.Domain()) + _, err = h.Exec(ctx, query) + if err != nil { + return fmt.Errorf("failed to create schema: %w", err) + } + + return nil +} + func (ms *MenuStorageService) GetPersister(ctx context.Context) (*persist.Persister, error) { stateStore, err := ms.GetStateStore(ctx) if err != nil { diff --git a/internal/testutil/engine.go b/internal/testutil/engine.go index 2372ce9..5d581ba 100644 --- a/internal/testutil/engine.go +++ b/internal/testutil/engine.go @@ -3,6 +3,8 @@ package testutil import ( "context" "fmt" + "log" + "net/url" "os" "path" "path/filepath" @@ -11,21 +13,90 @@ import ( "git.defalsify.org/vise.git/engine" "git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/resource" + "git.grassecon.net/urdt/ussd/config" + "git.grassecon.net/urdt/ussd/initializers" "git.grassecon.net/urdt/ussd/internal/handlers" "git.grassecon.net/urdt/ussd/internal/storage" "git.grassecon.net/urdt/ussd/internal/testutil/testservice" "git.grassecon.net/urdt/ussd/internal/testutil/testtag" - testdataloader "github.com/peteole/testdata-loader" "git.grassecon.net/urdt/ussd/remote" + "github.com/jackc/pgx/v5/pgxpool" + testdataloader "github.com/peteole/testdata-loader" ) var ( - baseDir = testdataloader.GetBasePath() - logg = logging.NewVanilla() - scriptDir = path.Join(baseDir, "services", "registration") + logg = logging.NewVanilla() + baseDir = testdataloader.GetBasePath() + scriptDir = path.Join(baseDir, "services", "registration") + setDbType string + setConnStr string + setDbSchema string ) +func init() { + initializers.LoadEnvVariablesPath(baseDir) + config.LoadConfig() +} + +// SetDatabase updates the database used by TestEngine +func SetDatabase(database, connStr, dbSchema string) { + setDbType = database + setConnStr = connStr + setDbSchema = dbSchema +} + +// CleanDatabase removes all test data from the database +func CleanDatabase() { + if setDbType == "postgres" { + ctx := context.Background() + // Update the connection string with the new search path + updatedConnStr, err := updateSearchPath(setConnStr, setDbSchema) + if err != nil { + log.Fatalf("Failed to update search path: %v", err) + } + + dbConn, err := pgxpool.New(ctx, updatedConnStr) + if err != nil { + log.Fatalf("Failed to connect to database for cleanup: %v", err) + } + defer dbConn.Close() + + query := fmt.Sprintf("DELETE FROM %s.kv_vise;", setDbSchema) + _, execErr := dbConn.Exec(ctx, query) + if execErr != nil { + log.Printf("Failed to cleanup table %s.kv_vise: %v", setDbSchema, execErr) + } else { + log.Printf("Successfully cleaned up table %s.kv_vise", setDbSchema) + } + } else { + setConnStr, _ := filepath.Abs(setConnStr) + if err := os.RemoveAll(setConnStr); err != nil { + log.Fatalf("Failed to delete state store %s: %v", setConnStr, err) + } + } +} + +// updateSearchPath updates the search_path (schema) to be used in the connection +func updateSearchPath(connStr string, newSearchPath string) (string, error) { + u, err := url.Parse(connStr) + if err != nil { + return "", fmt.Errorf("invalid connection string: %w", err) + } + + // Parse the query parameters + q := u.Query() + + // Update or add the search_path parameter + q.Set("search_path", newSearchPath) + + // Rebuild the connection string with updated parameters + u.RawQuery = q.Encode() + + return u.String(), nil +} + func TestEngine(sessionId string) (engine.Engine, func(), chan bool) { + var err error ctx := context.Background() ctx = context.WithValue(ctx, "SessionId", sessionId) pfp := path.Join(scriptDir, "pp.csv") @@ -39,16 +110,27 @@ func TestEngine(sessionId string) (engine.Engine, func(), chan bool) { FlagCount: uint32(128), } - connStr, err := filepath.Abs(".test_state/state.gdbm") - if err != nil { - fmt.Fprintf(os.Stderr, "connstr err: %v", err) - os.Exit(1) + if setDbType == "postgres" { + setConnStr = config.DbConn + setConnStr, err = updateSearchPath(setConnStr, setDbSchema) + if err != nil { + fmt.Println("Error:", err) + os.Exit(1) + } + } else { + setConnStr, err = filepath.Abs(setConnStr) + if err != nil { + fmt.Fprintf(os.Stderr, "connstr err: %v", err) + os.Exit(1) + } } - conn, err := storage.ToConnData(connStr) + + conn, err := storage.ToConnData(setConnStr) if err != nil { fmt.Fprintf(os.Stderr, "connstr parse err: %v", err) os.Exit(1) } + resourceDir := scriptDir menuStorageService := storage.NewMenuStorageService(conn, resourceDir) diff --git a/menutraversal_test/menu_traversal_test.go b/menutraversal_test/menu_traversal_test.go index 52e2273..4aee26e 100644 --- a/menutraversal_test/menu_traversal_test.go +++ b/menutraversal_test/menu_traversal_test.go @@ -6,8 +6,6 @@ import ( "flag" "log" "math/rand" - "os" - "path/filepath" "regexp" "testing" @@ -24,11 +22,9 @@ var ( ) var groupTestFile = flag.String("test-file", "group_test.json", "The test file to use for running the group tests") - -func testStore() string { - v, _ := filepath.Abs(".test_state/state.gdbm") - return v -} +var database = flag.String("db", "gdbm", "Specify the database (gdbm or postgres)") +var connStr = flag.String("conn", ".test_state", "connection string") +var dbSchema = flag.String("schema", "test", "Specify the database schema (default test)") func GenerateSessionId() string { uu := uuid.NewGenWithOptions(uuid.WithRandomReader(g)) @@ -84,12 +80,15 @@ func extractSendAmount(response []byte) string { } func TestMain(m *testing.M) { + // Parse the flags + flag.Parse() sessionID = GenerateSessionId() - defer func() { - if err := os.RemoveAll(testStore()); err != nil { - log.Fatalf("Failed to delete state store %s: %v", testStore(), err) - } - }() + // set the db + testutil.SetDatabase(*database, *connStr, *dbSchema) + + // Cleanup the db after tests + defer testutil.CleanDatabase() + m.Run() } @@ -126,7 +125,6 @@ func TestAccountCreationSuccessful(t *testing.T) { } } <-eventChannel - } func TestAccountRegistrationRejectTerms(t *testing.T) {