Compare commits
5 Commits
c12e867ac3
...
ea9cab930e
| Author | SHA1 | Date | |
|---|---|---|---|
| ea9cab930e | |||
| a37f6e6da3 | |||
| f59c3a53ef | |||
| 81c3378ea6 | |||
| 46a6d2bc6e |
@ -43,12 +43,14 @@ func main() {
|
||||
var resourceDir string
|
||||
var size uint
|
||||
var database string
|
||||
var dbSchema string
|
||||
var engineDebug bool
|
||||
var host string
|
||||
var port uint
|
||||
flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from")
|
||||
flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir")
|
||||
flag.StringVar(&database, "db", "gdbm", "database to be used")
|
||||
flag.StringVar(&dbSchema, "schema", "public", "database schema to be used")
|
||||
flag.BoolVar(&engineDebug, "d", false, "use engine debug output")
|
||||
flag.UintVar(&size, "s", 160, "max size of output")
|
||||
flag.StringVar(&host, "h", initializers.GetEnv("HOST", "127.0.0.1"), "http host")
|
||||
@ -59,6 +61,7 @@ func main() {
|
||||
|
||||
ctx := context.Background()
|
||||
ctx = context.WithValue(ctx, "Database", database)
|
||||
ctx = context.WithValue(ctx, "Schema", dbSchema)
|
||||
pfp := path.Join(scriptDir, "pp.csv")
|
||||
|
||||
cfg := engine.Config{
|
||||
|
||||
@ -53,6 +53,7 @@ func main() {
|
||||
var resourceDir string
|
||||
var size uint
|
||||
var database string
|
||||
var dbSchema string
|
||||
var engineDebug bool
|
||||
var host string
|
||||
var port uint
|
||||
@ -60,6 +61,7 @@ func main() {
|
||||
flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from")
|
||||
flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir")
|
||||
flag.StringVar(&database, "db", "gdbm", "database to be used")
|
||||
flag.StringVar(&dbSchema, "schema", "public", "database schema to be used")
|
||||
flag.BoolVar(&engineDebug, "d", false, "use engine debug output")
|
||||
flag.UintVar(&size, "s", 160, "max size of output")
|
||||
flag.StringVar(&host, "h", initializers.GetEnv("HOST", "127.0.0.1"), "http host")
|
||||
@ -70,6 +72,7 @@ func main() {
|
||||
|
||||
ctx := context.Background()
|
||||
ctx = context.WithValue(ctx, "Database", database)
|
||||
ctx = context.WithValue(ctx, "Schema", dbSchema)
|
||||
pfp := path.Join(scriptDir, "pp.csv")
|
||||
|
||||
cfg := engine.Config{
|
||||
|
||||
@ -42,12 +42,14 @@ func main() {
|
||||
var resourceDir string
|
||||
var size uint
|
||||
var database string
|
||||
var dbSchema string
|
||||
var engineDebug bool
|
||||
var host string
|
||||
var port uint
|
||||
flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from")
|
||||
flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir")
|
||||
flag.StringVar(&database, "db", "gdbm", "database to be used")
|
||||
flag.StringVar(&dbSchema, "schema", "public", "database schema to be used")
|
||||
flag.BoolVar(&engineDebug, "d", false, "use engine debug output")
|
||||
flag.UintVar(&size, "s", 160, "max size of output")
|
||||
flag.StringVar(&host, "h", initializers.GetEnv("HOST", "127.0.0.1"), "http host")
|
||||
@ -58,6 +60,7 @@ func main() {
|
||||
|
||||
ctx := context.Background()
|
||||
ctx = context.WithValue(ctx, "Database", database)
|
||||
ctx = context.WithValue(ctx, "Schema", dbSchema)
|
||||
pfp := path.Join(scriptDir, "pp.csv")
|
||||
|
||||
cfg := engine.Config{
|
||||
|
||||
@ -36,9 +36,11 @@ func main() {
|
||||
var size uint
|
||||
var sessionId string
|
||||
var database string
|
||||
var dbSchema string
|
||||
var engineDebug bool
|
||||
flag.StringVar(&sessionId, "session-id", "075xx2123", "session id")
|
||||
flag.StringVar(&database, "db", "gdbm", "database to be used")
|
||||
flag.StringVar(&dbSchema, "schema", "public", "database schema to be used")
|
||||
flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from")
|
||||
flag.BoolVar(&engineDebug, "d", false, "use engine debug output")
|
||||
flag.UintVar(&size, "s", 160, "max size of output")
|
||||
@ -49,6 +51,7 @@ func main() {
|
||||
ctx := context.Background()
|
||||
ctx = context.WithValue(ctx, "SessionId", sessionId)
|
||||
ctx = context.WithValue(ctx, "Database", database)
|
||||
ctx = context.WithValue(ctx, "Schema", dbSchema)
|
||||
pfp := path.Join(scriptDir, "pp.csv")
|
||||
|
||||
cfg := engine.Config{
|
||||
|
||||
@ -14,6 +14,7 @@ import (
|
||||
"git.defalsify.org/vise.git/resource"
|
||||
"git.grassecon.net/urdt/ussd/initializers"
|
||||
gdbmstorage "git.grassecon.net/urdt/ussd/internal/storage/db/gdbm"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -35,7 +36,7 @@ type MenuStorageService struct {
|
||||
userDataStore db.Db
|
||||
}
|
||||
|
||||
func buildConnStr() string {
|
||||
func BuildConnStr() string {
|
||||
host := initializers.GetEnv("DB_HOST", "localhost")
|
||||
user := initializers.GetEnv("DB_USER", "postgres")
|
||||
password := initializers.GetEnv("DB_PASSWORD", "")
|
||||
@ -64,6 +65,11 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D
|
||||
return nil, fmt.Errorf("failed to select the database")
|
||||
}
|
||||
|
||||
schema, ok := ctx.Value("Schema").(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to select the schema")
|
||||
}
|
||||
|
||||
if existingDb != nil {
|
||||
return existingDb, nil
|
||||
}
|
||||
@ -72,8 +78,15 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D
|
||||
var err error
|
||||
|
||||
if database == "postgres" {
|
||||
newDb = postgres.NewPgDb()
|
||||
connStr := buildConnStr()
|
||||
connStr := BuildConnStr()
|
||||
|
||||
// Ensure the schema exists
|
||||
err = ensureSchemaExists(ctx, connStr, schema)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to ensure schema exists: %w", err)
|
||||
}
|
||||
|
||||
newDb = postgres.NewPgDb().WithSchema(schema)
|
||||
err = newDb.Connect(ctx, connStr)
|
||||
} else {
|
||||
newDb = gdbmstorage.NewThreadGdbmDb()
|
||||
@ -88,6 +101,23 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D
|
||||
return newDb, nil
|
||||
}
|
||||
|
||||
// ensureSchemaExists creates a new schema if it does not exist
|
||||
func ensureSchemaExists(ctx context.Context, connStr, schema string) error {
|
||||
conn, err := pgxpool.New(ctx, connStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to the database: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
query := fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema)
|
||||
_, err = conn.Exec(ctx, query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create schema: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ms *MenuStorageService) GetPersister(ctx context.Context) (*persist.Persister, error) {
|
||||
stateStore, err := ms.GetStateStore(ctx)
|
||||
if err != nil {
|
||||
|
||||
@ -24,6 +24,7 @@ var (
|
||||
logg = logging.NewVanilla()
|
||||
scriptDir = path.Join(baseDir, "services", "registration")
|
||||
selectedDatabase = ""
|
||||
selectedDbSchema = ""
|
||||
)
|
||||
|
||||
func init() {
|
||||
@ -31,14 +32,17 @@ func init() {
|
||||
}
|
||||
|
||||
// SetDatabase updates the database used by TestEngine
|
||||
func SetDatabase(dbType string) {
|
||||
func SetDatabase(dbType string, dbSchema string) {
|
||||
selectedDatabase = dbType
|
||||
selectedDbSchema = dbSchema
|
||||
}
|
||||
|
||||
func TestEngine(sessionId string) (engine.Engine, func(), chan bool) {
|
||||
ctx := context.Background()
|
||||
ctx = context.WithValue(ctx, "SessionId", sessionId)
|
||||
ctx = context.WithValue(ctx, "Database", selectedDatabase)
|
||||
ctx = context.WithValue(ctx, "Schema", selectedDbSchema)
|
||||
|
||||
pfp := path.Join(scriptDir, "pp.csv")
|
||||
|
||||
var eventChannel = make(chan bool)
|
||||
|
||||
@ -4,15 +4,18 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
"os"
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"git.grassecon.net/urdt/ussd/internal/storage"
|
||||
"git.grassecon.net/urdt/ussd/internal/testutil"
|
||||
"git.grassecon.net/urdt/ussd/internal/testutil/driver"
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -25,6 +28,7 @@ var (
|
||||
|
||||
var groupTestFile = flag.String("test-file", "group_test.json", "The test file to use for running the group tests")
|
||||
var database = flag.String("db", "gdbm", "Specify the database (gdbm or postgres)")
|
||||
var dbSchema = flag.String("schema", "test", "Specify the database schema (default test)")
|
||||
|
||||
func GenerateSessionId() string {
|
||||
uu := uuid.NewGenWithOptions(uuid.WithRandomReader(g))
|
||||
@ -91,7 +95,29 @@ func TestMain(m *testing.M) {
|
||||
}()
|
||||
|
||||
// Set the selected database
|
||||
testutil.SetDatabase(*database)
|
||||
testutil.SetDatabase(*database, *dbSchema)
|
||||
|
||||
// Cleanup the schema table after tests
|
||||
defer func() {
|
||||
if *database == "postgres" {
|
||||
ctx := context.Background()
|
||||
connStr := storage.BuildConnStr()
|
||||
dbConn, err := pgxpool.New(ctx, connStr)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to connect to database for cleanup: %v", err)
|
||||
}
|
||||
defer dbConn.Close()
|
||||
|
||||
query := fmt.Sprintf("DELETE FROM %s.kv_vise;", *dbSchema)
|
||||
_, execErr := dbConn.Exec(ctx, query)
|
||||
if execErr != nil {
|
||||
log.Printf("Failed to cleanup table %s.kv_vise: %v", *dbSchema, execErr)
|
||||
} else {
|
||||
log.Printf("Successfully cleaned up table %s.kv_vise", *dbSchema)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
m.Run()
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user