postgres-switch-for-tests #255
| @ -15,6 +15,7 @@ const ( | ||||
| type ConnData struct { | ||||
| 	typ int | ||||
| 	str string | ||||
| 	domain string | ||||
| } | ||||
| 
 | ||||
| func (cd *ConnData) DbType() int { | ||||
| @ -25,23 +26,38 @@ func (cd *ConnData) String() string { | ||||
| 	return cd.str | ||||
| } | ||||
| 
 | ||||
| func probePostgres(s string) (string, bool) { | ||||
| 	v, err := url.Parse(s) | ||||
| 	if err != nil { | ||||
| 		return "", false | ||||
| 	} | ||||
| 	if v.Scheme != "postgres" { | ||||
| 		return "", false | ||||
| 	} | ||||
| 	return s, true | ||||
| func (cd *ConnData) Domain() string { | ||||
| 	return cd.domain | ||||
| } | ||||
| 
 | ||||
| func probeGdbm(s string) (string, bool) { | ||||
| func (cd *ConnData) Path() string { | ||||
| 	v, _ := url.Parse(cd.str) | ||||
| 	v.RawQuery = "" | ||||
| 	return v.String() | ||||
| } | ||||
| 
 | ||||
| func probePostgres(s string) (string, string, bool) { | ||||
| 	domain := "public" | ||||
| 	v, err := url.Parse(s) | ||||
| 	if err != nil { | ||||
| 		return "", "", false | ||||
| 	} | ||||
| 	if v.Scheme != "postgres" { | ||||
| 		return "", "", false | ||||
| 	} | ||||
| 	vv := v.Query() | ||||
| 	if vv.Has("search_path") { | ||||
| 		domain = vv.Get("search_path") | ||||
| 	} | ||||
| 	return s, domain, true | ||||
| } | ||||
| 
 | ||||
| func probeGdbm(s string) (string, string, bool) { | ||||
| 	if !path.IsAbs(s) { | ||||
| 		return "", false | ||||
| 		return "", "", false | ||||
| 	} | ||||
| 	s = path.Clean(s) | ||||
| 	return s, true | ||||
| 	return s, "", true | ||||
| } | ||||
| 
 | ||||
| func ToConnData(connStr string) (ConnData, error) { | ||||
| @ -51,14 +67,15 @@ func ToConnData(connStr string) (ConnData, error) { | ||||
| 		return o, nil | ||||
| 	} | ||||
| 
 | ||||
| 	v, ok := probePostgres(connStr) | ||||
| 	v, domain, ok := probePostgres(connStr) | ||||
| 	if ok { | ||||
| 		o.typ = DBTYPE_POSTGRES | ||||
| 		o.str = v | ||||
| 		o.domain = domain | ||||
| 		return o, nil | ||||
| 	} | ||||
| 
 | ||||
| 	v, ok = probeGdbm(connStr) | ||||
| 	v, _, ok = probeGdbm(connStr) | ||||
| 	if ok { | ||||
| 		o.typ = DBTYPE_GDBM | ||||
| 		o.str = v | ||||
|  | ||||
| @ -55,7 +55,12 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D | ||||
| 	connStr := ms.conn.String() | ||||
| 	dbTyp := ms.conn.DbType() | ||||
| 	if dbTyp == DBTYPE_POSTGRES { | ||||
| 		newDb = postgres.NewPgDb() | ||||
| 		// TODO: move to vise
 | ||||
| 		err = ensureSchemaExists(ctx, ms.conn) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		newDb = postgres.NewPgDb().WithSchema(ms.conn.Domain()) | ||||
| 	} else if dbTyp == DBTYPE_GDBM { | ||||
| 		err = ms.ensureDbDir() | ||||
| 		if err != nil { | ||||
| @ -66,7 +71,7 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D | ||||
| 	} else { | ||||
| 		return nil, fmt.Errorf("unsupported connection string: '%s'\n", ms.conn.String()) | ||||
| 	} | ||||
| 	logg.DebugCtxf(ctx, "connecting to db", "conn", connStr) | ||||
| 	logg.DebugCtxf(ctx, "connecting to db", "conn", connStr, "conndata", ms.conn) | ||||
| 	err = newDb.Connect(ctx, connStr) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| @ -98,15 +103,15 @@ func (ms *MenuStorageService) WithGettext(path string, lns []lang.Language) *Men | ||||
| } | ||||
| 
 | ||||
| // ensureSchemaExists creates a new schema if it does not exist
 | ||||
| func ensureSchemaExists(ctx context.Context, connStr, schema string) error { | ||||
| 	conn, err := pgxpool.New(ctx, connStr) | ||||
| func ensureSchemaExists(ctx context.Context, conn ConnData) error { | ||||
| 	h, err := pgxpool.New(ctx, conn.Path()) | ||||
| 	if err != nil { | ||||
| 
					
					lash marked this conversation as resolved
					
				 | ||||
| 		return fmt.Errorf("failed to connect to the database: %w", err) | ||||
| 	} | ||||
| 	defer conn.Close() | ||||
| 	defer h.Close() | ||||
| 
 | ||||
| 	query := fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema) | ||||
| 	_, err = conn.Exec(ctx, query) | ||||
| 	query := fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", conn.Domain()) | ||||
| 	_, err = h.Exec(ctx, query) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("failed to create schema: %w", err) | ||||
| 	} | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	
If we are using connection string, then this needs to be url parsed to get the query string in order to determine the schema.