Implement postgres schema

This commit is contained in:
lash 2025-01-09 07:42:09 +00:00
parent df8c9aab0c
commit b50a51df9b
Signed by untrusted user who does not match committer: lash
GPG Key ID: 21D2E7BB88C2A746
2 changed files with 43 additions and 21 deletions

View File

@ -15,6 +15,7 @@ const (
type ConnData struct { type ConnData struct {
typ int typ int
str string str string
domain string
} }
func (cd *ConnData) DbType() int { func (cd *ConnData) DbType() int {
@ -25,23 +26,38 @@ func (cd *ConnData) String() string {
return cd.str return cd.str
} }
func probePostgres(s string) (string, bool) { func (cd *ConnData) Domain() string {
v, err := url.Parse(s) return cd.domain
if err != nil {
return "", false
}
if v.Scheme != "postgres" {
return "", false
}
return s, true
} }
func probeGdbm(s string) (string, bool) { func (cd *ConnData) Path() string {
v, _ := url.Parse(cd.str)
v.RawQuery = ""
return v.String()
}
func probePostgres(s string) (string, string, bool) {
domain := "public"
v, err := url.Parse(s)
if err != nil {
return "", "", false
}
if v.Scheme != "postgres" {
return "", "", false
}
vv := v.Query()
if vv.Has("search_path") {
domain = vv.Get("search_path")
}
return s, domain, true
}
func probeGdbm(s string) (string, string, bool) {
if !path.IsAbs(s) { if !path.IsAbs(s) {
return "", false return "", "", false
} }
s = path.Clean(s) s = path.Clean(s)
return s, true return s, "", true
} }
func ToConnData(connStr string) (ConnData, error) { func ToConnData(connStr string) (ConnData, error) {
@ -51,14 +67,15 @@ func ToConnData(connStr string) (ConnData, error) {
return o, nil return o, nil
} }
v, ok := probePostgres(connStr) v, domain, ok := probePostgres(connStr)
if ok { if ok {
o.typ = DBTYPE_POSTGRES o.typ = DBTYPE_POSTGRES
o.str = v o.str = v
o.domain = domain
return o, nil return o, nil
} }
v, ok = probeGdbm(connStr) v, _, ok = probeGdbm(connStr)
if ok { if ok {
o.typ = DBTYPE_GDBM o.typ = DBTYPE_GDBM
o.str = v o.str = v

View File

@ -55,7 +55,12 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D
connStr := ms.conn.String() connStr := ms.conn.String()
dbTyp := ms.conn.DbType() dbTyp := ms.conn.DbType()
if dbTyp == DBTYPE_POSTGRES { if dbTyp == DBTYPE_POSTGRES {
newDb = postgres.NewPgDb() // TODO: move to vise
err = ensureSchemaExists(ctx, ms.conn)
if err != nil {
return nil, err
}
newDb = postgres.NewPgDb().WithSchema(ms.conn.Domain())
} else if dbTyp == DBTYPE_GDBM { } else if dbTyp == DBTYPE_GDBM {
err = ms.ensureDbDir() err = ms.ensureDbDir()
if err != nil { if err != nil {
@ -66,7 +71,7 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D
} else { } else {
return nil, fmt.Errorf("unsupported connection string: '%s'\n", ms.conn.String()) return nil, fmt.Errorf("unsupported connection string: '%s'\n", ms.conn.String())
} }
logg.DebugCtxf(ctx, "connecting to db", "conn", connStr) logg.DebugCtxf(ctx, "connecting to db", "conn", connStr, "conndata", ms.conn)
err = newDb.Connect(ctx, connStr) err = newDb.Connect(ctx, connStr)
if err != nil { if err != nil {
return nil, err return nil, err
@ -98,15 +103,15 @@ func (ms *MenuStorageService) WithGettext(path string, lns []lang.Language) *Men
} }
// ensureSchemaExists creates a new schema if it does not exist // ensureSchemaExists creates a new schema if it does not exist
func ensureSchemaExists(ctx context.Context, connStr, schema string) error { func ensureSchemaExists(ctx context.Context, conn ConnData) error {
conn, err := pgxpool.New(ctx, connStr) h, err := pgxpool.New(ctx, conn.Path())
if err != nil { if err != nil {
return fmt.Errorf("failed to connect to the database: %w", err) return fmt.Errorf("failed to connect to the database: %w", err)
} }
defer conn.Close() defer h.Close()
query := fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema) query := fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", conn.Domain())
_, err = conn.Exec(ctx, query) _, err = h.Exec(ctx, query)
if err != nil { if err != nil {
return fmt.Errorf("failed to create schema: %w", err) return fmt.Errorf("failed to create schema: %w", err)
} }