Compare commits

...

5 Commits

7 changed files with 77 additions and 5 deletions

View File

@ -43,12 +43,14 @@ func main() {
var resourceDir string var resourceDir string
var size uint var size uint
var database string var database string
var dbSchema string
var engineDebug bool var engineDebug bool
var host string var host string
var port uint var port uint
flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from") flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from")
flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir") flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir")
flag.StringVar(&database, "db", "gdbm", "database to be used") flag.StringVar(&database, "db", "gdbm", "database to be used")
flag.StringVar(&dbSchema, "schema", "public", "database schema to be used")
flag.BoolVar(&engineDebug, "d", false, "use engine debug output") flag.BoolVar(&engineDebug, "d", false, "use engine debug output")
flag.UintVar(&size, "s", 160, "max size of output") flag.UintVar(&size, "s", 160, "max size of output")
flag.StringVar(&host, "h", initializers.GetEnv("HOST", "127.0.0.1"), "http host") flag.StringVar(&host, "h", initializers.GetEnv("HOST", "127.0.0.1"), "http host")
@ -59,6 +61,7 @@ func main() {
ctx := context.Background() ctx := context.Background()
ctx = context.WithValue(ctx, "Database", database) ctx = context.WithValue(ctx, "Database", database)
ctx = context.WithValue(ctx, "Schema", dbSchema)
pfp := path.Join(scriptDir, "pp.csv") pfp := path.Join(scriptDir, "pp.csv")
cfg := engine.Config{ cfg := engine.Config{

View File

@ -53,6 +53,7 @@ func main() {
var resourceDir string var resourceDir string
var size uint var size uint
var database string var database string
var dbSchema string
var engineDebug bool var engineDebug bool
var host string var host string
var port uint var port uint
@ -60,6 +61,7 @@ func main() {
flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from") flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from")
flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir") flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir")
flag.StringVar(&database, "db", "gdbm", "database to be used") flag.StringVar(&database, "db", "gdbm", "database to be used")
flag.StringVar(&dbSchema, "schema", "public", "database schema to be used")
flag.BoolVar(&engineDebug, "d", false, "use engine debug output") flag.BoolVar(&engineDebug, "d", false, "use engine debug output")
flag.UintVar(&size, "s", 160, "max size of output") flag.UintVar(&size, "s", 160, "max size of output")
flag.StringVar(&host, "h", initializers.GetEnv("HOST", "127.0.0.1"), "http host") flag.StringVar(&host, "h", initializers.GetEnv("HOST", "127.0.0.1"), "http host")
@ -70,6 +72,7 @@ func main() {
ctx := context.Background() ctx := context.Background()
ctx = context.WithValue(ctx, "Database", database) ctx = context.WithValue(ctx, "Database", database)
ctx = context.WithValue(ctx, "Schema", dbSchema)
pfp := path.Join(scriptDir, "pp.csv") pfp := path.Join(scriptDir, "pp.csv")
cfg := engine.Config{ cfg := engine.Config{

View File

@ -42,12 +42,14 @@ func main() {
var resourceDir string var resourceDir string
var size uint var size uint
var database string var database string
var dbSchema string
var engineDebug bool var engineDebug bool
var host string var host string
var port uint var port uint
flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from") flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from")
flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir") flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir")
flag.StringVar(&database, "db", "gdbm", "database to be used") flag.StringVar(&database, "db", "gdbm", "database to be used")
flag.StringVar(&dbSchema, "schema", "public", "database schema to be used")
flag.BoolVar(&engineDebug, "d", false, "use engine debug output") flag.BoolVar(&engineDebug, "d", false, "use engine debug output")
flag.UintVar(&size, "s", 160, "max size of output") flag.UintVar(&size, "s", 160, "max size of output")
flag.StringVar(&host, "h", initializers.GetEnv("HOST", "127.0.0.1"), "http host") flag.StringVar(&host, "h", initializers.GetEnv("HOST", "127.0.0.1"), "http host")
@ -58,6 +60,7 @@ func main() {
ctx := context.Background() ctx := context.Background()
ctx = context.WithValue(ctx, "Database", database) ctx = context.WithValue(ctx, "Database", database)
ctx = context.WithValue(ctx, "Schema", dbSchema)
pfp := path.Join(scriptDir, "pp.csv") pfp := path.Join(scriptDir, "pp.csv")
cfg := engine.Config{ cfg := engine.Config{

View File

@ -36,9 +36,11 @@ func main() {
var size uint var size uint
var sessionId string var sessionId string
var database string var database string
var dbSchema string
var engineDebug bool var engineDebug bool
flag.StringVar(&sessionId, "session-id", "075xx2123", "session id") flag.StringVar(&sessionId, "session-id", "075xx2123", "session id")
flag.StringVar(&database, "db", "gdbm", "database to be used") flag.StringVar(&database, "db", "gdbm", "database to be used")
flag.StringVar(&dbSchema, "schema", "public", "database schema to be used")
flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from") flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from")
flag.BoolVar(&engineDebug, "d", false, "use engine debug output") flag.BoolVar(&engineDebug, "d", false, "use engine debug output")
flag.UintVar(&size, "s", 160, "max size of output") flag.UintVar(&size, "s", 160, "max size of output")
@ -49,6 +51,7 @@ func main() {
ctx := context.Background() ctx := context.Background()
ctx = context.WithValue(ctx, "SessionId", sessionId) ctx = context.WithValue(ctx, "SessionId", sessionId)
ctx = context.WithValue(ctx, "Database", database) ctx = context.WithValue(ctx, "Database", database)
ctx = context.WithValue(ctx, "Schema", dbSchema)
pfp := path.Join(scriptDir, "pp.csv") pfp := path.Join(scriptDir, "pp.csv")
cfg := engine.Config{ cfg := engine.Config{

View File

@ -14,6 +14,7 @@ import (
"git.defalsify.org/vise.git/resource" "git.defalsify.org/vise.git/resource"
"git.grassecon.net/urdt/ussd/initializers" "git.grassecon.net/urdt/ussd/initializers"
gdbmstorage "git.grassecon.net/urdt/ussd/internal/storage/db/gdbm" gdbmstorage "git.grassecon.net/urdt/ussd/internal/storage/db/gdbm"
"github.com/jackc/pgx/v5/pgxpool"
) )
var ( var (
@ -35,7 +36,7 @@ type MenuStorageService struct {
userDataStore db.Db userDataStore db.Db
} }
func buildConnStr() string { func BuildConnStr() string {
host := initializers.GetEnv("DB_HOST", "localhost") host := initializers.GetEnv("DB_HOST", "localhost")
user := initializers.GetEnv("DB_USER", "postgres") user := initializers.GetEnv("DB_USER", "postgres")
password := initializers.GetEnv("DB_PASSWORD", "") password := initializers.GetEnv("DB_PASSWORD", "")
@ -64,6 +65,11 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D
return nil, fmt.Errorf("failed to select the database") return nil, fmt.Errorf("failed to select the database")
} }
schema, ok := ctx.Value("Schema").(string)
if !ok {
return nil, fmt.Errorf("failed to select the schema")
}
if existingDb != nil { if existingDb != nil {
return existingDb, nil return existingDb, nil
} }
@ -72,8 +78,15 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D
var err error var err error
if database == "postgres" { if database == "postgres" {
newDb = postgres.NewPgDb() connStr := BuildConnStr()
connStr := buildConnStr()
// Ensure the schema exists
err = ensureSchemaExists(ctx, connStr, schema)
if err != nil {
return nil, fmt.Errorf("failed to ensure schema exists: %w", err)
}
newDb = postgres.NewPgDb().WithSchema(schema)
err = newDb.Connect(ctx, connStr) err = newDb.Connect(ctx, connStr)
} else { } else {
newDb = gdbmstorage.NewThreadGdbmDb() newDb = gdbmstorage.NewThreadGdbmDb()
@ -88,6 +101,23 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D
return newDb, nil return newDb, nil
} }
// ensureSchemaExists creates a new schema if it does not exist
func ensureSchemaExists(ctx context.Context, connStr, schema string) error {
conn, err := pgxpool.New(ctx, connStr)
if err != nil {
return fmt.Errorf("failed to connect to the database: %w", err)
}
defer conn.Close()
query := fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema)
_, err = conn.Exec(ctx, query)
if err != nil {
return fmt.Errorf("failed to create schema: %w", err)
}
return nil
}
func (ms *MenuStorageService) GetPersister(ctx context.Context) (*persist.Persister, error) { func (ms *MenuStorageService) GetPersister(ctx context.Context) (*persist.Persister, error) {
stateStore, err := ms.GetStateStore(ctx) stateStore, err := ms.GetStateStore(ctx)
if err != nil { if err != nil {

View File

@ -24,6 +24,7 @@ var (
logg = logging.NewVanilla() logg = logging.NewVanilla()
scriptDir = path.Join(baseDir, "services", "registration") scriptDir = path.Join(baseDir, "services", "registration")
selectedDatabase = "" selectedDatabase = ""
selectedDbSchema = ""
) )
func init() { func init() {
@ -31,14 +32,17 @@ func init() {
} }
// SetDatabase updates the database used by TestEngine // SetDatabase updates the database used by TestEngine
func SetDatabase(dbType string) { func SetDatabase(dbType string, dbSchema string) {
selectedDatabase = dbType selectedDatabase = dbType
selectedDbSchema = dbSchema
} }
func TestEngine(sessionId string) (engine.Engine, func(), chan bool) { func TestEngine(sessionId string) (engine.Engine, func(), chan bool) {
ctx := context.Background() ctx := context.Background()
ctx = context.WithValue(ctx, "SessionId", sessionId) ctx = context.WithValue(ctx, "SessionId", sessionId)
ctx = context.WithValue(ctx, "Database", selectedDatabase) ctx = context.WithValue(ctx, "Database", selectedDatabase)
ctx = context.WithValue(ctx, "Schema", selectedDbSchema)
pfp := path.Join(scriptDir, "pp.csv") pfp := path.Join(scriptDir, "pp.csv")
var eventChannel = make(chan bool) var eventChannel = make(chan bool)

View File

@ -4,15 +4,18 @@ import (
"bytes" "bytes"
"context" "context"
"flag" "flag"
"fmt"
"log" "log"
"math/rand" "math/rand"
"os" "os"
"regexp" "regexp"
"testing" "testing"
"git.grassecon.net/urdt/ussd/internal/storage"
"git.grassecon.net/urdt/ussd/internal/testutil" "git.grassecon.net/urdt/ussd/internal/testutil"
"git.grassecon.net/urdt/ussd/internal/testutil/driver" "git.grassecon.net/urdt/ussd/internal/testutil/driver"
"github.com/gofrs/uuid" "github.com/gofrs/uuid"
"github.com/jackc/pgx/v5/pgxpool"
) )
var ( var (
@ -25,6 +28,7 @@ var (
var groupTestFile = flag.String("test-file", "group_test.json", "The test file to use for running the group tests") var groupTestFile = flag.String("test-file", "group_test.json", "The test file to use for running the group tests")
var database = flag.String("db", "gdbm", "Specify the database (gdbm or postgres)") var database = flag.String("db", "gdbm", "Specify the database (gdbm or postgres)")
var dbSchema = flag.String("schema", "test", "Specify the database schema (default test)")
func GenerateSessionId() string { func GenerateSessionId() string {
uu := uuid.NewGenWithOptions(uuid.WithRandomReader(g)) uu := uuid.NewGenWithOptions(uuid.WithRandomReader(g))
@ -91,7 +95,29 @@ func TestMain(m *testing.M) {
}() }()
// Set the selected database // Set the selected database
testutil.SetDatabase(*database) testutil.SetDatabase(*database, *dbSchema)
// Cleanup the schema table after tests
defer func() {
if *database == "postgres" {
ctx := context.Background()
connStr := storage.BuildConnStr()
dbConn, err := pgxpool.New(ctx, connStr)
if err != nil {
log.Fatalf("Failed to connect to database for cleanup: %v", err)
}
defer dbConn.Close()
query := fmt.Sprintf("DELETE FROM %s.kv_vise;", *dbSchema)
_, execErr := dbConn.Exec(ctx, query)
if execErr != nil {
log.Printf("Failed to cleanup table %s.kv_vise: %v", *dbSchema, execErr)
} else {
log.Printf("Successfully cleaned up table %s.kv_vise", *dbSchema)
}
}
}()
m.Run() m.Run()
} }