Compare commits
5 Commits
c12e867ac3
...
ea9cab930e
| Author | SHA1 | Date | |
|---|---|---|---|
| ea9cab930e | |||
| a37f6e6da3 | |||
| f59c3a53ef | |||
| 81c3378ea6 | |||
| 46a6d2bc6e |
@ -43,12 +43,14 @@ func main() {
|
|||||||
var resourceDir string
|
var resourceDir string
|
||||||
var size uint
|
var size uint
|
||||||
var database string
|
var database string
|
||||||
|
var dbSchema string
|
||||||
var engineDebug bool
|
var engineDebug bool
|
||||||
var host string
|
var host string
|
||||||
var port uint
|
var port uint
|
||||||
flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from")
|
flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from")
|
||||||
flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir")
|
flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir")
|
||||||
flag.StringVar(&database, "db", "gdbm", "database to be used")
|
flag.StringVar(&database, "db", "gdbm", "database to be used")
|
||||||
|
flag.StringVar(&dbSchema, "schema", "public", "database schema to be used")
|
||||||
flag.BoolVar(&engineDebug, "d", false, "use engine debug output")
|
flag.BoolVar(&engineDebug, "d", false, "use engine debug output")
|
||||||
flag.UintVar(&size, "s", 160, "max size of output")
|
flag.UintVar(&size, "s", 160, "max size of output")
|
||||||
flag.StringVar(&host, "h", initializers.GetEnv("HOST", "127.0.0.1"), "http host")
|
flag.StringVar(&host, "h", initializers.GetEnv("HOST", "127.0.0.1"), "http host")
|
||||||
@ -59,6 +61,7 @@ func main() {
|
|||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
ctx = context.WithValue(ctx, "Database", database)
|
ctx = context.WithValue(ctx, "Database", database)
|
||||||
|
ctx = context.WithValue(ctx, "Schema", dbSchema)
|
||||||
pfp := path.Join(scriptDir, "pp.csv")
|
pfp := path.Join(scriptDir, "pp.csv")
|
||||||
|
|
||||||
cfg := engine.Config{
|
cfg := engine.Config{
|
||||||
|
|||||||
@ -53,6 +53,7 @@ func main() {
|
|||||||
var resourceDir string
|
var resourceDir string
|
||||||
var size uint
|
var size uint
|
||||||
var database string
|
var database string
|
||||||
|
var dbSchema string
|
||||||
var engineDebug bool
|
var engineDebug bool
|
||||||
var host string
|
var host string
|
||||||
var port uint
|
var port uint
|
||||||
@ -60,6 +61,7 @@ func main() {
|
|||||||
flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from")
|
flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from")
|
||||||
flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir")
|
flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir")
|
||||||
flag.StringVar(&database, "db", "gdbm", "database to be used")
|
flag.StringVar(&database, "db", "gdbm", "database to be used")
|
||||||
|
flag.StringVar(&dbSchema, "schema", "public", "database schema to be used")
|
||||||
flag.BoolVar(&engineDebug, "d", false, "use engine debug output")
|
flag.BoolVar(&engineDebug, "d", false, "use engine debug output")
|
||||||
flag.UintVar(&size, "s", 160, "max size of output")
|
flag.UintVar(&size, "s", 160, "max size of output")
|
||||||
flag.StringVar(&host, "h", initializers.GetEnv("HOST", "127.0.0.1"), "http host")
|
flag.StringVar(&host, "h", initializers.GetEnv("HOST", "127.0.0.1"), "http host")
|
||||||
@ -70,6 +72,7 @@ func main() {
|
|||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
ctx = context.WithValue(ctx, "Database", database)
|
ctx = context.WithValue(ctx, "Database", database)
|
||||||
|
ctx = context.WithValue(ctx, "Schema", dbSchema)
|
||||||
pfp := path.Join(scriptDir, "pp.csv")
|
pfp := path.Join(scriptDir, "pp.csv")
|
||||||
|
|
||||||
cfg := engine.Config{
|
cfg := engine.Config{
|
||||||
|
|||||||
@ -42,12 +42,14 @@ func main() {
|
|||||||
var resourceDir string
|
var resourceDir string
|
||||||
var size uint
|
var size uint
|
||||||
var database string
|
var database string
|
||||||
|
var dbSchema string
|
||||||
var engineDebug bool
|
var engineDebug bool
|
||||||
var host string
|
var host string
|
||||||
var port uint
|
var port uint
|
||||||
flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from")
|
flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from")
|
||||||
flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir")
|
flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir")
|
||||||
flag.StringVar(&database, "db", "gdbm", "database to be used")
|
flag.StringVar(&database, "db", "gdbm", "database to be used")
|
||||||
|
flag.StringVar(&dbSchema, "schema", "public", "database schema to be used")
|
||||||
flag.BoolVar(&engineDebug, "d", false, "use engine debug output")
|
flag.BoolVar(&engineDebug, "d", false, "use engine debug output")
|
||||||
flag.UintVar(&size, "s", 160, "max size of output")
|
flag.UintVar(&size, "s", 160, "max size of output")
|
||||||
flag.StringVar(&host, "h", initializers.GetEnv("HOST", "127.0.0.1"), "http host")
|
flag.StringVar(&host, "h", initializers.GetEnv("HOST", "127.0.0.1"), "http host")
|
||||||
@ -58,6 +60,7 @@ func main() {
|
|||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
ctx = context.WithValue(ctx, "Database", database)
|
ctx = context.WithValue(ctx, "Database", database)
|
||||||
|
ctx = context.WithValue(ctx, "Schema", dbSchema)
|
||||||
pfp := path.Join(scriptDir, "pp.csv")
|
pfp := path.Join(scriptDir, "pp.csv")
|
||||||
|
|
||||||
cfg := engine.Config{
|
cfg := engine.Config{
|
||||||
|
|||||||
@ -36,9 +36,11 @@ func main() {
|
|||||||
var size uint
|
var size uint
|
||||||
var sessionId string
|
var sessionId string
|
||||||
var database string
|
var database string
|
||||||
|
var dbSchema string
|
||||||
var engineDebug bool
|
var engineDebug bool
|
||||||
flag.StringVar(&sessionId, "session-id", "075xx2123", "session id")
|
flag.StringVar(&sessionId, "session-id", "075xx2123", "session id")
|
||||||
flag.StringVar(&database, "db", "gdbm", "database to be used")
|
flag.StringVar(&database, "db", "gdbm", "database to be used")
|
||||||
|
flag.StringVar(&dbSchema, "schema", "public", "database schema to be used")
|
||||||
flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from")
|
flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from")
|
||||||
flag.BoolVar(&engineDebug, "d", false, "use engine debug output")
|
flag.BoolVar(&engineDebug, "d", false, "use engine debug output")
|
||||||
flag.UintVar(&size, "s", 160, "max size of output")
|
flag.UintVar(&size, "s", 160, "max size of output")
|
||||||
@ -49,6 +51,7 @@ func main() {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
ctx = context.WithValue(ctx, "SessionId", sessionId)
|
ctx = context.WithValue(ctx, "SessionId", sessionId)
|
||||||
ctx = context.WithValue(ctx, "Database", database)
|
ctx = context.WithValue(ctx, "Database", database)
|
||||||
|
ctx = context.WithValue(ctx, "Schema", dbSchema)
|
||||||
pfp := path.Join(scriptDir, "pp.csv")
|
pfp := path.Join(scriptDir, "pp.csv")
|
||||||
|
|
||||||
cfg := engine.Config{
|
cfg := engine.Config{
|
||||||
|
|||||||
@ -14,6 +14,7 @@ import (
|
|||||||
"git.defalsify.org/vise.git/resource"
|
"git.defalsify.org/vise.git/resource"
|
||||||
"git.grassecon.net/urdt/ussd/initializers"
|
"git.grassecon.net/urdt/ussd/initializers"
|
||||||
gdbmstorage "git.grassecon.net/urdt/ussd/internal/storage/db/gdbm"
|
gdbmstorage "git.grassecon.net/urdt/ussd/internal/storage/db/gdbm"
|
||||||
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -35,7 +36,7 @@ type MenuStorageService struct {
|
|||||||
userDataStore db.Db
|
userDataStore db.Db
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildConnStr() string {
|
func BuildConnStr() string {
|
||||||
host := initializers.GetEnv("DB_HOST", "localhost")
|
host := initializers.GetEnv("DB_HOST", "localhost")
|
||||||
user := initializers.GetEnv("DB_USER", "postgres")
|
user := initializers.GetEnv("DB_USER", "postgres")
|
||||||
password := initializers.GetEnv("DB_PASSWORD", "")
|
password := initializers.GetEnv("DB_PASSWORD", "")
|
||||||
@ -64,6 +65,11 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D
|
|||||||
return nil, fmt.Errorf("failed to select the database")
|
return nil, fmt.Errorf("failed to select the database")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
schema, ok := ctx.Value("Schema").(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("failed to select the schema")
|
||||||
|
}
|
||||||
|
|
||||||
if existingDb != nil {
|
if existingDb != nil {
|
||||||
return existingDb, nil
|
return existingDb, nil
|
||||||
}
|
}
|
||||||
@ -72,8 +78,15 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D
|
|||||||
var err error
|
var err error
|
||||||
|
|
||||||
if database == "postgres" {
|
if database == "postgres" {
|
||||||
newDb = postgres.NewPgDb()
|
connStr := BuildConnStr()
|
||||||
connStr := buildConnStr()
|
|
||||||
|
// Ensure the schema exists
|
||||||
|
err = ensureSchemaExists(ctx, connStr, schema)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to ensure schema exists: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
newDb = postgres.NewPgDb().WithSchema(schema)
|
||||||
err = newDb.Connect(ctx, connStr)
|
err = newDb.Connect(ctx, connStr)
|
||||||
} else {
|
} else {
|
||||||
newDb = gdbmstorage.NewThreadGdbmDb()
|
newDb = gdbmstorage.NewThreadGdbmDb()
|
||||||
@ -88,6 +101,23 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D
|
|||||||
return newDb, nil
|
return newDb, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ensureSchemaExists creates a new schema if it does not exist
|
||||||
|
func ensureSchemaExists(ctx context.Context, connStr, schema string) error {
|
||||||
|
conn, err := pgxpool.New(ctx, connStr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to connect to the database: %w", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
query := fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema)
|
||||||
|
_, err = conn.Exec(ctx, query)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create schema: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (ms *MenuStorageService) GetPersister(ctx context.Context) (*persist.Persister, error) {
|
func (ms *MenuStorageService) GetPersister(ctx context.Context) (*persist.Persister, error) {
|
||||||
stateStore, err := ms.GetStateStore(ctx)
|
stateStore, err := ms.GetStateStore(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@ -24,6 +24,7 @@ var (
|
|||||||
logg = logging.NewVanilla()
|
logg = logging.NewVanilla()
|
||||||
scriptDir = path.Join(baseDir, "services", "registration")
|
scriptDir = path.Join(baseDir, "services", "registration")
|
||||||
selectedDatabase = ""
|
selectedDatabase = ""
|
||||||
|
selectedDbSchema = ""
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@ -31,14 +32,17 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SetDatabase updates the database used by TestEngine
|
// SetDatabase updates the database used by TestEngine
|
||||||
func SetDatabase(dbType string) {
|
func SetDatabase(dbType string, dbSchema string) {
|
||||||
selectedDatabase = dbType
|
selectedDatabase = dbType
|
||||||
|
selectedDbSchema = dbSchema
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEngine(sessionId string) (engine.Engine, func(), chan bool) {
|
func TestEngine(sessionId string) (engine.Engine, func(), chan bool) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
ctx = context.WithValue(ctx, "SessionId", sessionId)
|
ctx = context.WithValue(ctx, "SessionId", sessionId)
|
||||||
ctx = context.WithValue(ctx, "Database", selectedDatabase)
|
ctx = context.WithValue(ctx, "Database", selectedDatabase)
|
||||||
|
ctx = context.WithValue(ctx, "Schema", selectedDbSchema)
|
||||||
|
|
||||||
pfp := path.Join(scriptDir, "pp.csv")
|
pfp := path.Join(scriptDir, "pp.csv")
|
||||||
|
|
||||||
var eventChannel = make(chan bool)
|
var eventChannel = make(chan bool)
|
||||||
|
|||||||
@ -4,15 +4,18 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"flag"
|
"flag"
|
||||||
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"os"
|
"os"
|
||||||
"regexp"
|
"regexp"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"git.grassecon.net/urdt/ussd/internal/storage"
|
||||||
"git.grassecon.net/urdt/ussd/internal/testutil"
|
"git.grassecon.net/urdt/ussd/internal/testutil"
|
||||||
"git.grassecon.net/urdt/ussd/internal/testutil/driver"
|
"git.grassecon.net/urdt/ussd/internal/testutil/driver"
|
||||||
"github.com/gofrs/uuid"
|
"github.com/gofrs/uuid"
|
||||||
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -25,6 +28,7 @@ var (
|
|||||||
|
|
||||||
var groupTestFile = flag.String("test-file", "group_test.json", "The test file to use for running the group tests")
|
var groupTestFile = flag.String("test-file", "group_test.json", "The test file to use for running the group tests")
|
||||||
var database = flag.String("db", "gdbm", "Specify the database (gdbm or postgres)")
|
var database = flag.String("db", "gdbm", "Specify the database (gdbm or postgres)")
|
||||||
|
var dbSchema = flag.String("schema", "test", "Specify the database schema (default test)")
|
||||||
|
|
||||||
func GenerateSessionId() string {
|
func GenerateSessionId() string {
|
||||||
uu := uuid.NewGenWithOptions(uuid.WithRandomReader(g))
|
uu := uuid.NewGenWithOptions(uuid.WithRandomReader(g))
|
||||||
@ -91,7 +95,29 @@ func TestMain(m *testing.M) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
// Set the selected database
|
// Set the selected database
|
||||||
testutil.SetDatabase(*database)
|
testutil.SetDatabase(*database, *dbSchema)
|
||||||
|
|
||||||
|
// Cleanup the schema table after tests
|
||||||
|
defer func() {
|
||||||
|
if *database == "postgres" {
|
||||||
|
ctx := context.Background()
|
||||||
|
connStr := storage.BuildConnStr()
|
||||||
|
dbConn, err := pgxpool.New(ctx, connStr)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to connect to database for cleanup: %v", err)
|
||||||
|
}
|
||||||
|
defer dbConn.Close()
|
||||||
|
|
||||||
|
query := fmt.Sprintf("DELETE FROM %s.kv_vise;", *dbSchema)
|
||||||
|
_, execErr := dbConn.Exec(ctx, query)
|
||||||
|
if execErr != nil {
|
||||||
|
log.Printf("Failed to cleanup table %s.kv_vise: %v", *dbSchema, execErr)
|
||||||
|
} else {
|
||||||
|
log.Printf("Successfully cleaned up table %s.kv_vise", *dbSchema)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
m.Run()
|
m.Run()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user