From a47e44f262f1d83593cabb69ccbc20110b537f44 Mon Sep 17 00:00:00 2001 From: Mohammed Sohail Date: Thu, 2 Mar 2023 15:46:02 +0000 Subject: [PATCH] refactor: use sigChan for shutdown, ctx fixes * refactor main entry point for starting services * minor fixes around ctx propagation * improve otx marker js subscriber --- cmd/service/api.go | 23 +++---- cmd/service/custodial.go | 12 ++-- cmd/service/init.go | 65 ++++++++++--------- cmd/service/main.go | 102 ++++++++++------------------- cmd/service/tasker.go | 2 - cmd/service/utils.go | 31 +++++++++ internal/api/sign.go | 2 +- internal/api/types.go | 6 ++ internal/events/events.go | 7 -- internal/events/jetstream.go | 106 +++++++++++++++++++++++++++---- internal/events/jetstream_sub.go | 73 --------------------- internal/store/otx.go | 1 + pkg/postgres/pool.go | 6 +- pkg/redis/pool.go | 4 +- 14 files changed, 223 insertions(+), 217 deletions(-) create mode 100644 cmd/service/utils.go delete mode 100644 internal/events/events.go delete mode 100644 internal/events/jetstream_sub.go diff --git a/cmd/service/api.go b/cmd/service/api.go index b6a9773..08122bc 100644 --- a/cmd/service/api.go +++ b/cmd/service/api.go @@ -13,7 +13,7 @@ import ( ) const ( - contextTimeout = 5 + contextTimeout = 5 * time.Second ) // Bootstrap API server. @@ -24,11 +24,9 @@ func initApiServer(custodialContainer *custodial.Custodial) *echo.Echo { server := echo.New() server.HideBanner = true server.HidePort = true - server.Validator = &api.Validator{ ValidatorProvider: customValidator, } - server.HTTPErrorHandler = customHTTPErrorHandler server.Use(func(next echo.HandlerFunc) echo.HandlerFunc { @@ -39,7 +37,7 @@ func initApiServer(custodialContainer *custodial.Custodial) *echo.Echo { }) server.Use(middleware.Recover()) server.Use(middleware.BodyLimit("1M")) - server.Use(middleware.ContextTimeout(time.Duration(contextTimeout * time.Second))) + server.Use(middleware.ContextTimeout(contextTimeout)) if ko.Bool("service.metrics") { server.GET("/metrics", func(c echo.Context) error { @@ -61,8 +59,7 @@ func customHTTPErrorHandler(err error, c echo.Context) { return } - he, ok := err.(*echo.HTTPError) - if ok { + if he, ok := err.(*echo.HTTPError); ok { var errorMsg string if m, ok := he.Message.(error); ok { @@ -75,12 +72,12 @@ func customHTTPErrorHandler(err error, c echo.Context) { Ok: false, Message: errorMsg, }) - } else { - lo.Error("api: echo error", "path", c.Path(), "err", err) - - c.JSON(http.StatusInternalServerError, api.ErrResp{ - Ok: false, - Message: "Internal server error.", - }) + return } + + lo.Error("api: echo error", "path", c.Path(), "err", err) + c.JSON(http.StatusInternalServerError, api.ErrResp{ + Ok: false, + Message: "Internal server error.", + }) } diff --git a/cmd/service/custodial.go b/cmd/service/custodial.go index c27c934..08e2787 100644 --- a/cmd/service/custodial.go +++ b/cmd/service/custodial.go @@ -31,7 +31,10 @@ func initAbis() map[string]*w3.Func { // Bootstrap the internal custodial system configs and system signer key. // This container is passed down to individual tasker and API handlers. -func initSystemContainer(ctx context.Context, noncestore nonce.Noncestore) (*tasker.SystemContainer, error) { +func initSystemContainer(ctx context.Context, noncestore nonce.Noncestore) *tasker.SystemContainer { + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + // Some custodial system defaults loaded from the config file. systemContainer := &tasker.SystemContainer{ Abis: initAbis(), @@ -48,6 +51,7 @@ func initSystemContainer(ctx context.Context, noncestore nonce.Noncestore) (*tas TokenDecimals: ko.MustInt("system.token_decimals"), TokenTransferGasLimit: uint64(ko.MustInt64("system.token_transfer_gas_limit")), } + // Check if system signer account nonce is present. // If not (first boot), we bootstrap it from the network. currentSystemNonce, err := noncestore.Peek(ctx, ko.MustString("system.public_key")) @@ -56,15 +60,15 @@ func initSystemContainer(ctx context.Context, noncestore nonce.Noncestore) (*tas nonce, err := noncestore.SyncNetworkNonce(ctx, ko.MustString("system.public_key")) lo.Info("custodial: syncing system nonce", "nonce", nonce) if err != nil { - return nil, err + lo.Fatal("custodial: critical error bootstrapping system container", "error", err) } } loadedPrivateKey, err := eth_crypto.HexToECDSA(ko.MustString("system.private_key")) if err != nil { - return nil, err + lo.Fatal("custodial: critical error bootstrapping system container", "error", err) } systemContainer.PrivateKey = loadedPrivateKey - return systemContainer, nil + return systemContainer } diff --git a/cmd/service/init.go b/cmd/service/init.go index ed550e2..4e8ecdc 100644 --- a/cmd/service/init.go +++ b/cmd/service/init.go @@ -1,6 +1,7 @@ package main import ( + "context" "strings" "time" @@ -25,10 +26,10 @@ import ( ) // Load logger. -func initLogger(debug bool) logf.Logger { +func initLogger() logf.Logger { loggOpts := logg.LoggOpts{} - if debug { + if debugFlag { loggOpts.Color = true loggOpts.Caller = true loggOpts.Debug = true @@ -38,12 +39,12 @@ func initLogger(debug bool) logf.Logger { } // Load config file. -func initConfig(configFilePath string) *koanf.Koanf { +func initConfig() *koanf.Koanf { var ( ko = koanf.New(".") ) - confFile := file.Provider(configFilePath) + confFile := file.Provider(confFlag) if err := ko.Load(confFile, toml.Parser()); err != nil { lo.Fatal("Could not load config file", "error", err) } @@ -55,13 +56,15 @@ func initConfig(configFilePath string) *koanf.Koanf { lo.Fatal("Could not override config from env vars", "error", err) } - ko.Print() + if debugFlag { + ko.Print() + } return ko } // Load Celo chain provider. -func initCeloProvider() (*celoutils.Provider, error) { +func initCeloProvider() *celoutils.Provider { providerOpts := celoutils.ProviderOpts{ RpcEndpoint: ko.MustString("chain.rpc_endpoint"), } @@ -74,80 +77,80 @@ func initCeloProvider() (*celoutils.Provider, error) { provider, err := celoutils.NewProvider(providerOpts) if err != nil { - return nil, err + lo.Fatal("init: critical error loading chain provider", "error", err) } - return provider, nil + return provider } // Load postgres pool. -func initPostgresPool() (*pgxpool.Pool, error) { +func initPostgresPool() *pgxpool.Pool { poolOpts := postgres.PostgresPoolOpts{ DSN: ko.MustString("postgres.dsn"), MigrationsFolderPath: migrationsFolderFlag, } - pool, err := postgres.NewPostgresPool(poolOpts) + pool, err := postgres.NewPostgresPool(context.Background(), poolOpts) if err != nil { - return nil, err + lo.Fatal("init: critical error connecting to postgres", "error", err) } - return pool, nil + return pool } // Load separate redis connection for the tasker on a reserved db namespace. -func initAsynqRedisPool() (*redis.RedisPool, error) { +func initAsynqRedisPool() *redis.RedisPool { poolOpts := redis.RedisPoolOpts{ DSN: ko.MustString("asynq.dsn"), MinIdleConns: ko.MustInt("redis.min_idle_conn"), } - pool, err := redis.NewRedisPool(poolOpts) + pool, err := redis.NewRedisPool(context.Background(), poolOpts) if err != nil { - return nil, err + lo.Fatal("init: critical error connecting to asynq redis db", "error", err) } - return pool, nil + return pool } // Common redis connection on a different db namespace from the takser. -func initCommonRedisPool() (*redis.RedisPool, error) { +func initCommonRedisPool() *redis.RedisPool { poolOpts := redis.RedisPoolOpts{ DSN: ko.MustString("redis.dsn"), MinIdleConns: ko.MustInt("redis.min_idle_conn"), } - pool, err := redis.NewRedisPool(poolOpts) + pool, err := redis.NewRedisPool(context.Background(), poolOpts) if err != nil { - return nil, err + lo.Fatal("init: critical error connecting to common redis db", "error", err) } - return pool, nil + return pool } // Load SQL statements into struct. -func initQueries(queriesPath string) (*queries.Queries, error) { - parsedQueries, err := goyesql.ParseFile(queriesPath) +func initQueries() *queries.Queries { + parsedQueries, err := goyesql.ParseFile(queriesFlag) if err != nil { - return nil, err + lo.Fatal("init: critical error loading SQL queries", "error", err) } loadedQueries, err := queries.LoadQueries(parsedQueries) if err != nil { - return nil, err + lo.Fatal("init: critical error loading SQL queries", "error", err) } - return loadedQueries, nil + return loadedQueries } // Load postgres based keystore. -func initPostgresKeystore(postgresPool *pgxpool.Pool, queries *queries.Queries) (keystore.Keystore, error) { +func initPostgresKeystore(postgresPool *pgxpool.Pool, queries *queries.Queries) keystore.Keystore { keystore := keystore.NewPostgresKeytore(keystore.Opts{ PostgresPool: postgresPool, Queries: queries, }) - return keystore, nil + return keystore } // Load redis backed noncestore. @@ -180,17 +183,19 @@ func initPostgresStore(postgresPool *pgxpool.Pool, queries *queries.Queries) sto } // Init JetStream context for tasker events. -func initJetStream() (*events.JetStream, error) { +func initJetStream(pgStore store.Store) *events.JetStream { jsEmitter, err := events.NewJetStreamEventEmitter(events.JetStreamOpts{ Logg: lo, + PgStore: pgStore, ServerUrl: ko.MustString("jetstream.endpoint"), PersistDuration: time.Duration(ko.MustInt("jetstream.persist_duration_hrs")) * time.Hour, DedupDuration: time.Duration(ko.MustInt("jetstream.dedup_duration_hrs")) * time.Hour, }) if err != nil { - return nil, err + lo.Fatal("main: critical error loading jetstream event emitter") + } - return jsEmitter, nil + return jsEmitter } diff --git a/cmd/service/main.go b/cmd/service/main.go index ee586b6..1f180b6 100644 --- a/cmd/service/main.go +++ b/cmd/service/main.go @@ -3,19 +3,25 @@ package main import ( "context" "flag" - "os" - "os/signal" "strings" "sync" - "syscall" "github.com/grassrootseconomics/cic-custodial/internal/custodial" + "github.com/grassrootseconomics/cic-custodial/internal/events" "github.com/grassrootseconomics/cic-custodial/internal/tasker" "github.com/knadh/koanf/v2" "github.com/labstack/echo/v4" "github.com/zerodha/logf" ) +type ( + internalServiceContainer struct { + apiService *echo.Echo + jetstreamSub *events.JetStream + taskerService *tasker.TaskerServer + } +) + var ( confFlag string debugFlag bool @@ -33,63 +39,25 @@ func init() { flag.StringVar(&queriesFlag, "queries", "queries.sql", "Queries file location") flag.Parse() - lo = initLogger(debugFlag) - ko = initConfig(confFlag) + lo = initLogger() + ko = initConfig() } func main() { - var ( - tasker *tasker.TaskerServer - apiServer *echo.Echo - ) + parsedQueries := initQueries() + celoProvider := initCeloProvider() + postgresPool := initPostgresPool() + asynqRedisPool := initAsynqRedisPool() + redisPool := initCommonRedisPool() - ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) - defer stop() - - queries, err := initQueries(queriesFlag) - if err != nil { - lo.Fatal("main: critical error loading SQL queries", "error", err) - } - - celoProvider, err := initCeloProvider() - if err != nil { - lo.Fatal("main: critical error loading chain provider", "error", err) - } - - postgresPool, err := initPostgresPool() - if err != nil { - lo.Fatal("main: critical error connecting to postgres", "error", err) - } - - asynqRedisPool, err := initAsynqRedisPool() - if err != nil { - lo.Fatal("main: critical error connecting to asynq redis db", "error", err) - } - - redisPool, err := initCommonRedisPool() - if err != nil { - lo.Fatal("main: critical error connecting to common redis db", "error", err) - } - - postgresKeystore, err := initPostgresKeystore(postgresPool, queries) - if err != nil { - lo.Fatal("main: critical error loading keystore") - } - - jsEventEmitter, err := initJetStream() - if err != nil { - lo.Fatal("main: critical error loading jetstream event emitter") - } - - pgStore := initPostgresStore(postgresPool, queries) + postgresKeystore := initPostgresKeystore(postgresPool, parsedQueries) + pgStore := initPostgresStore(postgresPool, parsedQueries) redisNoncestore := initRedisNoncestore(redisPool, celoProvider) lockProvider := initLockProvider(redisPool.Client) taskerClient := initTaskerClient(asynqRedisPool) + systemContainer := initSystemContainer(context.Background(), redisNoncestore) - systemContainer, err := initSystemContainer(context.Background(), redisNoncestore) - if err != nil { - lo.Fatal("main: critical error bootstrapping system container", "error", err) - } + jsEventEmitter := initJetStream(pgStore) custodial := &custodial.Custodial{ CeloProvider: celoProvider, @@ -102,14 +70,18 @@ func main() { TaskerClient: taskerClient, } + internalServices := &internalServiceContainer{} wg := &sync.WaitGroup{} - apiServer = initApiServer(custodial) + signalCh, closeCh := createSigChannel() + defer closeCh() + + internalServices.apiService = initApiServer(custodial) wg.Add(1) go func() { defer wg.Done() lo.Info("main: starting API server") - if err := apiServer.Start(ko.MustString("service.address")); err != nil { + if err := internalServices.apiService.Start(ko.MustString("service.address")); err != nil { if strings.Contains(err.Error(), "Server closed") { lo.Info("main: shutting down server") } else { @@ -118,34 +90,28 @@ func main() { } }() - tasker = initTasker(custodial, asynqRedisPool) + internalServices.taskerService = initTasker(custodial, asynqRedisPool) wg.Add(1) go func() { defer wg.Done() lo.Info("Starting tasker") - if err := tasker.Start(); err != nil { + if err := internalServices.taskerService.Start(); err != nil { lo.Fatal("main: could not start task server", "err", err) } }() + internalServices.jetstreamSub = jsEventEmitter wg.Add(1) go func() { defer wg.Done() - lo.Info("Starting jetstream subscriber") - if err := jsEventEmitter.ChainSubscription(ctx, pgStore); err != nil { - lo.Fatal("main: jetstream subscriber", "err", err) + lo.Info("Starting jetstream sub") + if err := internalServices.jetstreamSub.Subscriber(); err != nil { + lo.Fatal("main: error running jetstream sub", "err", err) } }() - <-ctx.Done() - - lo.Info("main: stopping tasker") - tasker.Stop() - - lo.Info("main: stopping api server") - if err := apiServer.Shutdown(ctx); err != nil { - lo.Error("Could not gracefully shutdown api server", "err", err) - } + <-signalCh + startGracefulShutdown(context.Background(), internalServices) wg.Wait() } diff --git a/cmd/service/tasker.go b/cmd/service/tasker.go index c8236b1..25b29de 100644 --- a/cmd/service/tasker.go +++ b/cmd/service/tasker.go @@ -10,8 +10,6 @@ import ( // Load tasker handlers, injecting any necessary handler dependencies from the system container. func initTasker(custodialContainer *custodial.Custodial, redisPool *redis.RedisPool) *tasker.TaskerServer { - lo.Debug("Bootstrapping tasker") - taskerServerOpts := tasker.TaskerServerOpts{ Concurrency: ko.MustInt("asynq.worker_count"), Logg: lo, diff --git a/cmd/service/utils.go b/cmd/service/utils.go new file mode 100644 index 0000000..8cac34e --- /dev/null +++ b/cmd/service/utils.go @@ -0,0 +1,31 @@ +package main + +import ( + "context" + "os" + "os/signal" + "syscall" + "time" +) + +func createSigChannel() (chan os.Signal, func()) { + signalCh := make(chan os.Signal, 1) + signal.Notify(signalCh, os.Interrupt, syscall.SIGTERM, syscall.SIGINT) + + return signalCh, func() { + close(signalCh) + } +} + +func startGracefulShutdown(ctx context.Context, internalServices *internalServiceContainer) { + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + internalServices.jetstreamSub.Close() + + if err := internalServices.apiService.Shutdown(ctx); err != nil { + lo.Fatal("Could not gracefully shutdown api server", "err", err) + } + + internalServices.taskerService.Stop() +} diff --git a/internal/api/sign.go b/internal/api/sign.go index ddb41d5..7f12b0a 100644 --- a/internal/api/sign.go +++ b/internal/api/sign.go @@ -33,7 +33,7 @@ func HandleSignTransfer(c echo.Context) error { ) if err := c.Bind(&req); err != nil { - return NewBadRequestError(err) + return NewBadRequestError(ErrInvalidJSON) } if err := c.Validate(req); err != nil { diff --git a/internal/api/types.go b/internal/api/types.go index 6d59b1a..011e2d8 100644 --- a/internal/api/types.go +++ b/internal/api/types.go @@ -1,5 +1,11 @@ package api +import "errors" + +var ( + ErrInvalidJSON = errors.New("Invalid JSON structure.") +) + type H map[string]any type OkResp struct { diff --git a/internal/events/events.go b/internal/events/events.go deleted file mode 100644 index e51f73c..0000000 --- a/internal/events/events.go +++ /dev/null @@ -1,7 +0,0 @@ -package events - -type EventPayload struct { - OtxId uint `json:"otxId"` - TrackingId string `json:"trackingId"` - TxHash string `json:"txHash"` -} diff --git a/internal/events/jetstream.go b/internal/events/jetstream.go index e6afbad..c28dc3a 100644 --- a/internal/events/jetstream.go +++ b/internal/events/jetstream.go @@ -1,17 +1,20 @@ package events import ( + "context" "encoding/json" + "errors" "time" + "github.com/grassrootseconomics/cic-custodial/internal/store" "github.com/nats-io/nats.go" "github.com/zerodha/logf" ) const ( - StreamName string = "CUSTODIAL" - StreamSubjects string = "CUSTODIAL.*" - // Subjects + // Pub + StreamName string = "CUSTODIAL" + StreamSubjects string = "CUSTODIAL.*" AccountNewNonce string = "CUSTODIAL.accountNewNonce" AccountRegister string = "CUSTODIAL.accountRegister" AccountGiftGas string = "CUSTODIAL.systemNewAccountGas" @@ -20,20 +23,36 @@ const ( DispatchFail string = "CUSTODIAL.dispatchFail" DispatchSuccess string = "CUSTODIAL.dispatchSuccess" SignTransfer string = "CUSTODIAL.signTransfer" + + // Sub + durableId = "cic-custodial" + pullStream = "CHAIN" + pullSubject = "CHAIN.*" + actionTimeout = 5 * time.Second ) -type JetStreamOpts struct { - Logg logf.Logger - ServerUrl string - PersistDuration time.Duration - DedupDuration time.Duration -} +type ( + JetStreamOpts struct { + Logg logf.Logger + ServerUrl string + PersistDuration time.Duration + PgStore store.Store + DedupDuration time.Duration + } -type JetStream struct { - logg logf.Logger - jsCtx nats.JetStreamContext - natsConn *nats.Conn -} + JetStream struct { + logg logf.Logger + jsCtx nats.JetStreamContext + pgStore store.Store + natsConn *nats.Conn + } + + EventPayload struct { + OtxId uint `json:"otxId"` + TrackingId string `json:"trackingId"` + TxHash string `json:"txHash"` + } +) func NewJetStreamEventEmitter(o JetStreamOpts) (*JetStream, error) { natsConn, err := nats.Connect(o.ServerUrl) @@ -61,9 +80,20 @@ func NewJetStreamEventEmitter(o JetStreamOpts) (*JetStream, error) { } } + // Add a durable consumer + _, err = js.AddConsumer(pullStream, &nats.ConsumerConfig{ + Durable: durableId, + AckPolicy: nats.AckExplicitPolicy, + FilterSubject: pullSubject, + }) + if err != nil { + return nil, err + } + return &JetStream{ logg: o.Logg, jsCtx: js, + pgStore: o.PgStore, natsConn: natsConn, }, nil } @@ -89,3 +119,51 @@ func (js *JetStream) Publish(subject string, dedupId string, eventPayload interf return nil } + +func (js *JetStream) Subscriber() error { + subOpts := []nats.SubOpt{ + nats.ManualAck(), + nats.Bind(pullStream, durableId), + } + + natsSub, err := js.jsCtx.PullSubscribe(pullSubject, durableId, subOpts...) + if err != nil { + return err + } + + for { + events, err := natsSub.Fetch(1) + if err != nil { + if errors.Is(err, nats.ErrTimeout) { + continue + } else if errors.Is(err, nats.ErrConnectionClosed) { + return nil + } else { + return err + } + } + if len(events) > 0 { + var ( + chainEvent store.MinimalTxInfo + + msg = events[0] + ) + + if err := json.Unmarshal(msg.Data, &chainEvent); err != nil { + msg.Nak() + js.logg.Error("jetstream sub: json unmarshal fail", "error", err) + } else { + ctx, cancel := context.WithTimeout(context.Background(), actionTimeout) + + if err := js.pgStore.UpdateOtxStatusFromChainEvent(ctx, chainEvent); err != nil { + msg.Nak() + js.logg.Error("jetstream sub: otx marker failed to update state", "error", err) + } else { + msg.Ack() + } + cancel() + } + + } + } +} diff --git a/internal/events/jetstream_sub.go b/internal/events/jetstream_sub.go deleted file mode 100644 index 3af5405..0000000 --- a/internal/events/jetstream_sub.go +++ /dev/null @@ -1,73 +0,0 @@ -package events - -import ( - "context" - "encoding/json" - "errors" - - "github.com/grassrootseconomics/cic-custodial/internal/store" - "github.com/nats-io/nats.go" -) - -const ( - durableId = "cic-custodial" - pullStream = "CHAIN" - pullSubject = "CHAIN.*" -) - -func (js *JetStream) ChainSubscription(ctx context.Context, pgStore store.Store) error { - _, err := js.jsCtx.AddConsumer(pullStream, &nats.ConsumerConfig{ - Durable: durableId, - AckPolicy: nats.AckExplicitPolicy, - FilterSubject: pullSubject, - }) - if err != nil { - return err - } - - subOpts := []nats.SubOpt{ - nats.ManualAck(), - nats.Bind(pullStream, durableId), - } - - natsSub, err := js.jsCtx.PullSubscribe(pullSubject, durableId, subOpts...) - if err != nil { - return err - } - - for { - select { - case <-ctx.Done(): - js.logg.Info("jetstream chain sub: shutdown signal received") - js.Close() - return nil - default: - events, err := natsSub.Fetch(1) - if err != nil { - if errors.Is(err, nats.ErrTimeout) { - continue - } else { - js.logg.Error("jetstream chain sub: fetch other error", "error", err) - } - } - if len(events) == 0 { - continue - } - var ( - chainEvent store.MinimalTxInfo - ) - - if err := json.Unmarshal(events[0].Data, &chainEvent); err != nil { - js.logg.Error("jetstream chain sub: json unmarshal fail", "error", err) - } - - if err := pgStore.UpdateOtxStatusFromChainEvent(context.Background(), chainEvent); err != nil { - events[0].Nak() - js.logg.Error("jetstream chain sub: otx marker failed to update state", "error", err) - } - events[0].Ack() - js.logg.Debug("jetstream chain sub: successfully updated status", "tx", chainEvent.TxHash) - } - - } -} diff --git a/internal/store/otx.go b/internal/store/otx.go index 00e0b2a..9d07b19 100644 --- a/internal/store/otx.go +++ b/internal/store/otx.go @@ -60,6 +60,7 @@ func (s *PostgresStore) GetTxStatusByTrackingId(ctx context.Context, trackingId } func (s *PostgresStore) UpdateOtxStatusFromChainEvent(ctx context.Context, chainEvent MinimalTxInfo) error { + var ( status = enum.SUCCESS ) diff --git a/pkg/postgres/pool.go b/pkg/postgres/pool.go index 3915644..d9d05ee 100644 --- a/pkg/postgres/pool.go +++ b/pkg/postgres/pool.go @@ -19,18 +19,18 @@ type PostgresPoolOpts struct { } // NewPostgresPool creates a reusbale connection pool across the cic-custodial component. -func NewPostgresPool(o PostgresPoolOpts) (*pgxpool.Pool, error) { +func NewPostgresPool(ctx context.Context, o PostgresPoolOpts) (*pgxpool.Pool, error) { parsedConfig, err := pgxpool.ParseConfig(o.DSN) if err != nil { return nil, err } - dbPool, err := pgxpool.NewWithConfig(context.Background(), parsedConfig) + dbPool, err := pgxpool.NewWithConfig(ctx, parsedConfig) if err != nil { return nil, err } - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() conn, err := dbPool.Acquire(ctx) diff --git a/pkg/redis/pool.go b/pkg/redis/pool.go index f65a9d7..8797bf4 100644 --- a/pkg/redis/pool.go +++ b/pkg/redis/pool.go @@ -18,7 +18,7 @@ type RedisPool struct { // NewRedisPool creates a reusable connection across the cic-custodial componenent. // Note: Each db namespace requires its own connection pool. -func NewRedisPool(o RedisPoolOpts) (*RedisPool, error) { +func NewRedisPool(ctx context.Context, o RedisPoolOpts) (*RedisPool, error) { redisOpts, err := redis.ParseURL(o.DSN) if err != nil { return nil, err @@ -28,7 +28,7 @@ func NewRedisPool(o RedisPoolOpts) (*RedisPool, error) { redisClient := redis.NewClient(redisOpts) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() _, err = redisClient.Ping(ctx).Result()