diff --git a/internal/testutil/TestEngine.go b/internal/testutil/TestEngine.go index 3fcb307..40a744f 100644 --- a/internal/testutil/TestEngine.go +++ b/internal/testutil/TestEngine.go @@ -10,24 +10,35 @@ import ( "git.defalsify.org/vise.git/engine" "git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/resource" + "git.grassecon.net/urdt/ussd/initializers" "git.grassecon.net/urdt/ussd/internal/handlers" "git.grassecon.net/urdt/ussd/internal/storage" "git.grassecon.net/urdt/ussd/internal/testutil/testservice" "git.grassecon.net/urdt/ussd/internal/testutil/testtag" - testdataloader "github.com/peteole/testdata-loader" "git.grassecon.net/urdt/ussd/remote" + testdataloader "github.com/peteole/testdata-loader" ) var ( - baseDir = testdataloader.GetBasePath() - logg = logging.NewVanilla() - scriptDir = path.Join(baseDir, "services", "registration") + baseDir = testdataloader.GetBasePath() + logg = logging.NewVanilla() + scriptDir = path.Join(baseDir, "services", "registration") + selectedDatabase = "" ) +func init() { + initializers.LoadEnvVariables(baseDir) +} + +// SetDatabase updates the database used by TestEngine +func SetDatabase(dbType string) { + selectedDatabase = dbType +} + func TestEngine(sessionId string) (engine.Engine, func(), chan bool) { ctx := context.Background() ctx = context.WithValue(ctx, "SessionId", sessionId) - ctx = context.WithValue(ctx, "Database", "gdbm") + ctx = context.WithValue(ctx, "Database", selectedDatabase) pfp := path.Join(scriptDir, "pp.csv") var eventChannel = make(chan bool) diff --git a/menutraversal_test/menu_traversal_test.go b/menutraversal_test/menu_traversal_test.go index 6b6b3da..8cfe710 100644 --- a/menutraversal_test/menu_traversal_test.go +++ b/menutraversal_test/menu_traversal_test.go @@ -24,6 +24,7 @@ var ( ) var groupTestFile = flag.String("test-file", "group_test.json", "The test file to use for running the group tests") +var database = flag.String("db", "gdbm", "Specify the database (gdbm or postgres)") func GenerateSessionId() string { uu := uuid.NewGenWithOptions(uuid.WithRandomReader(g)) @@ -79,12 +80,18 @@ func extractSendAmount(response []byte) string { } func TestMain(m *testing.M) { + // Parse the flags + flag.Parse() + sessionID = GenerateSessionId() defer func() { if err := os.RemoveAll(testStore); err != nil { log.Fatalf("Failed to delete state store %s: %v", testStore, err) } }() + + // Set the selected database + testutil.SetDatabase(*database) m.Run() } @@ -121,7 +128,6 @@ func TestAccountCreationSuccessful(t *testing.T) { } } <-eventChannel - } func TestAccountRegistrationRejectTerms(t *testing.T) {