From 3ee15497a5bc080f424feb0008933bb53143c8de Mon Sep 17 00:00:00 2001 From: alfred-mk Date: Mon, 6 Jan 2025 14:50:39 +0300 Subject: [PATCH 01/12] specify the base directory for loading the .env file --- initializers/load.go | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/initializers/load.go b/initializers/load.go index 4ea5980..4cbeb0e 100644 --- a/initializers/load.go +++ b/initializers/load.go @@ -3,24 +3,26 @@ package initializers import ( "log" "os" + "path" "strconv" "github.com/joho/godotenv" ) -func LoadEnvVariables() { - err := godotenv.Load() +func LoadEnvVariables(baseDir string) { + envDir := path.Join(baseDir, ".env") + err := godotenv.Load(envDir) if err != nil { - log.Fatal("Error loading .env file") + log.Fatal("Error loading .env file", err) } } // Helper to get environment variables with a default fallback func GetEnv(key, defaultVal string) string { - if value, exists := os.LookupEnv(key); exists { - return value + if value, exists := os.LookupEnv(key); exists { + return value } - return defaultVal + return defaultVal } // Helper to safely convert environment variables to uint -- 2.45.2 From 79de0a9092f9fa7f9e5f610203a0574353d1b9a9 Mon Sep 17 00:00:00 2001 From: alfred-mk Date: Mon, 6 Jan 2025 14:54:04 +0300 Subject: [PATCH 02/12] pass the base directory to load the .env file --- cmd/africastalking/main.go | 4 +++- cmd/async/main.go | 4 +++- cmd/http/main.go | 8 +++++--- cmd/main.go | 4 +++- devtools/gen/main.go | 12 ++++++------ devtools/store/main.go | 13 +++++++------ 6 files changed, 27 insertions(+), 18 deletions(-) diff --git a/cmd/africastalking/main.go b/cmd/africastalking/main.go index dfcaca1..4ca8400 100644 --- a/cmd/africastalking/main.go +++ b/cmd/africastalking/main.go @@ -14,6 +14,7 @@ import ( "git.defalsify.org/vise.git/engine" "git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/resource" + testdataloader "github.com/peteole/testdata-loader" "git.grassecon.net/urdt/ussd/config" "git.grassecon.net/urdt/ussd/initializers" @@ -26,13 +27,14 @@ import ( var ( logg = logging.NewVanilla().WithDomain("AfricasTalking").WithContextKey("at-session-id") + baseDir = testdataloader.GetBasePath() scriptDir = path.Join("services", "registration") build = "dev" menuSeparator = ": " ) func init() { - initializers.LoadEnvVariables() + initializers.LoadEnvVariables(baseDir) } func main() { config.LoadConfig() diff --git a/cmd/async/main.go b/cmd/async/main.go index bf23d9f..51b9e40 100644 --- a/cmd/async/main.go +++ b/cmd/async/main.go @@ -12,6 +12,7 @@ import ( "git.defalsify.org/vise.git/engine" "git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/resource" + testdataloader "github.com/peteole/testdata-loader" "git.grassecon.net/urdt/ussd/config" "git.grassecon.net/urdt/ussd/initializers" @@ -22,12 +23,13 @@ import ( var ( logg = logging.NewVanilla() + baseDir = testdataloader.GetBasePath() scriptDir = path.Join("services", "registration") menuSeparator = ": " ) func init() { - initializers.LoadEnvVariables() + initializers.LoadEnvVariables(baseDir) } type asyncRequestParser struct { diff --git a/cmd/http/main.go b/cmd/http/main.go index 6ddfded..46dbe91 100644 --- a/cmd/http/main.go +++ b/cmd/http/main.go @@ -14,6 +14,7 @@ import ( "git.defalsify.org/vise.git/engine" "git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/resource" + testdataloader "github.com/peteole/testdata-loader" "git.grassecon.net/urdt/ussd/config" "git.grassecon.net/urdt/ussd/initializers" @@ -24,13 +25,14 @@ import ( ) var ( - logg = logging.NewVanilla() - scriptDir = path.Join("services", "registration") + logg = logging.NewVanilla() + baseDir = testdataloader.GetBasePath() + scriptDir = path.Join("services", "registration") menuSeparator = ": " ) func init() { - initializers.LoadEnvVariables() + initializers.LoadEnvVariables(baseDir) } func main() { diff --git a/cmd/main.go b/cmd/main.go index 4fd084f..fc6f147 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -15,16 +15,18 @@ import ( "git.grassecon.net/urdt/ussd/internal/handlers" "git.grassecon.net/urdt/ussd/internal/storage" "git.grassecon.net/urdt/ussd/remote" + testdataloader "github.com/peteole/testdata-loader" ) var ( logg = logging.NewVanilla() + baseDir = testdataloader.GetBasePath() scriptDir = path.Join("services", "registration") menuSeparator = ": " ) func init() { - initializers.LoadEnvVariables() + initializers.LoadEnvVariables(baseDir) } func main() { diff --git a/devtools/gen/main.go b/devtools/gen/main.go index b9e2aed..f54afb7 100644 --- a/devtools/gen/main.go +++ b/devtools/gen/main.go @@ -9,22 +9,23 @@ import ( "path" "git.defalsify.org/vise.git/logging" - "git.grassecon.net/urdt/ussd/config" - "git.grassecon.net/urdt/ussd/internal/storage" - "git.grassecon.net/urdt/ussd/initializers" "git.grassecon.net/urdt/ussd/common" + "git.grassecon.net/urdt/ussd/config" + "git.grassecon.net/urdt/ussd/initializers" + "git.grassecon.net/urdt/ussd/internal/storage" + testdataloader "github.com/peteole/testdata-loader" ) var ( logg = logging.NewVanilla() + baseDir = testdataloader.GetBasePath() scriptDir = path.Join("services", "registration") ) func init() { - initializers.LoadEnvVariables() + initializers.LoadEnvVariables(baseDir) } - func main() { config.LoadConfig() @@ -75,5 +76,4 @@ func main() { fmt.Fprintf(os.Stderr, err.Error()) os.Exit(1) } - } diff --git a/devtools/store/main.go b/devtools/store/main.go index 8bd4d16..9f3e196 100644 --- a/devtools/store/main.go +++ b/devtools/store/main.go @@ -7,24 +7,25 @@ import ( "os" "path" - "git.grassecon.net/urdt/ussd/config" - "git.grassecon.net/urdt/ussd/initializers" - "git.grassecon.net/urdt/ussd/internal/storage" - "git.grassecon.net/urdt/ussd/debug" "git.defalsify.org/vise.git/db" "git.defalsify.org/vise.git/logging" + "git.grassecon.net/urdt/ussd/config" + "git.grassecon.net/urdt/ussd/debug" + "git.grassecon.net/urdt/ussd/initializers" + "git.grassecon.net/urdt/ussd/internal/storage" + testdataloader "github.com/peteole/testdata-loader" ) var ( logg = logging.NewVanilla() + baseDir = testdataloader.GetBasePath() scriptDir = path.Join("services", "registration") ) func init() { - initializers.LoadEnvVariables() + initializers.LoadEnvVariables(baseDir) } - func main() { config.LoadConfig() -- 2.45.2 From c12e867ac37e9d352a2c924622b79fb34885d41a Mon Sep 17 00:00:00 2001 From: alfred-mk Date: Mon, 6 Jan 2025 15:06:25 +0300 Subject: [PATCH 03/12] add a db flag to specify the database of choice --- internal/testutil/TestEngine.go | 21 ++++++++++++++++----- menutraversal_test/menu_traversal_test.go | 8 +++++++- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/internal/testutil/TestEngine.go b/internal/testutil/TestEngine.go index 3fcb307..40a744f 100644 --- a/internal/testutil/TestEngine.go +++ b/internal/testutil/TestEngine.go @@ -10,24 +10,35 @@ import ( "git.defalsify.org/vise.git/engine" "git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/resource" + "git.grassecon.net/urdt/ussd/initializers" "git.grassecon.net/urdt/ussd/internal/handlers" "git.grassecon.net/urdt/ussd/internal/storage" "git.grassecon.net/urdt/ussd/internal/testutil/testservice" "git.grassecon.net/urdt/ussd/internal/testutil/testtag" - testdataloader "github.com/peteole/testdata-loader" "git.grassecon.net/urdt/ussd/remote" + testdataloader "github.com/peteole/testdata-loader" ) var ( - baseDir = testdataloader.GetBasePath() - logg = logging.NewVanilla() - scriptDir = path.Join(baseDir, "services", "registration") + baseDir = testdataloader.GetBasePath() + logg = logging.NewVanilla() + scriptDir = path.Join(baseDir, "services", "registration") + selectedDatabase = "" ) +func init() { + initializers.LoadEnvVariables(baseDir) +} + +// SetDatabase updates the database used by TestEngine +func SetDatabase(dbType string) { + selectedDatabase = dbType +} + func TestEngine(sessionId string) (engine.Engine, func(), chan bool) { ctx := context.Background() ctx = context.WithValue(ctx, "SessionId", sessionId) - ctx = context.WithValue(ctx, "Database", "gdbm") + ctx = context.WithValue(ctx, "Database", selectedDatabase) pfp := path.Join(scriptDir, "pp.csv") var eventChannel = make(chan bool) diff --git a/menutraversal_test/menu_traversal_test.go b/menutraversal_test/menu_traversal_test.go index 6b6b3da..8cfe710 100644 --- a/menutraversal_test/menu_traversal_test.go +++ b/menutraversal_test/menu_traversal_test.go @@ -24,6 +24,7 @@ var ( ) var groupTestFile = flag.String("test-file", "group_test.json", "The test file to use for running the group tests") +var database = flag.String("db", "gdbm", "Specify the database (gdbm or postgres)") func GenerateSessionId() string { uu := uuid.NewGenWithOptions(uuid.WithRandomReader(g)) @@ -79,12 +80,18 @@ func extractSendAmount(response []byte) string { } func TestMain(m *testing.M) { + // Parse the flags + flag.Parse() + sessionID = GenerateSessionId() defer func() { if err := os.RemoveAll(testStore); err != nil { log.Fatalf("Failed to delete state store %s: %v", testStore, err) } }() + + // Set the selected database + testutil.SetDatabase(*database) m.Run() } @@ -121,7 +128,6 @@ func TestAccountCreationSuccessful(t *testing.T) { } } <-eventChannel - } func TestAccountRegistrationRejectTerms(t *testing.T) { -- 2.45.2 From 46a6d2bc6e2e732d800750cb67901d1887b119bc Mon Sep 17 00:00:00 2001 From: alfred-mk Date: Wed, 8 Jan 2025 10:37:47 +0300 Subject: [PATCH 04/12] create a schema if it does not exist and use it in the connection --- internal/storage/storageservice.go | 32 +++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/internal/storage/storageservice.go b/internal/storage/storageservice.go index 04e75ce..d252383 100644 --- a/internal/storage/storageservice.go +++ b/internal/storage/storageservice.go @@ -14,6 +14,7 @@ import ( "git.defalsify.org/vise.git/resource" "git.grassecon.net/urdt/ussd/initializers" gdbmstorage "git.grassecon.net/urdt/ussd/internal/storage/db/gdbm" + "github.com/jackc/pgx/v5/pgxpool" ) var ( @@ -64,6 +65,11 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D return nil, fmt.Errorf("failed to select the database") } + schema, ok := ctx.Value("Schema").(string) + if !ok { + return nil, fmt.Errorf("failed to select the schema") + } + if existingDb != nil { return existingDb, nil } @@ -72,8 +78,15 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D var err error if database == "postgres" { - newDb = postgres.NewPgDb() connStr := buildConnStr() + + // Ensure the schema exists + err = ensureSchemaExists(ctx, connStr, schema) + if err != nil { + return nil, fmt.Errorf("failed to ensure schema exists: %w", err) + } + + newDb = postgres.NewPgDb().WithSchema(schema) err = newDb.Connect(ctx, connStr) } else { newDb = gdbmstorage.NewThreadGdbmDb() @@ -88,6 +101,23 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D return newDb, nil } +// ensureSchemaExists creates a new schema if it does not exist +func ensureSchemaExists(ctx context.Context, connStr, schema string) error { + conn, err := pgxpool.New(ctx, connStr) + if err != nil { + return fmt.Errorf("failed to connect to the database: %w", err) + } + defer conn.Close() + + query := fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema) + _, err = conn.Exec(ctx, query) + if err != nil { + return fmt.Errorf("failed to create schema: %w", err) + } + + return nil +} + func (ms *MenuStorageService) GetPersister(ctx context.Context) (*persist.Persister, error) { stateStore, err := ms.GetStateStore(ctx) if err != nil { -- 2.45.2 From 81c3378ea690cc12354145167be755ad7ea0b224 Mon Sep 17 00:00:00 2001 From: alfred-mk Date: Wed, 8 Jan 2025 10:55:43 +0300 Subject: [PATCH 05/12] use a flag to pass the schema to the context --- cmd/africastalking/main.go | 3 +++ cmd/async/main.go | 3 +++ cmd/http/main.go | 3 +++ cmd/main.go | 3 +++ 4 files changed, 12 insertions(+) diff --git a/cmd/africastalking/main.go b/cmd/africastalking/main.go index 4ca8400..0e330ae 100644 --- a/cmd/africastalking/main.go +++ b/cmd/africastalking/main.go @@ -43,12 +43,14 @@ func main() { var resourceDir string var size uint var database string + var dbSchema string var engineDebug bool var host string var port uint flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from") flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir") flag.StringVar(&database, "db", "gdbm", "database to be used") + flag.StringVar(&dbSchema, "schema", "public", "database schema to be used") flag.BoolVar(&engineDebug, "d", false, "use engine debug output") flag.UintVar(&size, "s", 160, "max size of output") flag.StringVar(&host, "h", initializers.GetEnv("HOST", "127.0.0.1"), "http host") @@ -59,6 +61,7 @@ func main() { ctx := context.Background() ctx = context.WithValue(ctx, "Database", database) + ctx = context.WithValue(ctx, "Schema", dbSchema) pfp := path.Join(scriptDir, "pp.csv") cfg := engine.Config{ diff --git a/cmd/async/main.go b/cmd/async/main.go index 51b9e40..91c28c6 100644 --- a/cmd/async/main.go +++ b/cmd/async/main.go @@ -53,6 +53,7 @@ func main() { var resourceDir string var size uint var database string + var dbSchema string var engineDebug bool var host string var port uint @@ -60,6 +61,7 @@ func main() { flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from") flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir") flag.StringVar(&database, "db", "gdbm", "database to be used") + flag.StringVar(&dbSchema, "schema", "public", "database schema to be used") flag.BoolVar(&engineDebug, "d", false, "use engine debug output") flag.UintVar(&size, "s", 160, "max size of output") flag.StringVar(&host, "h", initializers.GetEnv("HOST", "127.0.0.1"), "http host") @@ -70,6 +72,7 @@ func main() { ctx := context.Background() ctx = context.WithValue(ctx, "Database", database) + ctx = context.WithValue(ctx, "Schema", dbSchema) pfp := path.Join(scriptDir, "pp.csv") cfg := engine.Config{ diff --git a/cmd/http/main.go b/cmd/http/main.go index 46dbe91..69e31a1 100644 --- a/cmd/http/main.go +++ b/cmd/http/main.go @@ -42,12 +42,14 @@ func main() { var resourceDir string var size uint var database string + var dbSchema string var engineDebug bool var host string var port uint flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from") flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir") flag.StringVar(&database, "db", "gdbm", "database to be used") + flag.StringVar(&dbSchema, "schema", "public", "database schema to be used") flag.BoolVar(&engineDebug, "d", false, "use engine debug output") flag.UintVar(&size, "s", 160, "max size of output") flag.StringVar(&host, "h", initializers.GetEnv("HOST", "127.0.0.1"), "http host") @@ -58,6 +60,7 @@ func main() { ctx := context.Background() ctx = context.WithValue(ctx, "Database", database) + ctx = context.WithValue(ctx, "Schema", dbSchema) pfp := path.Join(scriptDir, "pp.csv") cfg := engine.Config{ diff --git a/cmd/main.go b/cmd/main.go index fc6f147..484b6c1 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -36,9 +36,11 @@ func main() { var size uint var sessionId string var database string + var dbSchema string var engineDebug bool flag.StringVar(&sessionId, "session-id", "075xx2123", "session id") flag.StringVar(&database, "db", "gdbm", "database to be used") + flag.StringVar(&dbSchema, "schema", "public", "database schema to be used") flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from") flag.BoolVar(&engineDebug, "d", false, "use engine debug output") flag.UintVar(&size, "s", 160, "max size of output") @@ -49,6 +51,7 @@ func main() { ctx := context.Background() ctx = context.WithValue(ctx, "SessionId", sessionId) ctx = context.WithValue(ctx, "Database", database) + ctx = context.WithValue(ctx, "Schema", dbSchema) pfp := path.Join(scriptDir, "pp.csv") cfg := engine.Config{ -- 2.45.2 From f59c3a53ef743a2adfbc4084dd8d2ae91d8d9e92 Mon Sep 17 00:00:00 2001 From: alfred-mk Date: Wed, 8 Jan 2025 10:56:59 +0300 Subject: [PATCH 06/12] allow the BuildConnStr to be accessed by different packages --- internal/storage/storageservice.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/storage/storageservice.go b/internal/storage/storageservice.go index d252383..d333a05 100644 --- a/internal/storage/storageservice.go +++ b/internal/storage/storageservice.go @@ -36,7 +36,7 @@ type MenuStorageService struct { userDataStore db.Db } -func buildConnStr() string { +func BuildConnStr() string { host := initializers.GetEnv("DB_HOST", "localhost") user := initializers.GetEnv("DB_USER", "postgres") password := initializers.GetEnv("DB_PASSWORD", "") @@ -78,7 +78,7 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D var err error if database == "postgres" { - connStr := buildConnStr() + connStr := BuildConnStr() // Ensure the schema exists err = ensureSchemaExists(ctx, connStr, schema) -- 2.45.2 From a37f6e6da30f0f21cf07718237169d47d9ce41e0 Mon Sep 17 00:00:00 2001 From: alfred-mk Date: Wed, 8 Jan 2025 10:57:58 +0300 Subject: [PATCH 07/12] pass the dbschema in the context --- internal/testutil/TestEngine.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/internal/testutil/TestEngine.go b/internal/testutil/TestEngine.go index 40a744f..25a1c15 100644 --- a/internal/testutil/TestEngine.go +++ b/internal/testutil/TestEngine.go @@ -24,6 +24,7 @@ var ( logg = logging.NewVanilla() scriptDir = path.Join(baseDir, "services", "registration") selectedDatabase = "" + selectedDbSchema = "" ) func init() { @@ -31,14 +32,17 @@ func init() { } // SetDatabase updates the database used by TestEngine -func SetDatabase(dbType string) { +func SetDatabase(dbType string, dbSchema string) { selectedDatabase = dbType + selectedDbSchema = dbSchema } func TestEngine(sessionId string) (engine.Engine, func(), chan bool) { ctx := context.Background() ctx = context.WithValue(ctx, "SessionId", sessionId) ctx = context.WithValue(ctx, "Database", selectedDatabase) + ctx = context.WithValue(ctx, "Schema", selectedDbSchema) + pfp := path.Join(scriptDir, "pp.csv") var eventChannel = make(chan bool) -- 2.45.2 From ea9cab930e120666715779f0e83f5c22e296ef2a Mon Sep 17 00:00:00 2001 From: alfred-mk Date: Wed, 8 Jan 2025 10:59:22 +0300 Subject: [PATCH 08/12] cleanup the generated test data for the schema --- menutraversal_test/menu_traversal_test.go | 28 ++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/menutraversal_test/menu_traversal_test.go b/menutraversal_test/menu_traversal_test.go index 8cfe710..d2e353d 100644 --- a/menutraversal_test/menu_traversal_test.go +++ b/menutraversal_test/menu_traversal_test.go @@ -4,15 +4,18 @@ import ( "bytes" "context" "flag" + "fmt" "log" "math/rand" "os" "regexp" "testing" + "git.grassecon.net/urdt/ussd/internal/storage" "git.grassecon.net/urdt/ussd/internal/testutil" "git.grassecon.net/urdt/ussd/internal/testutil/driver" "github.com/gofrs/uuid" + "github.com/jackc/pgx/v5/pgxpool" ) var ( @@ -25,6 +28,7 @@ var ( var groupTestFile = flag.String("test-file", "group_test.json", "The test file to use for running the group tests") var database = flag.String("db", "gdbm", "Specify the database (gdbm or postgres)") +var dbSchema = flag.String("schema", "test", "Specify the database schema (default test)") func GenerateSessionId() string { uu := uuid.NewGenWithOptions(uuid.WithRandomReader(g)) @@ -91,7 +95,29 @@ func TestMain(m *testing.M) { }() // Set the selected database - testutil.SetDatabase(*database) + testutil.SetDatabase(*database, *dbSchema) + + // Cleanup the schema table after tests + defer func() { + if *database == "postgres" { + ctx := context.Background() + connStr := storage.BuildConnStr() + dbConn, err := pgxpool.New(ctx, connStr) + if err != nil { + log.Fatalf("Failed to connect to database for cleanup: %v", err) + } + defer dbConn.Close() + + query := fmt.Sprintf("DELETE FROM %s.kv_vise;", *dbSchema) + _, execErr := dbConn.Exec(ctx, query) + if execErr != nil { + log.Printf("Failed to cleanup table %s.kv_vise: %v", *dbSchema, execErr) + } else { + log.Printf("Successfully cleaned up table %s.kv_vise", *dbSchema) + } + } + }() + m.Run() } -- 2.45.2 From df8c9aab0c056dce7a6ca68aa2b8490efa0daf02 Mon Sep 17 00:00:00 2001 From: lash Date: Wed, 8 Jan 2025 22:27:19 +0000 Subject: [PATCH 09/12] Rehabilitate tests --- cmd/africastalking/main.go | 8 +------- cmd/async/main.go | 8 +------- cmd/http/main.go | 8 +------- cmd/main.go | 8 +------- devtools/store/generate/main.go | 2 +- initializers/load.go | 10 +++++++--- internal/storage/storageservice.go | 8 -------- internal/testutil/engine.go | 4 +++- menutraversal_test/menu_traversal_test.go | 3 +-- 9 files changed, 16 insertions(+), 43 deletions(-) diff --git a/cmd/africastalking/main.go b/cmd/africastalking/main.go index 6be7a11..053eab9 100644 --- a/cmd/africastalking/main.go +++ b/cmd/africastalking/main.go @@ -15,7 +15,6 @@ import ( "git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/resource" "git.defalsify.org/vise.git/lang" - testdataloader "github.com/peteole/testdata-loader" "git.grassecon.net/urdt/ussd/config" "git.grassecon.net/urdt/ussd/initializers" @@ -28,14 +27,13 @@ import ( var ( logg = logging.NewVanilla().WithDomain("AfricasTalking").WithContextKey("at-session-id") - baseDir = testdataloader.GetBasePath() scriptDir = path.Join("services", "registration") build = "dev" menuSeparator = ": " ) func init() { - initializers.LoadEnvVariables(baseDir) + initializers.LoadEnvVariables() } func main() { @@ -44,8 +42,6 @@ func main() { var connStr string var resourceDir string var size uint - var database string - var dbSchema string var engineDebug bool var host string var port uint @@ -75,8 +71,6 @@ func main() { logg.Infof("start command", "build", build, "conn", connData, "resourcedir", resourceDir, "outputsize", size) ctx := context.Background() - ctx = context.WithValue(ctx, "Database", database) - ctx = context.WithValue(ctx, "Schema", dbSchema) ln, err := lang.LanguageFromCode(config.DefaultLanguage) if err != nil { fmt.Fprintf(os.Stderr, "default language set error: %v", err) diff --git a/cmd/async/main.go b/cmd/async/main.go index 59af2e5..0b9a233 100644 --- a/cmd/async/main.go +++ b/cmd/async/main.go @@ -13,7 +13,6 @@ import ( "git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/resource" "git.defalsify.org/vise.git/lang" - testdataloader "github.com/peteole/testdata-loader" "git.grassecon.net/urdt/ussd/config" "git.grassecon.net/urdt/ussd/initializers" @@ -25,13 +24,12 @@ import ( var ( logg = logging.NewVanilla() - baseDir = testdataloader.GetBasePath() scriptDir = path.Join("services", "registration") menuSeparator = ": " ) func init() { - initializers.LoadEnvVariables(baseDir) + initializers.LoadEnvVariables() } type asyncRequestParser struct { @@ -54,8 +52,6 @@ func main() { var sessionId string var resourceDir string var size uint - var database string - var dbSchema string var engineDebug bool var host string var port uint @@ -86,8 +82,6 @@ func main() { logg.Infof("start command", "conn", connData, "resourcedir", resourceDir, "outputsize", size, "sessionId", sessionId) ctx := context.Background() - ctx = context.WithValue(ctx, "Database", database) - ctx = context.WithValue(ctx, "Schema", dbSchema) ln, err := lang.LanguageFromCode(config.DefaultLanguage) if err != nil { diff --git a/cmd/http/main.go b/cmd/http/main.go index af0aefc..761ee72 100644 --- a/cmd/http/main.go +++ b/cmd/http/main.go @@ -15,7 +15,6 @@ import ( "git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/resource" "git.defalsify.org/vise.git/lang" - testdataloader "github.com/peteole/testdata-loader" "git.grassecon.net/urdt/ussd/config" "git.grassecon.net/urdt/ussd/initializers" @@ -28,13 +27,12 @@ import ( var ( logg = logging.NewVanilla() - baseDir = testdataloader.GetBasePath() scriptDir = path.Join("services", "registration") menuSeparator = ": " ) func init() { - initializers.LoadEnvVariables(baseDir) + initializers.LoadEnvVariables() } func main() { @@ -43,8 +41,6 @@ func main() { var connStr string var resourceDir string var size uint - var database string - var dbSchema string var engineDebug bool var host string var port uint @@ -74,8 +70,6 @@ func main() { logg.Infof("start command", "conn", connData, "resourcedir", resourceDir, "outputsize", size) ctx := context.Background() - ctx = context.WithValue(ctx, "Database", database) - ctx = context.WithValue(ctx, "Schema", dbSchema) ln, err := lang.LanguageFromCode(config.DefaultLanguage) if err != nil { diff --git a/cmd/main.go b/cmd/main.go index 5c89c05..8c06094 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -17,18 +17,16 @@ import ( "git.grassecon.net/urdt/ussd/internal/storage" "git.grassecon.net/urdt/ussd/internal/args" "git.grassecon.net/urdt/ussd/remote" - testdataloader "github.com/peteole/testdata-loader" ) var ( logg = logging.NewVanilla() - baseDir = testdataloader.GetBasePath() scriptDir = path.Join("services", "registration") menuSeparator = ": " ) func init() { - initializers.LoadEnvVariables(baseDir) + initializers.LoadEnvVariables() } // TODO: external script automatically generate language handler list from select language vise code OR consider dynamic menu generation script possibility @@ -38,8 +36,6 @@ func main() { var connStr string var size uint var sessionId string - var database string - var dbSchema string var engineDebug bool var resourceDir string var err error @@ -72,8 +68,6 @@ func main() { ctx := context.Background() ctx = context.WithValue(ctx, "SessionId", sessionId) - ctx = context.WithValue(ctx, "Database", database) - ctx = context.WithValue(ctx, "Schema", dbSchema) ln, err := lang.LanguageFromCode(config.DefaultLanguage) if err != nil { diff --git a/devtools/store/generate/main.go b/devtools/store/generate/main.go index 58a9808..c421d1a 100644 --- a/devtools/store/generate/main.go +++ b/devtools/store/generate/main.go @@ -23,7 +23,7 @@ var ( ) func init() { - initializers.LoadEnvVariables(baseDir) + initializers.LoadEnvVariables() } func main() { diff --git a/initializers/load.go b/initializers/load.go index 4cbeb0e..fc61746 100644 --- a/initializers/load.go +++ b/initializers/load.go @@ -9,9 +9,13 @@ import ( "github.com/joho/godotenv" ) -func LoadEnvVariables(baseDir string) { - envDir := path.Join(baseDir, ".env") - err := godotenv.Load(envDir) +func LoadEnvVariables() { + LoadEnvVariablesPath(".") +} + +func LoadEnvVariablesPath(dir string) { + fp := path.Join(dir, ".env") + err := godotenv.Load(fp) if err != nil { log.Fatal("Error loading .env file", err) } diff --git a/internal/storage/storageservice.go b/internal/storage/storageservice.go index f2ac273..617c0ef 100644 --- a/internal/storage/storageservice.go +++ b/internal/storage/storageservice.go @@ -55,14 +55,6 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D connStr := ms.conn.String() dbTyp := ms.conn.DbType() if dbTyp == DBTYPE_POSTGRES { -// // Ensure the schema exists -// err = ensureSchemaExists(ctx, connStr, schema) -// if err != nil { -// return nil, fmt.Errorf("failed to ensure schema exists: %w", err) -// } -// -// newDb = postgres.NewPgDb().WithSchema(schema) - newDb = postgres.NewPgDb() } else if dbTyp == DBTYPE_GDBM { err = ms.ensureDbDir() diff --git a/internal/testutil/engine.go b/internal/testutil/engine.go index a1eefa8..8d5d3e4 100644 --- a/internal/testutil/engine.go +++ b/internal/testutil/engine.go @@ -17,17 +17,19 @@ import ( "git.grassecon.net/urdt/ussd/internal/testutil/testservice" "git.grassecon.net/urdt/ussd/internal/testutil/testtag" "git.grassecon.net/urdt/ussd/remote" + testdataloader "github.com/peteole/testdata-loader" ) var ( logg = logging.NewVanilla() + baseDir = testdataloader.GetBasePath() scriptDir = path.Join(baseDir, "services", "registration") selectedDatabase = "" selectedDbSchema = "" ) func init() { - initializers.LoadEnvVariables() + initializers.LoadEnvVariablesPath(baseDir) } // SetDatabase updates the database used by TestEngine diff --git a/menutraversal_test/menu_traversal_test.go b/menutraversal_test/menu_traversal_test.go index db4586a..4003641 100644 --- a/menutraversal_test/menu_traversal_test.go +++ b/menutraversal_test/menu_traversal_test.go @@ -12,7 +12,6 @@ import ( "regexp" "testing" - "git.grassecon.net/urdt/ussd/internal/storage" "git.grassecon.net/urdt/ussd/internal/testutil" "git.grassecon.net/urdt/ussd/internal/testutil/driver" "github.com/gofrs/uuid" @@ -106,7 +105,7 @@ func TestMain(m *testing.M) { defer func() { if *database == "postgres" { ctx := context.Background() - connStr := storage.BuildConnStr() + connStr := "postgres://" //storage.BuildConnStr() dbConn, err := pgxpool.New(ctx, connStr) if err != nil { log.Fatalf("Failed to connect to database for cleanup: %v", err) -- 2.45.2 From b50a51df9b29ad5a0d4df51537b68f15f525a60f Mon Sep 17 00:00:00 2001 From: lash Date: Thu, 9 Jan 2025 07:42:09 +0000 Subject: [PATCH 10/12] Implement postgres schema --- internal/storage/parse.go | 45 ++++++++++++++++++++---------- internal/storage/storageservice.go | 19 ++++++++----- 2 files changed, 43 insertions(+), 21 deletions(-) diff --git a/internal/storage/parse.go b/internal/storage/parse.go index bb25627..41dac6b 100644 --- a/internal/storage/parse.go +++ b/internal/storage/parse.go @@ -15,6 +15,7 @@ const ( type ConnData struct { typ int str string + domain string } func (cd *ConnData) DbType() int { @@ -25,23 +26,38 @@ func (cd *ConnData) String() string { return cd.str } -func probePostgres(s string) (string, bool) { - v, err := url.Parse(s) - if err != nil { - return "", false - } - if v.Scheme != "postgres" { - return "", false - } - return s, true +func (cd *ConnData) Domain() string { + return cd.domain } -func probeGdbm(s string) (string, bool) { +func (cd *ConnData) Path() string { + v, _ := url.Parse(cd.str) + v.RawQuery = "" + return v.String() +} + +func probePostgres(s string) (string, string, bool) { + domain := "public" + v, err := url.Parse(s) + if err != nil { + return "", "", false + } + if v.Scheme != "postgres" { + return "", "", false + } + vv := v.Query() + if vv.Has("search_path") { + domain = vv.Get("search_path") + } + return s, domain, true +} + +func probeGdbm(s string) (string, string, bool) { if !path.IsAbs(s) { - return "", false + return "", "", false } s = path.Clean(s) - return s, true + return s, "", true } func ToConnData(connStr string) (ConnData, error) { @@ -51,14 +67,15 @@ func ToConnData(connStr string) (ConnData, error) { return o, nil } - v, ok := probePostgres(connStr) + v, domain, ok := probePostgres(connStr) if ok { o.typ = DBTYPE_POSTGRES o.str = v + o.domain = domain return o, nil } - v, ok = probeGdbm(connStr) + v, _, ok = probeGdbm(connStr) if ok { o.typ = DBTYPE_GDBM o.str = v diff --git a/internal/storage/storageservice.go b/internal/storage/storageservice.go index 617c0ef..374af74 100644 --- a/internal/storage/storageservice.go +++ b/internal/storage/storageservice.go @@ -55,7 +55,12 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D connStr := ms.conn.String() dbTyp := ms.conn.DbType() if dbTyp == DBTYPE_POSTGRES { - newDb = postgres.NewPgDb() + // TODO: move to vise + err = ensureSchemaExists(ctx, ms.conn) + if err != nil { + return nil, err + } + newDb = postgres.NewPgDb().WithSchema(ms.conn.Domain()) } else if dbTyp == DBTYPE_GDBM { err = ms.ensureDbDir() if err != nil { @@ -66,7 +71,7 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D } else { return nil, fmt.Errorf("unsupported connection string: '%s'\n", ms.conn.String()) } - logg.DebugCtxf(ctx, "connecting to db", "conn", connStr) + logg.DebugCtxf(ctx, "connecting to db", "conn", connStr, "conndata", ms.conn) err = newDb.Connect(ctx, connStr) if err != nil { return nil, err @@ -98,15 +103,15 @@ func (ms *MenuStorageService) WithGettext(path string, lns []lang.Language) *Men } // ensureSchemaExists creates a new schema if it does not exist -func ensureSchemaExists(ctx context.Context, connStr, schema string) error { - conn, err := pgxpool.New(ctx, connStr) +func ensureSchemaExists(ctx context.Context, conn ConnData) error { + h, err := pgxpool.New(ctx, conn.Path()) if err != nil { return fmt.Errorf("failed to connect to the database: %w", err) } - defer conn.Close() + defer h.Close() - query := fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema) - _, err = conn.Exec(ctx, query) + query := fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", conn.Domain()) + _, err = h.Exec(ctx, query) if err != nil { return fmt.Errorf("failed to create schema: %w", err) } -- 2.45.2 From 3fccfaab618f3f68e37750bae225b6079a03ac2b Mon Sep 17 00:00:00 2001 From: alfred-mk Date: Thu, 9 Jan 2025 13:01:28 +0300 Subject: [PATCH 11/12] Replace the connStr if it is not set --- cmd/africastalking/main.go | 6 +++--- cmd/async/main.go | 7 +++---- cmd/http/main.go | 8 ++++---- cmd/main.go | 8 ++++---- 4 files changed, 14 insertions(+), 15 deletions(-) diff --git a/cmd/africastalking/main.go b/cmd/africastalking/main.go index 053eab9..24812a1 100644 --- a/cmd/africastalking/main.go +++ b/cmd/africastalking/main.go @@ -12,17 +12,17 @@ import ( "syscall" "git.defalsify.org/vise.git/engine" + "git.defalsify.org/vise.git/lang" "git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/resource" - "git.defalsify.org/vise.git/lang" "git.grassecon.net/urdt/ussd/config" "git.grassecon.net/urdt/ussd/initializers" + "git.grassecon.net/urdt/ussd/internal/args" "git.grassecon.net/urdt/ussd/internal/handlers" "git.grassecon.net/urdt/ussd/internal/http/at" "git.grassecon.net/urdt/ussd/internal/storage" "git.grassecon.net/urdt/ussd/remote" - "git.grassecon.net/urdt/ussd/internal/args" ) var ( @@ -59,7 +59,7 @@ func main() { flag.Var(&langs, "language", "add symbol resolution for language") flag.Parse() - if connStr != "" { + if connStr == "" { connStr = config.DbConn } connData, err := storage.ToConnData(connStr) diff --git a/cmd/async/main.go b/cmd/async/main.go index 0b9a233..27db453 100644 --- a/cmd/async/main.go +++ b/cmd/async/main.go @@ -10,16 +10,16 @@ import ( "syscall" "git.defalsify.org/vise.git/engine" + "git.defalsify.org/vise.git/lang" "git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/resource" - "git.defalsify.org/vise.git/lang" "git.grassecon.net/urdt/ussd/config" "git.grassecon.net/urdt/ussd/initializers" + "git.grassecon.net/urdt/ussd/internal/args" "git.grassecon.net/urdt/ussd/internal/handlers" "git.grassecon.net/urdt/ussd/internal/storage" "git.grassecon.net/urdt/ussd/remote" - "git.grassecon.net/urdt/ussd/internal/args" ) var ( @@ -70,7 +70,7 @@ func main() { flag.Var(&langs, "language", "add symbol resolution for language") flag.Parse() - if connStr != "" { + if connStr == "" { connStr = config.DbConn } connData, err := storage.ToConnData(connStr) @@ -115,7 +115,6 @@ func main() { os.Exit(1) } - userdataStore, err := menuStorageService.GetUserdataDb(ctx) if err != nil { fmt.Fprintf(os.Stderr, err.Error()) diff --git a/cmd/http/main.go b/cmd/http/main.go index 761ee72..6617ca5 100644 --- a/cmd/http/main.go +++ b/cmd/http/main.go @@ -12,17 +12,17 @@ import ( "syscall" "git.defalsify.org/vise.git/engine" + "git.defalsify.org/vise.git/lang" "git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/resource" - "git.defalsify.org/vise.git/lang" "git.grassecon.net/urdt/ussd/config" "git.grassecon.net/urdt/ussd/initializers" + "git.grassecon.net/urdt/ussd/internal/args" "git.grassecon.net/urdt/ussd/internal/handlers" httpserver "git.grassecon.net/urdt/ussd/internal/http" "git.grassecon.net/urdt/ussd/internal/storage" "git.grassecon.net/urdt/ussd/remote" - "git.grassecon.net/urdt/ussd/internal/args" ) var ( @@ -58,7 +58,7 @@ func main() { flag.Var(&langs, "language", "add symbol resolution for language") flag.Parse() - if connStr != "" { + if connStr == "" { connStr = config.DbConn } connData, err := storage.ToConnData(connStr) @@ -92,7 +92,7 @@ func main() { } menuStorageService := storage.NewMenuStorageService(connData, resourceDir) - + rs, err := menuStorageService.GetResource(ctx) if err != nil { fmt.Fprintf(os.Stderr, err.Error()) diff --git a/cmd/main.go b/cmd/main.go index 8c06094..d2fe0ba 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -8,14 +8,14 @@ import ( "path" "git.defalsify.org/vise.git/engine" + "git.defalsify.org/vise.git/lang" "git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/resource" - "git.defalsify.org/vise.git/lang" "git.grassecon.net/urdt/ussd/config" "git.grassecon.net/urdt/ussd/initializers" + "git.grassecon.net/urdt/ussd/internal/args" "git.grassecon.net/urdt/ussd/internal/handlers" "git.grassecon.net/urdt/ussd/internal/storage" - "git.grassecon.net/urdt/ussd/internal/args" "git.grassecon.net/urdt/ussd/remote" ) @@ -51,7 +51,7 @@ func main() { flag.Var(&langs, "language", "add symbol resolution for language") flag.Parse() - if connStr != "" { + if connStr == "" { connStr = config.DbConn } connData, err := storage.ToConnData(connStr) @@ -87,7 +87,7 @@ func main() { } menuStorageService := storage.NewMenuStorageService(connData, resourceDir) - + if gettextDir != "" { menuStorageService = menuStorageService.WithGettext(gettextDir, langs.Langs()) } -- 2.45.2 From 9a6d8e51589f53ac9480ef3951944f1ca6c44b50 Mon Sep 17 00:00:00 2001 From: alfred-mk Date: Fri, 10 Jan 2025 13:41:05 +0300 Subject: [PATCH 12/12] Refactored the code to switch between postgres and gdbm, with db cleanup --- internal/testutil/engine.go | 95 +++++++++++++++++++---- menutraversal_test/menu_traversal_test.go | 43 ++-------- 2 files changed, 87 insertions(+), 51 deletions(-) diff --git a/internal/testutil/engine.go b/internal/testutil/engine.go index 8d5d3e4..5d581ba 100644 --- a/internal/testutil/engine.go +++ b/internal/testutil/engine.go @@ -3,6 +3,8 @@ package testutil import ( "context" "fmt" + "log" + "net/url" "os" "path" "path/filepath" @@ -11,34 +13,90 @@ import ( "git.defalsify.org/vise.git/engine" "git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/resource" + "git.grassecon.net/urdt/ussd/config" "git.grassecon.net/urdt/ussd/initializers" "git.grassecon.net/urdt/ussd/internal/handlers" "git.grassecon.net/urdt/ussd/internal/storage" "git.grassecon.net/urdt/ussd/internal/testutil/testservice" "git.grassecon.net/urdt/ussd/internal/testutil/testtag" "git.grassecon.net/urdt/ussd/remote" + "github.com/jackc/pgx/v5/pgxpool" testdataloader "github.com/peteole/testdata-loader" ) var ( - logg = logging.NewVanilla() - baseDir = testdataloader.GetBasePath() - scriptDir = path.Join(baseDir, "services", "registration") - selectedDatabase = "" - selectedDbSchema = "" + logg = logging.NewVanilla() + baseDir = testdataloader.GetBasePath() + scriptDir = path.Join(baseDir, "services", "registration") + setDbType string + setConnStr string + setDbSchema string ) func init() { initializers.LoadEnvVariablesPath(baseDir) + config.LoadConfig() } // SetDatabase updates the database used by TestEngine -func SetDatabase(dbType string, dbSchema string) { - selectedDatabase = dbType - selectedDbSchema = dbSchema +func SetDatabase(database, connStr, dbSchema string) { + setDbType = database + setConnStr = connStr + setDbSchema = dbSchema +} + +// CleanDatabase removes all test data from the database +func CleanDatabase() { + if setDbType == "postgres" { + ctx := context.Background() + // Update the connection string with the new search path + updatedConnStr, err := updateSearchPath(setConnStr, setDbSchema) + if err != nil { + log.Fatalf("Failed to update search path: %v", err) + } + + dbConn, err := pgxpool.New(ctx, updatedConnStr) + if err != nil { + log.Fatalf("Failed to connect to database for cleanup: %v", err) + } + defer dbConn.Close() + + query := fmt.Sprintf("DELETE FROM %s.kv_vise;", setDbSchema) + _, execErr := dbConn.Exec(ctx, query) + if execErr != nil { + log.Printf("Failed to cleanup table %s.kv_vise: %v", setDbSchema, execErr) + } else { + log.Printf("Successfully cleaned up table %s.kv_vise", setDbSchema) + } + } else { + setConnStr, _ := filepath.Abs(setConnStr) + if err := os.RemoveAll(setConnStr); err != nil { + log.Fatalf("Failed to delete state store %s: %v", setConnStr, err) + } + } +} + +// updateSearchPath updates the search_path (schema) to be used in the connection +func updateSearchPath(connStr string, newSearchPath string) (string, error) { + u, err := url.Parse(connStr) + if err != nil { + return "", fmt.Errorf("invalid connection string: %w", err) + } + + // Parse the query parameters + q := u.Query() + + // Update or add the search_path parameter + q.Set("search_path", newSearchPath) + + // Rebuild the connection string with updated parameters + u.RawQuery = q.Encode() + + return u.String(), nil } func TestEngine(sessionId string) (engine.Engine, func(), chan bool) { + var err error ctx := context.Background() ctx = context.WithValue(ctx, "SessionId", sessionId) pfp := path.Join(scriptDir, "pp.csv") @@ -52,16 +110,27 @@ func TestEngine(sessionId string) (engine.Engine, func(), chan bool) { FlagCount: uint32(128), } - connStr, err := filepath.Abs(".test_state/state.gdbm") - if err != nil { - fmt.Fprintf(os.Stderr, "connstr err: %v", err) - os.Exit(1) + if setDbType == "postgres" { + setConnStr = config.DbConn + setConnStr, err = updateSearchPath(setConnStr, setDbSchema) + if err != nil { + fmt.Println("Error:", err) + os.Exit(1) + } + } else { + setConnStr, err = filepath.Abs(setConnStr) + if err != nil { + fmt.Fprintf(os.Stderr, "connstr err: %v", err) + os.Exit(1) + } } - conn, err := storage.ToConnData(connStr) + + conn, err := storage.ToConnData(setConnStr) if err != nil { fmt.Fprintf(os.Stderr, "connstr parse err: %v", err) os.Exit(1) } + resourceDir := scriptDir menuStorageService := storage.NewMenuStorageService(conn, resourceDir) diff --git a/menutraversal_test/menu_traversal_test.go b/menutraversal_test/menu_traversal_test.go index 4003641..4aee26e 100644 --- a/menutraversal_test/menu_traversal_test.go +++ b/menutraversal_test/menu_traversal_test.go @@ -4,18 +4,14 @@ import ( "bytes" "context" "flag" - "fmt" "log" "math/rand" - "os" - "path/filepath" "regexp" "testing" "git.grassecon.net/urdt/ussd/internal/testutil" "git.grassecon.net/urdt/ussd/internal/testutil/driver" "github.com/gofrs/uuid" - "github.com/jackc/pgx/v5/pgxpool" ) var ( @@ -27,13 +23,9 @@ var ( var groupTestFile = flag.String("test-file", "group_test.json", "The test file to use for running the group tests") var database = flag.String("db", "gdbm", "Specify the database (gdbm or postgres)") +var connStr = flag.String("conn", ".test_state", "connection string") var dbSchema = flag.String("schema", "test", "Specify the database schema (default test)") -func testStore() string { - v, _ := filepath.Abs(".test_state/state.gdbm") - return v -} - func GenerateSessionId() string { uu := uuid.NewGenWithOptions(uuid.WithRandomReader(g)) v, err := uu.NewV4() @@ -90,37 +82,12 @@ func extractSendAmount(response []byte) string { func TestMain(m *testing.M) { // Parse the flags flag.Parse() - sessionID = GenerateSessionId() - defer func() { - if err := os.RemoveAll(testStore()); err != nil { - log.Fatalf("Failed to delete state store %s: %v", testStore(), err) - } - }() + // set the db + testutil.SetDatabase(*database, *connStr, *dbSchema) - // Set the selected database - testutil.SetDatabase(*database, *dbSchema) - - // Cleanup the schema table after tests - defer func() { - if *database == "postgres" { - ctx := context.Background() - connStr := "postgres://" //storage.BuildConnStr() - dbConn, err := pgxpool.New(ctx, connStr) - if err != nil { - log.Fatalf("Failed to connect to database for cleanup: %v", err) - } - defer dbConn.Close() - - query := fmt.Sprintf("DELETE FROM %s.kv_vise;", *dbSchema) - _, execErr := dbConn.Exec(ctx, query) - if execErr != nil { - log.Printf("Failed to cleanup table %s.kv_vise: %v", *dbSchema, execErr) - } else { - log.Printf("Successfully cleaned up table %s.kv_vise", *dbSchema) - } - } - }() + // Cleanup the db after tests + defer testutil.CleanDatabase() m.Run() } -- 2.45.2