postgres-switch-for-tests #255
| @ -15,6 +15,7 @@ const ( | |||||||
| type ConnData struct { | type ConnData struct { | ||||||
| 	typ int | 	typ int | ||||||
| 	str string | 	str string | ||||||
|  | 	domain string | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (cd *ConnData) DbType() int { | func (cd *ConnData) DbType() int { | ||||||
| @ -25,23 +26,38 @@ func (cd *ConnData) String() string { | |||||||
| 	return cd.str | 	return cd.str | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func probePostgres(s string) (string, bool) { | func (cd *ConnData) Domain() string { | ||||||
| 	v, err := url.Parse(s) | 	return cd.domain | ||||||
| 	if err != nil { |  | ||||||
| 		return "", false |  | ||||||
| 	} |  | ||||||
| 	if v.Scheme != "postgres" { |  | ||||||
| 		return "", false |  | ||||||
| 	} |  | ||||||
| 	return s, true |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func probeGdbm(s string) (string, bool) { | func (cd *ConnData) Path() string { | ||||||
|  | 	v, _ := url.Parse(cd.str) | ||||||
|  | 	v.RawQuery = "" | ||||||
|  | 	return v.String() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func probePostgres(s string) (string, string, bool) { | ||||||
|  | 	domain := "public" | ||||||
|  | 	v, err := url.Parse(s) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", "", false | ||||||
|  | 	} | ||||||
|  | 	if v.Scheme != "postgres" { | ||||||
|  | 		return "", "", false | ||||||
|  | 	} | ||||||
|  | 	vv := v.Query() | ||||||
|  | 	if vv.Has("search_path") { | ||||||
|  | 		domain = vv.Get("search_path") | ||||||
|  | 	} | ||||||
|  | 	return s, domain, true | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func probeGdbm(s string) (string, string, bool) { | ||||||
| 	if !path.IsAbs(s) { | 	if !path.IsAbs(s) { | ||||||
| 		return "", false | 		return "", "", false | ||||||
| 	} | 	} | ||||||
| 	s = path.Clean(s) | 	s = path.Clean(s) | ||||||
| 	return s, true | 	return s, "", true | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func ToConnData(connStr string) (ConnData, error) { | func ToConnData(connStr string) (ConnData, error) { | ||||||
| @ -51,14 +67,15 @@ func ToConnData(connStr string) (ConnData, error) { | |||||||
| 		return o, nil | 		return o, nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	v, ok := probePostgres(connStr) | 	v, domain, ok := probePostgres(connStr) | ||||||
| 	if ok { | 	if ok { | ||||||
| 		o.typ = DBTYPE_POSTGRES | 		o.typ = DBTYPE_POSTGRES | ||||||
| 		o.str = v | 		o.str = v | ||||||
|  | 		o.domain = domain | ||||||
| 		return o, nil | 		return o, nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	v, ok = probeGdbm(connStr) | 	v, _, ok = probeGdbm(connStr) | ||||||
| 	if ok { | 	if ok { | ||||||
| 		o.typ = DBTYPE_GDBM | 		o.typ = DBTYPE_GDBM | ||||||
| 		o.str = v | 		o.str = v | ||||||
|  | |||||||
| @ -55,7 +55,12 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D | |||||||
| 	connStr := ms.conn.String() | 	connStr := ms.conn.String() | ||||||
| 	dbTyp := ms.conn.DbType() | 	dbTyp := ms.conn.DbType() | ||||||
| 	if dbTyp == DBTYPE_POSTGRES { | 	if dbTyp == DBTYPE_POSTGRES { | ||||||
| 		newDb = postgres.NewPgDb() | 		// TODO: move to vise
 | ||||||
|  | 		err = ensureSchemaExists(ctx, ms.conn) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 		newDb = postgres.NewPgDb().WithSchema(ms.conn.Domain()) | ||||||
| 	} else if dbTyp == DBTYPE_GDBM { | 	} else if dbTyp == DBTYPE_GDBM { | ||||||
| 		err = ms.ensureDbDir() | 		err = ms.ensureDbDir() | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| @ -66,7 +71,7 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D | |||||||
| 	} else { | 	} else { | ||||||
| 		return nil, fmt.Errorf("unsupported connection string: '%s'\n", ms.conn.String()) | 		return nil, fmt.Errorf("unsupported connection string: '%s'\n", ms.conn.String()) | ||||||
| 	} | 	} | ||||||
| 	logg.DebugCtxf(ctx, "connecting to db", "conn", connStr) | 	logg.DebugCtxf(ctx, "connecting to db", "conn", connStr, "conndata", ms.conn) | ||||||
| 	err = newDb.Connect(ctx, connStr) | 	err = newDb.Connect(ctx, connStr) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @ -98,15 +103,15 @@ func (ms *MenuStorageService) WithGettext(path string, lns []lang.Language) *Men | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // ensureSchemaExists creates a new schema if it does not exist
 | // ensureSchemaExists creates a new schema if it does not exist
 | ||||||
| func ensureSchemaExists(ctx context.Context, connStr, schema string) error { | func ensureSchemaExists(ctx context.Context, conn ConnData) error { | ||||||
| 	conn, err := pgxpool.New(ctx, connStr) | 	h, err := pgxpool.New(ctx, conn.Path()) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 
					
					lash marked this conversation as resolved
					
				 | |||||||
| 		return fmt.Errorf("failed to connect to the database: %w", err) | 		return fmt.Errorf("failed to connect to the database: %w", err) | ||||||
| 	} | 	} | ||||||
| 	defer conn.Close() | 	defer h.Close() | ||||||
| 
 | 
 | ||||||
| 	query := fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema) | 	query := fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", conn.Domain()) | ||||||
| 	_, err = conn.Exec(ctx, query) | 	_, err = h.Exec(ctx, query) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return fmt.Errorf("failed to create schema: %w", err) | 		return fmt.Errorf("failed to create schema: %w", err) | ||||||
| 	} | 	} | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	
If we are using connection string, then this needs to be url parsed to get the query string in order to determine the schema.