Merge pull request 'postgres-switch-for-tests' (#255) from postgres-switch-for-tests into master

Reviewed-on: #255
This commit is contained in:
lash 2025-01-10 12:07:07 +01:00
commit 8f5ed0cd4f
10 changed files with 190 additions and 73 deletions

View File

@ -12,17 +12,17 @@ import (
"syscall" "syscall"
"git.defalsify.org/vise.git/engine" "git.defalsify.org/vise.git/engine"
"git.defalsify.org/vise.git/lang"
"git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/logging"
"git.defalsify.org/vise.git/resource" "git.defalsify.org/vise.git/resource"
"git.defalsify.org/vise.git/lang"
"git.grassecon.net/urdt/ussd/config" "git.grassecon.net/urdt/ussd/config"
"git.grassecon.net/urdt/ussd/initializers" "git.grassecon.net/urdt/ussd/initializers"
"git.grassecon.net/urdt/ussd/internal/args"
"git.grassecon.net/urdt/ussd/internal/handlers" "git.grassecon.net/urdt/ussd/internal/handlers"
"git.grassecon.net/urdt/ussd/internal/http/at" "git.grassecon.net/urdt/ussd/internal/http/at"
"git.grassecon.net/urdt/ussd/internal/storage" "git.grassecon.net/urdt/ussd/internal/storage"
"git.grassecon.net/urdt/ussd/remote" "git.grassecon.net/urdt/ussd/remote"
"git.grassecon.net/urdt/ussd/internal/args"
) )
var ( var (
@ -42,7 +42,6 @@ func main() {
var connStr string var connStr string
var resourceDir string var resourceDir string
var size uint var size uint
var database string
var engineDebug bool var engineDebug bool
var host string var host string
var port uint var port uint
@ -60,7 +59,7 @@ func main() {
flag.Var(&langs, "language", "add symbol resolution for language") flag.Var(&langs, "language", "add symbol resolution for language")
flag.Parse() flag.Parse()
if connStr != "" { if connStr == "" {
connStr = config.DbConn connStr = config.DbConn
} }
connData, err := storage.ToConnData(connStr) connData, err := storage.ToConnData(connStr)
@ -72,7 +71,6 @@ func main() {
logg.Infof("start command", "build", build, "conn", connData, "resourcedir", resourceDir, "outputsize", size) logg.Infof("start command", "build", build, "conn", connData, "resourcedir", resourceDir, "outputsize", size)
ctx := context.Background() ctx := context.Background()
ctx = context.WithValue(ctx, "Database", database)
ln, err := lang.LanguageFromCode(config.DefaultLanguage) ln, err := lang.LanguageFromCode(config.DefaultLanguage)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "default language set error: %v", err) fmt.Fprintf(os.Stderr, "default language set error: %v", err)

View File

@ -10,16 +10,16 @@ import (
"syscall" "syscall"
"git.defalsify.org/vise.git/engine" "git.defalsify.org/vise.git/engine"
"git.defalsify.org/vise.git/lang"
"git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/logging"
"git.defalsify.org/vise.git/resource" "git.defalsify.org/vise.git/resource"
"git.defalsify.org/vise.git/lang"
"git.grassecon.net/urdt/ussd/config" "git.grassecon.net/urdt/ussd/config"
"git.grassecon.net/urdt/ussd/initializers" "git.grassecon.net/urdt/ussd/initializers"
"git.grassecon.net/urdt/ussd/internal/args"
"git.grassecon.net/urdt/ussd/internal/handlers" "git.grassecon.net/urdt/ussd/internal/handlers"
"git.grassecon.net/urdt/ussd/internal/storage" "git.grassecon.net/urdt/ussd/internal/storage"
"git.grassecon.net/urdt/ussd/remote" "git.grassecon.net/urdt/ussd/remote"
"git.grassecon.net/urdt/ussd/internal/args"
) )
var ( var (
@ -52,7 +52,6 @@ func main() {
var sessionId string var sessionId string
var resourceDir string var resourceDir string
var size uint var size uint
var database string
var engineDebug bool var engineDebug bool
var host string var host string
var port uint var port uint
@ -71,7 +70,7 @@ func main() {
flag.Var(&langs, "language", "add symbol resolution for language") flag.Var(&langs, "language", "add symbol resolution for language")
flag.Parse() flag.Parse()
if connStr != "" { if connStr == "" {
connStr = config.DbConn connStr = config.DbConn
} }
connData, err := storage.ToConnData(connStr) connData, err := storage.ToConnData(connStr)
@ -83,7 +82,6 @@ func main() {
logg.Infof("start command", "conn", connData, "resourcedir", resourceDir, "outputsize", size, "sessionId", sessionId) logg.Infof("start command", "conn", connData, "resourcedir", resourceDir, "outputsize", size, "sessionId", sessionId)
ctx := context.Background() ctx := context.Background()
ctx = context.WithValue(ctx, "Database", database)
ln, err := lang.LanguageFromCode(config.DefaultLanguage) ln, err := lang.LanguageFromCode(config.DefaultLanguage)
if err != nil { if err != nil {
@ -117,7 +115,6 @@ func main() {
os.Exit(1) os.Exit(1)
} }
userdataStore, err := menuStorageService.GetUserdataDb(ctx) userdataStore, err := menuStorageService.GetUserdataDb(ctx)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, err.Error()) fmt.Fprintf(os.Stderr, err.Error())

View File

@ -12,17 +12,17 @@ import (
"syscall" "syscall"
"git.defalsify.org/vise.git/engine" "git.defalsify.org/vise.git/engine"
"git.defalsify.org/vise.git/lang"
"git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/logging"
"git.defalsify.org/vise.git/resource" "git.defalsify.org/vise.git/resource"
"git.defalsify.org/vise.git/lang"
"git.grassecon.net/urdt/ussd/config" "git.grassecon.net/urdt/ussd/config"
"git.grassecon.net/urdt/ussd/initializers" "git.grassecon.net/urdt/ussd/initializers"
"git.grassecon.net/urdt/ussd/internal/args"
"git.grassecon.net/urdt/ussd/internal/handlers" "git.grassecon.net/urdt/ussd/internal/handlers"
httpserver "git.grassecon.net/urdt/ussd/internal/http" httpserver "git.grassecon.net/urdt/ussd/internal/http"
"git.grassecon.net/urdt/ussd/internal/storage" "git.grassecon.net/urdt/ussd/internal/storage"
"git.grassecon.net/urdt/ussd/remote" "git.grassecon.net/urdt/ussd/remote"
"git.grassecon.net/urdt/ussd/internal/args"
) )
var ( var (
@ -41,7 +41,6 @@ func main() {
var connStr string var connStr string
var resourceDir string var resourceDir string
var size uint var size uint
var database string
var engineDebug bool var engineDebug bool
var host string var host string
var port uint var port uint
@ -59,7 +58,7 @@ func main() {
flag.Var(&langs, "language", "add symbol resolution for language") flag.Var(&langs, "language", "add symbol resolution for language")
flag.Parse() flag.Parse()
if connStr != "" { if connStr == "" {
connStr = config.DbConn connStr = config.DbConn
} }
connData, err := storage.ToConnData(connStr) connData, err := storage.ToConnData(connStr)
@ -71,7 +70,6 @@ func main() {
logg.Infof("start command", "conn", connData, "resourcedir", resourceDir, "outputsize", size) logg.Infof("start command", "conn", connData, "resourcedir", resourceDir, "outputsize", size)
ctx := context.Background() ctx := context.Background()
ctx = context.WithValue(ctx, "Database", database)
ln, err := lang.LanguageFromCode(config.DefaultLanguage) ln, err := lang.LanguageFromCode(config.DefaultLanguage)
if err != nil { if err != nil {

View File

@ -8,14 +8,14 @@ import (
"path" "path"
"git.defalsify.org/vise.git/engine" "git.defalsify.org/vise.git/engine"
"git.defalsify.org/vise.git/lang"
"git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/logging"
"git.defalsify.org/vise.git/resource" "git.defalsify.org/vise.git/resource"
"git.defalsify.org/vise.git/lang"
"git.grassecon.net/urdt/ussd/config" "git.grassecon.net/urdt/ussd/config"
"git.grassecon.net/urdt/ussd/initializers" "git.grassecon.net/urdt/ussd/initializers"
"git.grassecon.net/urdt/ussd/internal/args"
"git.grassecon.net/urdt/ussd/internal/handlers" "git.grassecon.net/urdt/ussd/internal/handlers"
"git.grassecon.net/urdt/ussd/internal/storage" "git.grassecon.net/urdt/ussd/internal/storage"
"git.grassecon.net/urdt/ussd/internal/args"
"git.grassecon.net/urdt/ussd/remote" "git.grassecon.net/urdt/ussd/remote"
) )
@ -36,7 +36,6 @@ func main() {
var connStr string var connStr string
var size uint var size uint
var sessionId string var sessionId string
var database string
var engineDebug bool var engineDebug bool
var resourceDir string var resourceDir string
var err error var err error
@ -52,7 +51,7 @@ func main() {
flag.Var(&langs, "language", "add symbol resolution for language") flag.Var(&langs, "language", "add symbol resolution for language")
flag.Parse() flag.Parse()
if connStr != "" { if connStr == "" {
connStr = config.DbConn connStr = config.DbConn
} }
connData, err := storage.ToConnData(connStr) connData, err := storage.ToConnData(connStr)
@ -69,7 +68,6 @@ func main() {
ctx := context.Background() ctx := context.Background()
ctx = context.WithValue(ctx, "SessionId", sessionId) ctx = context.WithValue(ctx, "SessionId", sessionId)
ctx = context.WithValue(ctx, "Database", database)
ln, err := lang.LanguageFromCode(config.DefaultLanguage) ln, err := lang.LanguageFromCode(config.DefaultLanguage)
if err != nil { if err != nil {

View File

@ -9,14 +9,16 @@ import (
"path" "path"
"git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/logging"
"git.grassecon.net/urdt/ussd/config"
"git.grassecon.net/urdt/ussd/internal/storage"
"git.grassecon.net/urdt/ussd/initializers"
"git.grassecon.net/urdt/ussd/common" "git.grassecon.net/urdt/ussd/common"
"git.grassecon.net/urdt/ussd/config"
"git.grassecon.net/urdt/ussd/initializers"
"git.grassecon.net/urdt/ussd/internal/storage"
testdataloader "github.com/peteole/testdata-loader"
) )
var ( var (
logg = logging.NewVanilla() logg = logging.NewVanilla()
baseDir = testdataloader.GetBasePath()
scriptDir = path.Join("services", "registration") scriptDir = path.Join("services", "registration")
) )
@ -24,7 +26,6 @@ func init() {
initializers.LoadEnvVariables() initializers.LoadEnvVariables()
} }
func main() { func main() {
config.LoadConfig() config.LoadConfig()
@ -86,5 +87,4 @@ func main() {
fmt.Fprintf(os.Stderr, err.Error()) fmt.Fprintf(os.Stderr, err.Error())
os.Exit(1) os.Exit(1)
} }
} }

View File

@ -3,15 +3,21 @@ package initializers
import ( import (
"log" "log"
"os" "os"
"path"
"strconv" "strconv"
"github.com/joho/godotenv" "github.com/joho/godotenv"
) )
func LoadEnvVariables() { func LoadEnvVariables() {
err := godotenv.Load() LoadEnvVariablesPath(".")
}
func LoadEnvVariablesPath(dir string) {
fp := path.Join(dir, ".env")
err := godotenv.Load(fp)
if err != nil { if err != nil {
log.Fatal("Error loading .env file") log.Fatal("Error loading .env file", err)
} }
} }

View File

@ -15,6 +15,7 @@ const (
type ConnData struct { type ConnData struct {
typ int typ int
str string str string
domain string
} }
func (cd *ConnData) DbType() int { func (cd *ConnData) DbType() int {
@ -25,23 +26,38 @@ func (cd *ConnData) String() string {
return cd.str return cd.str
} }
func probePostgres(s string) (string, bool) { func (cd *ConnData) Domain() string {
v, err := url.Parse(s) return cd.domain
if err != nil {
return "", false
}
if v.Scheme != "postgres" {
return "", false
}
return s, true
} }
func probeGdbm(s string) (string, bool) { func (cd *ConnData) Path() string {
v, _ := url.Parse(cd.str)
v.RawQuery = ""
return v.String()
}
func probePostgres(s string) (string, string, bool) {
domain := "public"
v, err := url.Parse(s)
if err != nil {
return "", "", false
}
if v.Scheme != "postgres" {
return "", "", false
}
vv := v.Query()
if vv.Has("search_path") {
domain = vv.Get("search_path")
}
return s, domain, true
}
func probeGdbm(s string) (string, string, bool) {
if !path.IsAbs(s) { if !path.IsAbs(s) {
return "", false return "", "", false
} }
s = path.Clean(s) s = path.Clean(s)
return s, true return s, "", true
} }
func ToConnData(connStr string) (ConnData, error) { func ToConnData(connStr string) (ConnData, error) {
@ -51,14 +67,15 @@ func ToConnData(connStr string) (ConnData, error) {
return o, nil return o, nil
} }
v, ok := probePostgres(connStr) v, domain, ok := probePostgres(connStr)
if ok { if ok {
o.typ = DBTYPE_POSTGRES o.typ = DBTYPE_POSTGRES
o.str = v o.str = v
o.domain = domain
return o, nil return o, nil
} }
v, ok = probeGdbm(connStr) v, _, ok = probeGdbm(connStr)
if ok { if ok {
o.typ = DBTYPE_GDBM o.typ = DBTYPE_GDBM
o.str = v o.str = v

View File

@ -14,6 +14,7 @@ import (
"git.defalsify.org/vise.git/persist" "git.defalsify.org/vise.git/persist"
"git.defalsify.org/vise.git/resource" "git.defalsify.org/vise.git/resource"
gdbmstorage "git.grassecon.net/urdt/ussd/internal/storage/db/gdbm" gdbmstorage "git.grassecon.net/urdt/ussd/internal/storage/db/gdbm"
"github.com/jackc/pgx/v5/pgxpool"
) )
var ( var (
@ -54,7 +55,12 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D
connStr := ms.conn.String() connStr := ms.conn.String()
dbTyp := ms.conn.DbType() dbTyp := ms.conn.DbType()
if dbTyp == DBTYPE_POSTGRES { if dbTyp == DBTYPE_POSTGRES {
newDb = postgres.NewPgDb() // TODO: move to vise
err = ensureSchemaExists(ctx, ms.conn)
if err != nil {
return nil, err
}
newDb = postgres.NewPgDb().WithSchema(ms.conn.Domain())
} else if dbTyp == DBTYPE_GDBM { } else if dbTyp == DBTYPE_GDBM {
err = ms.ensureDbDir() err = ms.ensureDbDir()
if err != nil { if err != nil {
@ -65,7 +71,7 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D
} else { } else {
return nil, fmt.Errorf("unsupported connection string: '%s'\n", ms.conn.String()) return nil, fmt.Errorf("unsupported connection string: '%s'\n", ms.conn.String())
} }
logg.DebugCtxf(ctx, "connecting to db", "conn", connStr) logg.DebugCtxf(ctx, "connecting to db", "conn", connStr, "conndata", ms.conn)
err = newDb.Connect(ctx, connStr) err = newDb.Connect(ctx, connStr)
if err != nil { if err != nil {
return nil, err return nil, err
@ -96,6 +102,23 @@ func (ms *MenuStorageService) WithGettext(path string, lns []lang.Language) *Men
return ms return ms
} }
// ensureSchemaExists creates a new schema if it does not exist
func ensureSchemaExists(ctx context.Context, conn ConnData) error {
h, err := pgxpool.New(ctx, conn.Path())
if err != nil {
return fmt.Errorf("failed to connect to the database: %w", err)
}
defer h.Close()
query := fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", conn.Domain())
_, err = h.Exec(ctx, query)
if err != nil {
return fmt.Errorf("failed to create schema: %w", err)
}
return nil
}
func (ms *MenuStorageService) GetPersister(ctx context.Context) (*persist.Persister, error) { func (ms *MenuStorageService) GetPersister(ctx context.Context) (*persist.Persister, error) {
stateStore, err := ms.GetStateStore(ctx) stateStore, err := ms.GetStateStore(ctx)
if err != nil { if err != nil {

View File

@ -3,6 +3,8 @@ package testutil
import ( import (
"context" "context"
"fmt" "fmt"
"log"
"net/url"
"os" "os"
"path" "path"
"path/filepath" "path/filepath"
@ -11,21 +13,90 @@ import (
"git.defalsify.org/vise.git/engine" "git.defalsify.org/vise.git/engine"
"git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/logging"
"git.defalsify.org/vise.git/resource" "git.defalsify.org/vise.git/resource"
"git.grassecon.net/urdt/ussd/config"
"git.grassecon.net/urdt/ussd/initializers"
"git.grassecon.net/urdt/ussd/internal/handlers" "git.grassecon.net/urdt/ussd/internal/handlers"
"git.grassecon.net/urdt/ussd/internal/storage" "git.grassecon.net/urdt/ussd/internal/storage"
"git.grassecon.net/urdt/ussd/internal/testutil/testservice" "git.grassecon.net/urdt/ussd/internal/testutil/testservice"
"git.grassecon.net/urdt/ussd/internal/testutil/testtag" "git.grassecon.net/urdt/ussd/internal/testutil/testtag"
testdataloader "github.com/peteole/testdata-loader"
"git.grassecon.net/urdt/ussd/remote" "git.grassecon.net/urdt/ussd/remote"
"github.com/jackc/pgx/v5/pgxpool"
testdataloader "github.com/peteole/testdata-loader"
) )
var ( var (
baseDir = testdataloader.GetBasePath()
logg = logging.NewVanilla() logg = logging.NewVanilla()
baseDir = testdataloader.GetBasePath()
scriptDir = path.Join(baseDir, "services", "registration") scriptDir = path.Join(baseDir, "services", "registration")
setDbType string
setConnStr string
setDbSchema string
) )
func init() {
initializers.LoadEnvVariablesPath(baseDir)
config.LoadConfig()
}
// SetDatabase updates the database used by TestEngine
func SetDatabase(database, connStr, dbSchema string) {
setDbType = database
setConnStr = connStr
setDbSchema = dbSchema
}
// CleanDatabase removes all test data from the database
func CleanDatabase() {
if setDbType == "postgres" {
ctx := context.Background()
// Update the connection string with the new search path
updatedConnStr, err := updateSearchPath(setConnStr, setDbSchema)
if err != nil {
log.Fatalf("Failed to update search path: %v", err)
}
dbConn, err := pgxpool.New(ctx, updatedConnStr)
if err != nil {
log.Fatalf("Failed to connect to database for cleanup: %v", err)
}
defer dbConn.Close()
query := fmt.Sprintf("DELETE FROM %s.kv_vise;", setDbSchema)
_, execErr := dbConn.Exec(ctx, query)
if execErr != nil {
log.Printf("Failed to cleanup table %s.kv_vise: %v", setDbSchema, execErr)
} else {
log.Printf("Successfully cleaned up table %s.kv_vise", setDbSchema)
}
} else {
setConnStr, _ := filepath.Abs(setConnStr)
if err := os.RemoveAll(setConnStr); err != nil {
log.Fatalf("Failed to delete state store %s: %v", setConnStr, err)
}
}
}
// updateSearchPath updates the search_path (schema) to be used in the connection
func updateSearchPath(connStr string, newSearchPath string) (string, error) {
u, err := url.Parse(connStr)
if err != nil {
return "", fmt.Errorf("invalid connection string: %w", err)
}
// Parse the query parameters
q := u.Query()
// Update or add the search_path parameter
q.Set("search_path", newSearchPath)
// Rebuild the connection string with updated parameters
u.RawQuery = q.Encode()
return u.String(), nil
}
func TestEngine(sessionId string) (engine.Engine, func(), chan bool) { func TestEngine(sessionId string) (engine.Engine, func(), chan bool) {
var err error
ctx := context.Background() ctx := context.Background()
ctx = context.WithValue(ctx, "SessionId", sessionId) ctx = context.WithValue(ctx, "SessionId", sessionId)
pfp := path.Join(scriptDir, "pp.csv") pfp := path.Join(scriptDir, "pp.csv")
@ -39,16 +110,27 @@ func TestEngine(sessionId string) (engine.Engine, func(), chan bool) {
FlagCount: uint32(128), FlagCount: uint32(128),
} }
connStr, err := filepath.Abs(".test_state/state.gdbm") if setDbType == "postgres" {
setConnStr = config.DbConn
setConnStr, err = updateSearchPath(setConnStr, setDbSchema)
if err != nil {
fmt.Println("Error:", err)
os.Exit(1)
}
} else {
setConnStr, err = filepath.Abs(setConnStr)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "connstr err: %v", err) fmt.Fprintf(os.Stderr, "connstr err: %v", err)
os.Exit(1) os.Exit(1)
} }
conn, err := storage.ToConnData(connStr) }
conn, err := storage.ToConnData(setConnStr)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "connstr parse err: %v", err) fmt.Fprintf(os.Stderr, "connstr parse err: %v", err)
os.Exit(1) os.Exit(1)
} }
resourceDir := scriptDir resourceDir := scriptDir
menuStorageService := storage.NewMenuStorageService(conn, resourceDir) menuStorageService := storage.NewMenuStorageService(conn, resourceDir)

View File

@ -6,8 +6,6 @@ import (
"flag" "flag"
"log" "log"
"math/rand" "math/rand"
"os"
"path/filepath"
"regexp" "regexp"
"testing" "testing"
@ -24,11 +22,9 @@ var (
) )
var groupTestFile = flag.String("test-file", "group_test.json", "The test file to use for running the group tests") var groupTestFile = flag.String("test-file", "group_test.json", "The test file to use for running the group tests")
var database = flag.String("db", "gdbm", "Specify the database (gdbm or postgres)")
func testStore() string { var connStr = flag.String("conn", ".test_state", "connection string")
v, _ := filepath.Abs(".test_state/state.gdbm") var dbSchema = flag.String("schema", "test", "Specify the database schema (default test)")
return v
}
func GenerateSessionId() string { func GenerateSessionId() string {
uu := uuid.NewGenWithOptions(uuid.WithRandomReader(g)) uu := uuid.NewGenWithOptions(uuid.WithRandomReader(g))
@ -84,12 +80,15 @@ func extractSendAmount(response []byte) string {
} }
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
// Parse the flags
flag.Parse()
sessionID = GenerateSessionId() sessionID = GenerateSessionId()
defer func() { // set the db
if err := os.RemoveAll(testStore()); err != nil { testutil.SetDatabase(*database, *connStr, *dbSchema)
log.Fatalf("Failed to delete state store %s: %v", testStore(), err)
} // Cleanup the db after tests
}() defer testutil.CleanDatabase()
m.Run() m.Run()
} }
@ -126,7 +125,6 @@ func TestAccountCreationSuccessful(t *testing.T) {
} }
} }
<-eventChannel <-eventChannel
} }
func TestAccountRegistrationRejectTerms(t *testing.T) { func TestAccountRegistrationRejectTerms(t *testing.T) {