Compare commits

...

5 Commits

7 changed files with 77 additions and 5 deletions

View File

@ -43,12 +43,14 @@ func main() {
var resourceDir string
var size uint
var database string
var dbSchema string
var engineDebug bool
var host string
var port uint
flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from")
flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir")
flag.StringVar(&database, "db", "gdbm", "database to be used")
flag.StringVar(&dbSchema, "schema", "public", "database schema to be used")
flag.BoolVar(&engineDebug, "d", false, "use engine debug output")
flag.UintVar(&size, "s", 160, "max size of output")
flag.StringVar(&host, "h", initializers.GetEnv("HOST", "127.0.0.1"), "http host")
@ -59,6 +61,7 @@ func main() {
ctx := context.Background()
ctx = context.WithValue(ctx, "Database", database)
ctx = context.WithValue(ctx, "Schema", dbSchema)
pfp := path.Join(scriptDir, "pp.csv")
cfg := engine.Config{

View File

@ -53,6 +53,7 @@ func main() {
var resourceDir string
var size uint
var database string
var dbSchema string
var engineDebug bool
var host string
var port uint
@ -60,6 +61,7 @@ func main() {
flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from")
flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir")
flag.StringVar(&database, "db", "gdbm", "database to be used")
flag.StringVar(&dbSchema, "schema", "public", "database schema to be used")
flag.BoolVar(&engineDebug, "d", false, "use engine debug output")
flag.UintVar(&size, "s", 160, "max size of output")
flag.StringVar(&host, "h", initializers.GetEnv("HOST", "127.0.0.1"), "http host")
@ -70,6 +72,7 @@ func main() {
ctx := context.Background()
ctx = context.WithValue(ctx, "Database", database)
ctx = context.WithValue(ctx, "Schema", dbSchema)
pfp := path.Join(scriptDir, "pp.csv")
cfg := engine.Config{

View File

@ -42,12 +42,14 @@ func main() {
var resourceDir string
var size uint
var database string
var dbSchema string
var engineDebug bool
var host string
var port uint
flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from")
flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir")
flag.StringVar(&database, "db", "gdbm", "database to be used")
flag.StringVar(&dbSchema, "schema", "public", "database schema to be used")
flag.BoolVar(&engineDebug, "d", false, "use engine debug output")
flag.UintVar(&size, "s", 160, "max size of output")
flag.StringVar(&host, "h", initializers.GetEnv("HOST", "127.0.0.1"), "http host")
@ -58,6 +60,7 @@ func main() {
ctx := context.Background()
ctx = context.WithValue(ctx, "Database", database)
ctx = context.WithValue(ctx, "Schema", dbSchema)
pfp := path.Join(scriptDir, "pp.csv")
cfg := engine.Config{

View File

@ -36,9 +36,11 @@ func main() {
var size uint
var sessionId string
var database string
var dbSchema string
var engineDebug bool
flag.StringVar(&sessionId, "session-id", "075xx2123", "session id")
flag.StringVar(&database, "db", "gdbm", "database to be used")
flag.StringVar(&dbSchema, "schema", "public", "database schema to be used")
flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from")
flag.BoolVar(&engineDebug, "d", false, "use engine debug output")
flag.UintVar(&size, "s", 160, "max size of output")
@ -49,6 +51,7 @@ func main() {
ctx := context.Background()
ctx = context.WithValue(ctx, "SessionId", sessionId)
ctx = context.WithValue(ctx, "Database", database)
ctx = context.WithValue(ctx, "Schema", dbSchema)
pfp := path.Join(scriptDir, "pp.csv")
cfg := engine.Config{

View File

@ -14,6 +14,7 @@ import (
"git.defalsify.org/vise.git/resource"
"git.grassecon.net/urdt/ussd/initializers"
gdbmstorage "git.grassecon.net/urdt/ussd/internal/storage/db/gdbm"
"github.com/jackc/pgx/v5/pgxpool"
)
var (
@ -35,7 +36,7 @@ type MenuStorageService struct {
userDataStore db.Db
}
func buildConnStr() string {
func BuildConnStr() string {
host := initializers.GetEnv("DB_HOST", "localhost")
user := initializers.GetEnv("DB_USER", "postgres")
password := initializers.GetEnv("DB_PASSWORD", "")
@ -64,6 +65,11 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D
return nil, fmt.Errorf("failed to select the database")
}
schema, ok := ctx.Value("Schema").(string)
if !ok {
return nil, fmt.Errorf("failed to select the schema")
}
if existingDb != nil {
return existingDb, nil
}
@ -72,8 +78,15 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D
var err error
if database == "postgres" {
newDb = postgres.NewPgDb()
connStr := buildConnStr()
connStr := BuildConnStr()
// Ensure the schema exists
err = ensureSchemaExists(ctx, connStr, schema)
if err != nil {
return nil, fmt.Errorf("failed to ensure schema exists: %w", err)
}
newDb = postgres.NewPgDb().WithSchema(schema)
err = newDb.Connect(ctx, connStr)
} else {
newDb = gdbmstorage.NewThreadGdbmDb()
@ -88,6 +101,23 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D
return newDb, nil
}
// ensureSchemaExists creates a new schema if it does not exist
func ensureSchemaExists(ctx context.Context, connStr, schema string) error {
conn, err := pgxpool.New(ctx, connStr)
if err != nil {
return fmt.Errorf("failed to connect to the database: %w", err)
}
defer conn.Close()
query := fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema)
_, err = conn.Exec(ctx, query)
if err != nil {
return fmt.Errorf("failed to create schema: %w", err)
}
return nil
}
func (ms *MenuStorageService) GetPersister(ctx context.Context) (*persist.Persister, error) {
stateStore, err := ms.GetStateStore(ctx)
if err != nil {

View File

@ -24,6 +24,7 @@ var (
logg = logging.NewVanilla()
scriptDir = path.Join(baseDir, "services", "registration")
selectedDatabase = ""
selectedDbSchema = ""
)
func init() {
@ -31,14 +32,17 @@ func init() {
}
// SetDatabase updates the database used by TestEngine
func SetDatabase(dbType string) {
func SetDatabase(dbType string, dbSchema string) {
selectedDatabase = dbType
selectedDbSchema = dbSchema
}
func TestEngine(sessionId string) (engine.Engine, func(), chan bool) {
ctx := context.Background()
ctx = context.WithValue(ctx, "SessionId", sessionId)
ctx = context.WithValue(ctx, "Database", selectedDatabase)
ctx = context.WithValue(ctx, "Schema", selectedDbSchema)
pfp := path.Join(scriptDir, "pp.csv")
var eventChannel = make(chan bool)

View File

@ -4,15 +4,18 @@ import (
"bytes"
"context"
"flag"
"fmt"
"log"
"math/rand"
"os"
"regexp"
"testing"
"git.grassecon.net/urdt/ussd/internal/storage"
"git.grassecon.net/urdt/ussd/internal/testutil"
"git.grassecon.net/urdt/ussd/internal/testutil/driver"
"github.com/gofrs/uuid"
"github.com/jackc/pgx/v5/pgxpool"
)
var (
@ -25,6 +28,7 @@ var (
var groupTestFile = flag.String("test-file", "group_test.json", "The test file to use for running the group tests")
var database = flag.String("db", "gdbm", "Specify the database (gdbm or postgres)")
var dbSchema = flag.String("schema", "test", "Specify the database schema (default test)")
func GenerateSessionId() string {
uu := uuid.NewGenWithOptions(uuid.WithRandomReader(g))
@ -91,7 +95,29 @@ func TestMain(m *testing.M) {
}()
// Set the selected database
testutil.SetDatabase(*database)
testutil.SetDatabase(*database, *dbSchema)
// Cleanup the schema table after tests
defer func() {
if *database == "postgres" {
ctx := context.Background()
connStr := storage.BuildConnStr()
dbConn, err := pgxpool.New(ctx, connStr)
if err != nil {
log.Fatalf("Failed to connect to database for cleanup: %v", err)
}
defer dbConn.Close()
query := fmt.Sprintf("DELETE FROM %s.kv_vise;", *dbSchema)
_, execErr := dbConn.Exec(ctx, query)
if execErr != nil {
log.Printf("Failed to cleanup table %s.kv_vise: %v", *dbSchema, execErr)
} else {
log.Printf("Successfully cleaned up table %s.kv_vise", *dbSchema)
}
}
}()
m.Run()
}