diff --git a/internal/testutil/engine.go b/internal/testutil/engine.go index 2372ce9..9de07ec 100644 --- a/internal/testutil/engine.go +++ b/internal/testutil/engine.go @@ -11,23 +11,35 @@ import ( "git.defalsify.org/vise.git/engine" "git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/resource" + "git.grassecon.net/urdt/ussd/initializers" "git.grassecon.net/urdt/ussd/internal/handlers" "git.grassecon.net/urdt/ussd/internal/storage" "git.grassecon.net/urdt/ussd/internal/testutil/testservice" "git.grassecon.net/urdt/ussd/internal/testutil/testtag" - testdataloader "github.com/peteole/testdata-loader" "git.grassecon.net/urdt/ussd/remote" + testdataloader "github.com/peteole/testdata-loader" ) var ( - baseDir = testdataloader.GetBasePath() - logg = logging.NewVanilla() - scriptDir = path.Join(baseDir, "services", "registration") + baseDir = testdataloader.GetBasePath() + logg = logging.NewVanilla() + scriptDir = path.Join(baseDir, "services", "registration") + selectedDatabase = "" ) +func init() { + initializers.LoadEnvVariables(baseDir) +} + +// SetDatabase updates the database used by TestEngine +func SetDatabase(dbType string) { + selectedDatabase = dbType +} + func TestEngine(sessionId string) (engine.Engine, func(), chan bool) { ctx := context.Background() ctx = context.WithValue(ctx, "SessionId", sessionId) + ctx = context.WithValue(ctx, "Database", selectedDatabase) pfp := path.Join(scriptDir, "pp.csv") var eventChannel = make(chan bool) diff --git a/menutraversal_test/menu_traversal_test.go b/menutraversal_test/menu_traversal_test.go index 52e2273..8d9bc52 100644 --- a/menutraversal_test/menu_traversal_test.go +++ b/menutraversal_test/menu_traversal_test.go @@ -24,6 +24,7 @@ var ( ) var groupTestFile = flag.String("test-file", "group_test.json", "The test file to use for running the group tests") +var database = flag.String("db", "gdbm", "Specify the database (gdbm or postgres)") func testStore() string { v, _ := filepath.Abs(".test_state/state.gdbm") @@ -84,12 +85,18 @@ func extractSendAmount(response []byte) string { } func TestMain(m *testing.M) { + // Parse the flags + flag.Parse() + sessionID = GenerateSessionId() defer func() { if err := os.RemoveAll(testStore()); err != nil { log.Fatalf("Failed to delete state store %s: %v", testStore(), err) } }() + + // Set the selected database + testutil.SetDatabase(*database) m.Run() } @@ -126,7 +133,6 @@ func TestAccountCreationSuccessful(t *testing.T) { } } <-eventChannel - } func TestAccountRegistrationRejectTerms(t *testing.T) {