diff --git a/cmd/store/generate/main.go b/cmd/store/generate/main.go index d7923ae..60dcea9 100644 --- a/cmd/store/generate/main.go +++ b/cmd/store/generate/main.go @@ -17,6 +17,7 @@ import ( var ( logg = logging.NewVanilla() + baseDir = testdataloader.GetBasePath() scriptDir = path.Join("services", "registration") ) @@ -24,7 +25,6 @@ func init() { initializers.LoadEnvVariables() } - func main() { config.LoadConfig() @@ -86,5 +86,4 @@ func main() { fmt.Fprintf(os.Stderr, err.Error()) os.Exit(1) } - } diff --git a/initializers/load.go b/initializers/load.go index 4ea5980..fc61746 100644 --- a/initializers/load.go +++ b/initializers/load.go @@ -3,24 +3,30 @@ package initializers import ( "log" "os" + "path" "strconv" "github.com/joho/godotenv" ) func LoadEnvVariables() { - err := godotenv.Load() + LoadEnvVariablesPath(".") +} + +func LoadEnvVariablesPath(dir string) { + fp := path.Join(dir, ".env") + err := godotenv.Load(fp) if err != nil { - log.Fatal("Error loading .env file") + log.Fatal("Error loading .env file", err) } } // Helper to get environment variables with a default fallback func GetEnv(key, defaultVal string) string { - if value, exists := os.LookupEnv(key); exists { - return value + if value, exists := os.LookupEnv(key); exists { + return value } - return defaultVal + return defaultVal } // Helper to safely convert environment variables to uint diff --git a/menutraversal_test/menu_traversal_test.go b/menutraversal_test/menu_traversal_test.go index d41cd3b..92a839a 100644 --- a/menutraversal_test/menu_traversal_test.go +++ b/menutraversal_test/menu_traversal_test.go @@ -6,8 +6,6 @@ import ( "flag" "log" "math/rand" - "os" - "path/filepath" "regexp" "testing" @@ -24,11 +22,9 @@ var ( ) var groupTestFile = flag.String("test-file", "group_test.json", "The test file to use for running the group tests") - -func testStore() string { - v, _ := filepath.Abs(".test_state/state.gdbm") - return v -} +var database = flag.String("db", "gdbm", "Specify the database (gdbm or postgres)") +var connStr = flag.String("conn", ".test_state", "connection string") +var dbSchema = flag.String("schema", "test", "Specify the database schema (default test)") func GenerateSessionId() string { uu := uuid.NewGenWithOptions(uuid.WithRandomReader(g)) @@ -84,12 +80,15 @@ func extractSendAmount(response []byte) string { } func TestMain(m *testing.M) { + // Parse the flags + flag.Parse() sessionID = GenerateSessionId() - defer func() { - if err := os.RemoveAll(testStore()); err != nil { - log.Fatalf("Failed to delete state store %s: %v", testStore(), err) - } - }() + // set the db + testutil.SetDatabase(*database, *connStr, *dbSchema) + + // Cleanup the db after tests + defer testutil.CleanDatabase() + m.Run() } @@ -126,7 +125,6 @@ func TestAccountCreationSuccessful(t *testing.T) { } } <-eventChannel - } func TestAccountRegistrationRejectTerms(t *testing.T) { diff --git a/storage/parse.go b/storage/parse.go index bb25627..41dac6b 100644 --- a/storage/parse.go +++ b/storage/parse.go @@ -15,6 +15,7 @@ const ( type ConnData struct { typ int str string + domain string } func (cd *ConnData) DbType() int { @@ -25,23 +26,38 @@ func (cd *ConnData) String() string { return cd.str } -func probePostgres(s string) (string, bool) { - v, err := url.Parse(s) - if err != nil { - return "", false - } - if v.Scheme != "postgres" { - return "", false - } - return s, true +func (cd *ConnData) Domain() string { + return cd.domain } -func probeGdbm(s string) (string, bool) { +func (cd *ConnData) Path() string { + v, _ := url.Parse(cd.str) + v.RawQuery = "" + return v.String() +} + +func probePostgres(s string) (string, string, bool) { + domain := "public" + v, err := url.Parse(s) + if err != nil { + return "", "", false + } + if v.Scheme != "postgres" { + return "", "", false + } + vv := v.Query() + if vv.Has("search_path") { + domain = vv.Get("search_path") + } + return s, domain, true +} + +func probeGdbm(s string) (string, string, bool) { if !path.IsAbs(s) { - return "", false + return "", "", false } s = path.Clean(s) - return s, true + return s, "", true } func ToConnData(connStr string) (ConnData, error) { @@ -51,14 +67,15 @@ func ToConnData(connStr string) (ConnData, error) { return o, nil } - v, ok := probePostgres(connStr) + v, domain, ok := probePostgres(connStr) if ok { o.typ = DBTYPE_POSTGRES o.str = v + o.domain = domain return o, nil } - v, ok = probeGdbm(connStr) + v, _, ok = probeGdbm(connStr) if ok { o.typ = DBTYPE_GDBM o.str = v diff --git a/storage/storageservice.go b/storage/storageservice.go index 6d4ad80..bc5a5e2 100644 --- a/storage/storageservice.go +++ b/storage/storageservice.go @@ -59,7 +59,12 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D connStr := ms.conn.String() dbTyp := ms.conn.DbType() if dbTyp == DBTYPE_POSTGRES { - newDb = postgres.NewPgDb() + // TODO: move to vise + err = ensureSchemaExists(ctx, ms.conn) + if err != nil { + return nil, err + } + newDb = postgres.NewPgDb().WithSchema(ms.conn.Domain()) } else if dbTyp == DBTYPE_GDBM { err = ms.ensureDbDir() if err != nil { @@ -70,7 +75,7 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D } else { return nil, fmt.Errorf("unsupported connection string: '%s'\n", ms.conn.String()) } - logg.DebugCtxf(ctx, "connecting to db", "conn", connStr) + logg.DebugCtxf(ctx, "connecting to db", "conn", connStr, "conndata", ms.conn) err = newDb.Connect(ctx, connStr) if err != nil { return nil, err @@ -101,6 +106,23 @@ func (ms *MenuStorageService) WithGettext(path string, lns []lang.Language) *Men return ms } +// ensureSchemaExists creates a new schema if it does not exist +func ensureSchemaExists(ctx context.Context, conn ConnData) error { + h, err := pgxpool.New(ctx, conn.Path()) + if err != nil { + return fmt.Errorf("failed to connect to the database: %w", err) + } + defer h.Close() + + query := fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", conn.Domain()) + _, err = h.Exec(ctx, query) + if err != nil { + return fmt.Errorf("failed to create schema: %w", err) + } + + return nil +} + func (ms *MenuStorageService) GetPersister(ctx context.Context) (*persist.Persister, error) { stateStore, err := ms.GetStateStore(ctx) if err != nil { diff --git a/testutil/engine.go b/testutil/engine.go index c0cb845..310b4a0 100644 --- a/testutil/engine.go +++ b/testutil/engine.go @@ -3,14 +3,19 @@ package testutil import ( "context" "fmt" + "log" + "net/url" "os" "path" "path/filepath" "time" + "github.com/jackc/pgx/v5/pgxpool" "git.defalsify.org/vise.git/engine" "git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/resource" + "git.grassecon.net/grassrootseconomics/visedriver/initializers" + "git.grassecon.net/grassrootseconomics/visedriver/config" "git.grassecon.net/grassrootseconomics/visedriver/handlers" "git.grassecon.net/grassrootseconomics/visedriver/storage" "git.grassecon.net/grassrootseconomics/visedriver/internal/testutil/testservice" @@ -20,12 +25,78 @@ import ( ) var ( - baseDir = testdataloader.GetBasePath() - logg = logging.NewVanilla() - scriptDir = path.Join(baseDir, "services", "registration") + logg = logging.NewVanilla() + baseDir = testdataloader.GetBasePath() + scriptDir = path.Join(baseDir, "services", "registration") + setDbType string + setConnStr string + setDbSchema string ) +func init() { + initializers.LoadEnvVariablesPath(baseDir) + config.LoadConfig() +} + +// SetDatabase updates the database used by TestEngine +func SetDatabase(database, connStr, dbSchema string) { + setDbType = database + setConnStr = connStr + setDbSchema = dbSchema +} + +// CleanDatabase removes all test data from the database +func CleanDatabase() { + if setDbType == "postgres" { + ctx := context.Background() + // Update the connection string with the new search path + updatedConnStr, err := updateSearchPath(setConnStr, setDbSchema) + if err != nil { + log.Fatalf("Failed to update search path: %v", err) + } + + dbConn, err := pgxpool.New(ctx, updatedConnStr) + if err != nil { + log.Fatalf("Failed to connect to database for cleanup: %v", err) + } + defer dbConn.Close() + + query := fmt.Sprintf("DELETE FROM %s.kv_vise;", setDbSchema) + _, execErr := dbConn.Exec(ctx, query) + if execErr != nil { + log.Printf("Failed to cleanup table %s.kv_vise: %v", setDbSchema, execErr) + } else { + log.Printf("Successfully cleaned up table %s.kv_vise", setDbSchema) + } + } else { + setConnStr, _ := filepath.Abs(setConnStr) + if err := os.RemoveAll(setConnStr); err != nil { + log.Fatalf("Failed to delete state store %s: %v", setConnStr, err) + } + } +} + +// updateSearchPath updates the search_path (schema) to be used in the connection +func updateSearchPath(connStr string, newSearchPath string) (string, error) { + u, err := url.Parse(connStr) + if err != nil { + return "", fmt.Errorf("invalid connection string: %w", err) + } + + // Parse the query parameters + q := u.Query() + + // Update or add the search_path parameter + q.Set("search_path", newSearchPath) + + // Rebuild the connection string with updated parameters + u.RawQuery = q.Encode() + + return u.String(), nil +} + func TestEngine(sessionId string) (engine.Engine, func(), chan bool) { + var err error ctx := context.Background() ctx = context.WithValue(ctx, "SessionId", sessionId) pfp := path.Join(scriptDir, "pp.csv") @@ -39,16 +110,27 @@ func TestEngine(sessionId string) (engine.Engine, func(), chan bool) { FlagCount: uint32(128), } - connStr, err := filepath.Abs(".test_state/state.gdbm") - if err != nil { - fmt.Fprintf(os.Stderr, "connstr err: %v", err) - os.Exit(1) + if setDbType == "postgres" { + setConnStr = config.DbConn + setConnStr, err = updateSearchPath(setConnStr, setDbSchema) + if err != nil { + fmt.Println("Error:", err) + os.Exit(1) + } + } else { + setConnStr, err = filepath.Abs(setConnStr) + if err != nil { + fmt.Fprintf(os.Stderr, "connstr err: %v", err) + os.Exit(1) + } } - conn, err := storage.ToConnData(connStr) + + conn, err := storage.ToConnData(setConnStr) if err != nil { fmt.Fprintf(os.Stderr, "connstr parse err: %v", err) os.Exit(1) } + resourceDir := scriptDir menuStorageService := storage.NewMenuStorageService(conn, resourceDir)