add a db flag to specify the database of choice

This commit is contained in:
Alfred Kamanda 2025-01-06 15:06:25 +03:00 committed by konstantinmds
parent 66110439a0
commit 0f6c486ee0
2 changed files with 23 additions and 5 deletions

View File

@ -11,23 +11,35 @@ import (
"git.defalsify.org/vise.git/engine" "git.defalsify.org/vise.git/engine"
"git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/logging"
"git.defalsify.org/vise.git/resource" "git.defalsify.org/vise.git/resource"
"git.grassecon.net/urdt/ussd/initializers"
"git.grassecon.net/urdt/ussd/internal/handlers" "git.grassecon.net/urdt/ussd/internal/handlers"
"git.grassecon.net/urdt/ussd/internal/storage" "git.grassecon.net/urdt/ussd/internal/storage"
"git.grassecon.net/urdt/ussd/internal/testutil/testservice" "git.grassecon.net/urdt/ussd/internal/testutil/testservice"
"git.grassecon.net/urdt/ussd/internal/testutil/testtag" "git.grassecon.net/urdt/ussd/internal/testutil/testtag"
testdataloader "github.com/peteole/testdata-loader"
"git.grassecon.net/urdt/ussd/remote" "git.grassecon.net/urdt/ussd/remote"
testdataloader "github.com/peteole/testdata-loader"
) )
var ( var (
baseDir = testdataloader.GetBasePath() baseDir = testdataloader.GetBasePath()
logg = logging.NewVanilla() logg = logging.NewVanilla()
scriptDir = path.Join(baseDir, "services", "registration") scriptDir = path.Join(baseDir, "services", "registration")
selectedDatabase = ""
) )
func init() {
initializers.LoadEnvVariables(baseDir)
}
// SetDatabase updates the database used by TestEngine
func SetDatabase(dbType string) {
selectedDatabase = dbType
}
func TestEngine(sessionId string) (engine.Engine, func(), chan bool) { func TestEngine(sessionId string) (engine.Engine, func(), chan bool) {
ctx := context.Background() ctx := context.Background()
ctx = context.WithValue(ctx, "SessionId", sessionId) ctx = context.WithValue(ctx, "SessionId", sessionId)
ctx = context.WithValue(ctx, "Database", selectedDatabase)
pfp := path.Join(scriptDir, "pp.csv") pfp := path.Join(scriptDir, "pp.csv")
var eventChannel = make(chan bool) var eventChannel = make(chan bool)

View File

@ -24,6 +24,7 @@ var (
) )
var groupTestFile = flag.String("test-file", "group_test.json", "The test file to use for running the group tests") var groupTestFile = flag.String("test-file", "group_test.json", "The test file to use for running the group tests")
var database = flag.String("db", "gdbm", "Specify the database (gdbm or postgres)")
func testStore() string { func testStore() string {
v, _ := filepath.Abs(".test_state/state.gdbm") v, _ := filepath.Abs(".test_state/state.gdbm")
@ -84,12 +85,18 @@ func extractSendAmount(response []byte) string {
} }
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
// Parse the flags
flag.Parse()
sessionID = GenerateSessionId() sessionID = GenerateSessionId()
defer func() { defer func() {
if err := os.RemoveAll(testStore()); err != nil { if err := os.RemoveAll(testStore()); err != nil {
log.Fatalf("Failed to delete state store %s: %v", testStore(), err) log.Fatalf("Failed to delete state store %s: %v", testStore(), err)
} }
}() }()
// Set the selected database
testutil.SetDatabase(*database)
m.Run() m.Run()
} }
@ -126,7 +133,6 @@ func TestAccountCreationSuccessful(t *testing.T) {
} }
} }
<-eventChannel <-eventChannel
} }
func TestAccountRegistrationRejectTerms(t *testing.T) { func TestAccountRegistrationRejectTerms(t *testing.T) {