diff --git a/internal/storage/parse.go b/internal/storage/parse.go index bb25627..41dac6b 100644 --- a/internal/storage/parse.go +++ b/internal/storage/parse.go @@ -15,6 +15,7 @@ const ( type ConnData struct { typ int str string + domain string } func (cd *ConnData) DbType() int { @@ -25,23 +26,38 @@ func (cd *ConnData) String() string { return cd.str } -func probePostgres(s string) (string, bool) { - v, err := url.Parse(s) - if err != nil { - return "", false - } - if v.Scheme != "postgres" { - return "", false - } - return s, true +func (cd *ConnData) Domain() string { + return cd.domain } -func probeGdbm(s string) (string, bool) { +func (cd *ConnData) Path() string { + v, _ := url.Parse(cd.str) + v.RawQuery = "" + return v.String() +} + +func probePostgres(s string) (string, string, bool) { + domain := "public" + v, err := url.Parse(s) + if err != nil { + return "", "", false + } + if v.Scheme != "postgres" { + return "", "", false + } + vv := v.Query() + if vv.Has("search_path") { + domain = vv.Get("search_path") + } + return s, domain, true +} + +func probeGdbm(s string) (string, string, bool) { if !path.IsAbs(s) { - return "", false + return "", "", false } s = path.Clean(s) - return s, true + return s, "", true } func ToConnData(connStr string) (ConnData, error) { @@ -51,14 +67,15 @@ func ToConnData(connStr string) (ConnData, error) { return o, nil } - v, ok := probePostgres(connStr) + v, domain, ok := probePostgres(connStr) if ok { o.typ = DBTYPE_POSTGRES o.str = v + o.domain = domain return o, nil } - v, ok = probeGdbm(connStr) + v, _, ok = probeGdbm(connStr) if ok { o.typ = DBTYPE_GDBM o.str = v diff --git a/internal/storage/storage_service.go b/internal/storage/storage_service.go index c94f17c..d72a2eb 100644 --- a/internal/storage/storage_service.go +++ b/internal/storage/storage_service.go @@ -76,7 +76,12 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D connStr := ms.conn.String() dbTyp := ms.conn.DbType() if dbTyp == DBTYPE_POSTGRES { - newDb = postgres.NewPgDb() + // TODO: move to vise + err = ensureSchemaExists(ctx, ms.conn) + if err != nil { + return nil, err + } + newDb = postgres.NewPgDb().WithSchema(ms.conn.Domain()) } else if dbTyp == DBTYPE_GDBM { err = ms.ensureDbDir() if err != nil { @@ -90,7 +95,7 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D } else { return nil, fmt.Errorf("unsupported connection string: '%s'\n", ms.conn.String()) } - logg.DebugCtxf(ctx, "connecting to db", "conn", connStr) + logg.DebugCtxf(ctx, "connecting to db", "conn", connStr, "conndata", ms.conn) err = newDb.Connect(ctx, connStr) if err != nil { return nil, err @@ -100,15 +105,15 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D } // ensureSchemaExists creates a new schema if it does not exist -func ensureSchemaExists(ctx context.Context, connStr, schema string) error { - conn, err := pgxpool.New(ctx, connStr) +func ensureSchemaExists(ctx context.Context, conn ConnData) error { + h, err := pgxpool.New(ctx, conn.Path()) if err != nil { return fmt.Errorf("failed to connect to the database: %w", err) } - defer conn.Close() + defer h.Close() - query := fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema) - _, err = conn.Exec(ctx, query) + query := fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", conn.Domain()) + _, err = h.Exec(ctx, query) if err != nil { return fmt.Errorf("failed to create schema: %w", err) }