From 46a6d2bc6e2e732d800750cb67901d1887b119bc Mon Sep 17 00:00:00 2001 From: alfred-mk Date: Wed, 8 Jan 2025 10:37:47 +0300 Subject: [PATCH] create a schema if it does not exist and use it in the connection --- internal/storage/storageservice.go | 32 +++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/internal/storage/storageservice.go b/internal/storage/storageservice.go index 04e75ce..d252383 100644 --- a/internal/storage/storageservice.go +++ b/internal/storage/storageservice.go @@ -14,6 +14,7 @@ import ( "git.defalsify.org/vise.git/resource" "git.grassecon.net/urdt/ussd/initializers" gdbmstorage "git.grassecon.net/urdt/ussd/internal/storage/db/gdbm" + "github.com/jackc/pgx/v5/pgxpool" ) var ( @@ -64,6 +65,11 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D return nil, fmt.Errorf("failed to select the database") } + schema, ok := ctx.Value("Schema").(string) + if !ok { + return nil, fmt.Errorf("failed to select the schema") + } + if existingDb != nil { return existingDb, nil } @@ -72,8 +78,15 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D var err error if database == "postgres" { - newDb = postgres.NewPgDb() connStr := buildConnStr() + + // Ensure the schema exists + err = ensureSchemaExists(ctx, connStr, schema) + if err != nil { + return nil, fmt.Errorf("failed to ensure schema exists: %w", err) + } + + newDb = postgres.NewPgDb().WithSchema(schema) err = newDb.Connect(ctx, connStr) } else { newDb = gdbmstorage.NewThreadGdbmDb() @@ -88,6 +101,23 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D return newDb, nil } +// ensureSchemaExists creates a new schema if it does not exist +func ensureSchemaExists(ctx context.Context, connStr, schema string) error { + conn, err := pgxpool.New(ctx, connStr) + if err != nil { + return fmt.Errorf("failed to connect to the database: %w", err) + } + defer conn.Close() + + query := fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema) + _, err = conn.Exec(ctx, query) + if err != nil { + return fmt.Errorf("failed to create schema: %w", err) + } + + return nil +} + func (ms *MenuStorageService) GetPersister(ctx context.Context) (*persist.Persister, error) { stateStore, err := ms.GetStateStore(ctx) if err != nil {