diff --git a/internal/testutil/engine.go b/internal/testutil/engine.go index 678b345..0bb07ec 100644 --- a/internal/testutil/engine.go +++ b/internal/testutil/engine.go @@ -3,6 +3,8 @@ package testutil import ( "context" "fmt" + "log" + "net/url" "os" "path" "path/filepath" @@ -11,35 +13,90 @@ import ( "git.defalsify.org/vise.git/engine" "git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/resource" + "git.grassecon.net/urdt/ussd/config" "git.grassecon.net/urdt/ussd/initializers" "git.grassecon.net/urdt/ussd/internal/handlers" "git.grassecon.net/urdt/ussd/internal/storage" "git.grassecon.net/urdt/ussd/internal/testutil/testservice" "git.grassecon.net/urdt/ussd/internal/testutil/testtag" "git.grassecon.net/urdt/ussd/remote" + "github.com/jackc/pgx/v5/pgxpool" testdataloader "github.com/peteole/testdata-loader" ) var ( - baseDir = testdataloader.GetBasePath() - logg = logging.NewVanilla() - baseDir = testdataloader.GetBasePath() - scriptDir = path.Join(baseDir, "services", "registration") - selectedDatabase = "" - selectedDbSchema = "" + logg = logging.NewVanilla() + baseDir = testdataloader.GetBasePath() + scriptDir = path.Join(baseDir, "services", "registration") + setDbType string + setConnStr string + setDbSchema string ) func init() { initializers.LoadEnvVariablesPath(baseDir) + config.LoadConfig() } // SetDatabase updates the database used by TestEngine -func SetDatabase(dbType string, dbSchema string) { - selectedDatabase = dbType - selectedDbSchema = dbSchema +func SetDatabase(database, connStr, dbSchema string) { + setDbType = database + setConnStr = connStr + setDbSchema = dbSchema +} + +// CleanDatabase removes all test data from the database +func CleanDatabase() { + if setDbType == "postgres" { + ctx := context.Background() + // Update the connection string with the new search path + updatedConnStr, err := updateSearchPath(setConnStr, setDbSchema) + if err != nil { + log.Fatalf("Failed to update search path: %v", err) + } + + dbConn, err := pgxpool.New(ctx, updatedConnStr) + if err != nil { + log.Fatalf("Failed to connect to database for cleanup: %v", err) + } + defer dbConn.Close() + + query := fmt.Sprintf("DELETE FROM %s.kv_vise;", setDbSchema) + _, execErr := dbConn.Exec(ctx, query) + if execErr != nil { + log.Printf("Failed to cleanup table %s.kv_vise: %v", setDbSchema, execErr) + } else { + log.Printf("Successfully cleaned up table %s.kv_vise", setDbSchema) + } + } else { + setConnStr, _ := filepath.Abs(setConnStr) + if err := os.RemoveAll(setConnStr); err != nil { + log.Fatalf("Failed to delete state store %s: %v", setConnStr, err) + } + } +} + +// updateSearchPath updates the search_path (schema) to be used in the connection +func updateSearchPath(connStr string, newSearchPath string) (string, error) { + u, err := url.Parse(connStr) + if err != nil { + return "", fmt.Errorf("invalid connection string: %w", err) + } + + // Parse the query parameters + q := u.Query() + + // Update or add the search_path parameter + q.Set("search_path", newSearchPath) + + // Rebuild the connection string with updated parameters + u.RawQuery = q.Encode() + + return u.String(), nil } func TestEngine(sessionId string) (engine.Engine, func(), chan bool) { + var err error ctx := context.Background() ctx = context.WithValue(ctx, "SessionId", sessionId) ctx = context.WithValue(ctx, "Database", selectedDatabase) @@ -56,16 +113,27 @@ func TestEngine(sessionId string) (engine.Engine, func(), chan bool) { FlagCount: uint32(128), } - connStr, err := filepath.Abs(".test_state/state.gdbm") - if err != nil { - fmt.Fprintf(os.Stderr, "connstr err: %v", err) - os.Exit(1) + if setDbType == "postgres" { + setConnStr = config.DbConn + setConnStr, err = updateSearchPath(setConnStr, setDbSchema) + if err != nil { + fmt.Println("Error:", err) + os.Exit(1) + } + } else { + setConnStr, err = filepath.Abs(setConnStr) + if err != nil { + fmt.Fprintf(os.Stderr, "connstr err: %v", err) + os.Exit(1) + } } - conn, err := storage.ToConnData(connStr) + + conn, err := storage.ToConnData(setConnStr) if err != nil { fmt.Fprintf(os.Stderr, "connstr parse err: %v", err) os.Exit(1) } + resourceDir := scriptDir menuStorageService := storage.NewMenuStorageService(conn, resourceDir) diff --git a/menutraversal_test/menu_traversal_test.go b/menutraversal_test/menu_traversal_test.go index 4003641..4aee26e 100644 --- a/menutraversal_test/menu_traversal_test.go +++ b/menutraversal_test/menu_traversal_test.go @@ -4,18 +4,14 @@ import ( "bytes" "context" "flag" - "fmt" "log" "math/rand" - "os" - "path/filepath" "regexp" "testing" "git.grassecon.net/urdt/ussd/internal/testutil" "git.grassecon.net/urdt/ussd/internal/testutil/driver" "github.com/gofrs/uuid" - "github.com/jackc/pgx/v5/pgxpool" ) var ( @@ -27,13 +23,9 @@ var ( var groupTestFile = flag.String("test-file", "group_test.json", "The test file to use for running the group tests") var database = flag.String("db", "gdbm", "Specify the database (gdbm or postgres)") +var connStr = flag.String("conn", ".test_state", "connection string") var dbSchema = flag.String("schema", "test", "Specify the database schema (default test)") -func testStore() string { - v, _ := filepath.Abs(".test_state/state.gdbm") - return v -} - func GenerateSessionId() string { uu := uuid.NewGenWithOptions(uuid.WithRandomReader(g)) v, err := uu.NewV4() @@ -90,37 +82,12 @@ func extractSendAmount(response []byte) string { func TestMain(m *testing.M) { // Parse the flags flag.Parse() - sessionID = GenerateSessionId() - defer func() { - if err := os.RemoveAll(testStore()); err != nil { - log.Fatalf("Failed to delete state store %s: %v", testStore(), err) - } - }() + // set the db + testutil.SetDatabase(*database, *connStr, *dbSchema) - // Set the selected database - testutil.SetDatabase(*database, *dbSchema) - - // Cleanup the schema table after tests - defer func() { - if *database == "postgres" { - ctx := context.Background() - connStr := "postgres://" //storage.BuildConnStr() - dbConn, err := pgxpool.New(ctx, connStr) - if err != nil { - log.Fatalf("Failed to connect to database for cleanup: %v", err) - } - defer dbConn.Close() - - query := fmt.Sprintf("DELETE FROM %s.kv_vise;", *dbSchema) - _, execErr := dbConn.Exec(ctx, query) - if execErr != nil { - log.Printf("Failed to cleanup table %s.kv_vise: %v", *dbSchema, execErr) - } else { - log.Printf("Successfully cleaned up table %s.kv_vise", *dbSchema) - } - } - }() + // Cleanup the db after tests + defer testutil.CleanDatabase() m.Run() }