refactor: use sigChan for shutdown, ctx fixes

* refactor main entry point for starting services
* minor fixes around ctx propagation
* improve otx marker js subscriber
This commit is contained in:
Mohamed Sohail 2023-03-02 15:46:02 +00:00
parent a1b6cb08d8
commit a47e44f262
Signed by: kamikazechaser
GPG Key ID: 7DD45520C01CD85D
14 changed files with 223 additions and 217 deletions

View File

@ -13,7 +13,7 @@ import (
)
const (
contextTimeout = 5
contextTimeout = 5 * time.Second
)
// Bootstrap API server.
@ -24,11 +24,9 @@ func initApiServer(custodialContainer *custodial.Custodial) *echo.Echo {
server := echo.New()
server.HideBanner = true
server.HidePort = true
server.Validator = &api.Validator{
ValidatorProvider: customValidator,
}
server.HTTPErrorHandler = customHTTPErrorHandler
server.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
@ -39,7 +37,7 @@ func initApiServer(custodialContainer *custodial.Custodial) *echo.Echo {
})
server.Use(middleware.Recover())
server.Use(middleware.BodyLimit("1M"))
server.Use(middleware.ContextTimeout(time.Duration(contextTimeout * time.Second)))
server.Use(middleware.ContextTimeout(contextTimeout))
if ko.Bool("service.metrics") {
server.GET("/metrics", func(c echo.Context) error {
@ -61,8 +59,7 @@ func customHTTPErrorHandler(err error, c echo.Context) {
return
}
he, ok := err.(*echo.HTTPError)
if ok {
if he, ok := err.(*echo.HTTPError); ok {
var errorMsg string
if m, ok := he.Message.(error); ok {
@ -75,12 +72,12 @@ func customHTTPErrorHandler(err error, c echo.Context) {
Ok: false,
Message: errorMsg,
})
} else {
lo.Error("api: echo error", "path", c.Path(), "err", err)
c.JSON(http.StatusInternalServerError, api.ErrResp{
Ok: false,
Message: "Internal server error.",
})
return
}
lo.Error("api: echo error", "path", c.Path(), "err", err)
c.JSON(http.StatusInternalServerError, api.ErrResp{
Ok: false,
Message: "Internal server error.",
})
}

View File

@ -31,7 +31,10 @@ func initAbis() map[string]*w3.Func {
// Bootstrap the internal custodial system configs and system signer key.
// This container is passed down to individual tasker and API handlers.
func initSystemContainer(ctx context.Context, noncestore nonce.Noncestore) (*tasker.SystemContainer, error) {
func initSystemContainer(ctx context.Context, noncestore nonce.Noncestore) *tasker.SystemContainer {
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
// Some custodial system defaults loaded from the config file.
systemContainer := &tasker.SystemContainer{
Abis: initAbis(),
@ -48,6 +51,7 @@ func initSystemContainer(ctx context.Context, noncestore nonce.Noncestore) (*tas
TokenDecimals: ko.MustInt("system.token_decimals"),
TokenTransferGasLimit: uint64(ko.MustInt64("system.token_transfer_gas_limit")),
}
// Check if system signer account nonce is present.
// If not (first boot), we bootstrap it from the network.
currentSystemNonce, err := noncestore.Peek(ctx, ko.MustString("system.public_key"))
@ -56,15 +60,15 @@ func initSystemContainer(ctx context.Context, noncestore nonce.Noncestore) (*tas
nonce, err := noncestore.SyncNetworkNonce(ctx, ko.MustString("system.public_key"))
lo.Info("custodial: syncing system nonce", "nonce", nonce)
if err != nil {
return nil, err
lo.Fatal("custodial: critical error bootstrapping system container", "error", err)
}
}
loadedPrivateKey, err := eth_crypto.HexToECDSA(ko.MustString("system.private_key"))
if err != nil {
return nil, err
lo.Fatal("custodial: critical error bootstrapping system container", "error", err)
}
systemContainer.PrivateKey = loadedPrivateKey
return systemContainer, nil
return systemContainer
}

View File

@ -1,6 +1,7 @@
package main
import (
"context"
"strings"
"time"
@ -25,10 +26,10 @@ import (
)
// Load logger.
func initLogger(debug bool) logf.Logger {
func initLogger() logf.Logger {
loggOpts := logg.LoggOpts{}
if debug {
if debugFlag {
loggOpts.Color = true
loggOpts.Caller = true
loggOpts.Debug = true
@ -38,12 +39,12 @@ func initLogger(debug bool) logf.Logger {
}
// Load config file.
func initConfig(configFilePath string) *koanf.Koanf {
func initConfig() *koanf.Koanf {
var (
ko = koanf.New(".")
)
confFile := file.Provider(configFilePath)
confFile := file.Provider(confFlag)
if err := ko.Load(confFile, toml.Parser()); err != nil {
lo.Fatal("Could not load config file", "error", err)
}
@ -55,13 +56,15 @@ func initConfig(configFilePath string) *koanf.Koanf {
lo.Fatal("Could not override config from env vars", "error", err)
}
ko.Print()
if debugFlag {
ko.Print()
}
return ko
}
// Load Celo chain provider.
func initCeloProvider() (*celoutils.Provider, error) {
func initCeloProvider() *celoutils.Provider {
providerOpts := celoutils.ProviderOpts{
RpcEndpoint: ko.MustString("chain.rpc_endpoint"),
}
@ -74,80 +77,80 @@ func initCeloProvider() (*celoutils.Provider, error) {
provider, err := celoutils.NewProvider(providerOpts)
if err != nil {
return nil, err
lo.Fatal("init: critical error loading chain provider", "error", err)
}
return provider, nil
return provider
}
// Load postgres pool.
func initPostgresPool() (*pgxpool.Pool, error) {
func initPostgresPool() *pgxpool.Pool {
poolOpts := postgres.PostgresPoolOpts{
DSN: ko.MustString("postgres.dsn"),
MigrationsFolderPath: migrationsFolderFlag,
}
pool, err := postgres.NewPostgresPool(poolOpts)
pool, err := postgres.NewPostgresPool(context.Background(), poolOpts)
if err != nil {
return nil, err
lo.Fatal("init: critical error connecting to postgres", "error", err)
}
return pool, nil
return pool
}
// Load separate redis connection for the tasker on a reserved db namespace.
func initAsynqRedisPool() (*redis.RedisPool, error) {
func initAsynqRedisPool() *redis.RedisPool {
poolOpts := redis.RedisPoolOpts{
DSN: ko.MustString("asynq.dsn"),
MinIdleConns: ko.MustInt("redis.min_idle_conn"),
}
pool, err := redis.NewRedisPool(poolOpts)
pool, err := redis.NewRedisPool(context.Background(), poolOpts)
if err != nil {
return nil, err
lo.Fatal("init: critical error connecting to asynq redis db", "error", err)
}
return pool, nil
return pool
}
// Common redis connection on a different db namespace from the takser.
func initCommonRedisPool() (*redis.RedisPool, error) {
func initCommonRedisPool() *redis.RedisPool {
poolOpts := redis.RedisPoolOpts{
DSN: ko.MustString("redis.dsn"),
MinIdleConns: ko.MustInt("redis.min_idle_conn"),
}
pool, err := redis.NewRedisPool(poolOpts)
pool, err := redis.NewRedisPool(context.Background(), poolOpts)
if err != nil {
return nil, err
lo.Fatal("init: critical error connecting to common redis db", "error", err)
}
return pool, nil
return pool
}
// Load SQL statements into struct.
func initQueries(queriesPath string) (*queries.Queries, error) {
parsedQueries, err := goyesql.ParseFile(queriesPath)
func initQueries() *queries.Queries {
parsedQueries, err := goyesql.ParseFile(queriesFlag)
if err != nil {
return nil, err
lo.Fatal("init: critical error loading SQL queries", "error", err)
}
loadedQueries, err := queries.LoadQueries(parsedQueries)
if err != nil {
return nil, err
lo.Fatal("init: critical error loading SQL queries", "error", err)
}
return loadedQueries, nil
return loadedQueries
}
// Load postgres based keystore.
func initPostgresKeystore(postgresPool *pgxpool.Pool, queries *queries.Queries) (keystore.Keystore, error) {
func initPostgresKeystore(postgresPool *pgxpool.Pool, queries *queries.Queries) keystore.Keystore {
keystore := keystore.NewPostgresKeytore(keystore.Opts{
PostgresPool: postgresPool,
Queries: queries,
})
return keystore, nil
return keystore
}
// Load redis backed noncestore.
@ -180,17 +183,19 @@ func initPostgresStore(postgresPool *pgxpool.Pool, queries *queries.Queries) sto
}
// Init JetStream context for tasker events.
func initJetStream() (*events.JetStream, error) {
func initJetStream(pgStore store.Store) *events.JetStream {
jsEmitter, err := events.NewJetStreamEventEmitter(events.JetStreamOpts{
Logg: lo,
PgStore: pgStore,
ServerUrl: ko.MustString("jetstream.endpoint"),
PersistDuration: time.Duration(ko.MustInt("jetstream.persist_duration_hrs")) * time.Hour,
DedupDuration: time.Duration(ko.MustInt("jetstream.dedup_duration_hrs")) * time.Hour,
})
if err != nil {
return nil, err
lo.Fatal("main: critical error loading jetstream event emitter")
}
return jsEmitter, nil
return jsEmitter
}

View File

@ -3,19 +3,25 @@ package main
import (
"context"
"flag"
"os"
"os/signal"
"strings"
"sync"
"syscall"
"github.com/grassrootseconomics/cic-custodial/internal/custodial"
"github.com/grassrootseconomics/cic-custodial/internal/events"
"github.com/grassrootseconomics/cic-custodial/internal/tasker"
"github.com/knadh/koanf/v2"
"github.com/labstack/echo/v4"
"github.com/zerodha/logf"
)
type (
internalServiceContainer struct {
apiService *echo.Echo
jetstreamSub *events.JetStream
taskerService *tasker.TaskerServer
}
)
var (
confFlag string
debugFlag bool
@ -33,63 +39,25 @@ func init() {
flag.StringVar(&queriesFlag, "queries", "queries.sql", "Queries file location")
flag.Parse()
lo = initLogger(debugFlag)
ko = initConfig(confFlag)
lo = initLogger()
ko = initConfig()
}
func main() {
var (
tasker *tasker.TaskerServer
apiServer *echo.Echo
)
parsedQueries := initQueries()
celoProvider := initCeloProvider()
postgresPool := initPostgresPool()
asynqRedisPool := initAsynqRedisPool()
redisPool := initCommonRedisPool()
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
queries, err := initQueries(queriesFlag)
if err != nil {
lo.Fatal("main: critical error loading SQL queries", "error", err)
}
celoProvider, err := initCeloProvider()
if err != nil {
lo.Fatal("main: critical error loading chain provider", "error", err)
}
postgresPool, err := initPostgresPool()
if err != nil {
lo.Fatal("main: critical error connecting to postgres", "error", err)
}
asynqRedisPool, err := initAsynqRedisPool()
if err != nil {
lo.Fatal("main: critical error connecting to asynq redis db", "error", err)
}
redisPool, err := initCommonRedisPool()
if err != nil {
lo.Fatal("main: critical error connecting to common redis db", "error", err)
}
postgresKeystore, err := initPostgresKeystore(postgresPool, queries)
if err != nil {
lo.Fatal("main: critical error loading keystore")
}
jsEventEmitter, err := initJetStream()
if err != nil {
lo.Fatal("main: critical error loading jetstream event emitter")
}
pgStore := initPostgresStore(postgresPool, queries)
postgresKeystore := initPostgresKeystore(postgresPool, parsedQueries)
pgStore := initPostgresStore(postgresPool, parsedQueries)
redisNoncestore := initRedisNoncestore(redisPool, celoProvider)
lockProvider := initLockProvider(redisPool.Client)
taskerClient := initTaskerClient(asynqRedisPool)
systemContainer := initSystemContainer(context.Background(), redisNoncestore)
systemContainer, err := initSystemContainer(context.Background(), redisNoncestore)
if err != nil {
lo.Fatal("main: critical error bootstrapping system container", "error", err)
}
jsEventEmitter := initJetStream(pgStore)
custodial := &custodial.Custodial{
CeloProvider: celoProvider,
@ -102,14 +70,18 @@ func main() {
TaskerClient: taskerClient,
}
internalServices := &internalServiceContainer{}
wg := &sync.WaitGroup{}
apiServer = initApiServer(custodial)
signalCh, closeCh := createSigChannel()
defer closeCh()
internalServices.apiService = initApiServer(custodial)
wg.Add(1)
go func() {
defer wg.Done()
lo.Info("main: starting API server")
if err := apiServer.Start(ko.MustString("service.address")); err != nil {
if err := internalServices.apiService.Start(ko.MustString("service.address")); err != nil {
if strings.Contains(err.Error(), "Server closed") {
lo.Info("main: shutting down server")
} else {
@ -118,34 +90,28 @@ func main() {
}
}()
tasker = initTasker(custodial, asynqRedisPool)
internalServices.taskerService = initTasker(custodial, asynqRedisPool)
wg.Add(1)
go func() {
defer wg.Done()
lo.Info("Starting tasker")
if err := tasker.Start(); err != nil {
if err := internalServices.taskerService.Start(); err != nil {
lo.Fatal("main: could not start task server", "err", err)
}
}()
internalServices.jetstreamSub = jsEventEmitter
wg.Add(1)
go func() {
defer wg.Done()
lo.Info("Starting jetstream subscriber")
if err := jsEventEmitter.ChainSubscription(ctx, pgStore); err != nil {
lo.Fatal("main: jetstream subscriber", "err", err)
lo.Info("Starting jetstream sub")
if err := internalServices.jetstreamSub.Subscriber(); err != nil {
lo.Fatal("main: error running jetstream sub", "err", err)
}
}()
<-ctx.Done()
lo.Info("main: stopping tasker")
tasker.Stop()
lo.Info("main: stopping api server")
if err := apiServer.Shutdown(ctx); err != nil {
lo.Error("Could not gracefully shutdown api server", "err", err)
}
<-signalCh
startGracefulShutdown(context.Background(), internalServices)
wg.Wait()
}

View File

@ -10,8 +10,6 @@ import (
// Load tasker handlers, injecting any necessary handler dependencies from the system container.
func initTasker(custodialContainer *custodial.Custodial, redisPool *redis.RedisPool) *tasker.TaskerServer {
lo.Debug("Bootstrapping tasker")
taskerServerOpts := tasker.TaskerServerOpts{
Concurrency: ko.MustInt("asynq.worker_count"),
Logg: lo,

31
cmd/service/utils.go Normal file
View File

@ -0,0 +1,31 @@
package main
import (
"context"
"os"
"os/signal"
"syscall"
"time"
)
func createSigChannel() (chan os.Signal, func()) {
signalCh := make(chan os.Signal, 1)
signal.Notify(signalCh, os.Interrupt, syscall.SIGTERM, syscall.SIGINT)
return signalCh, func() {
close(signalCh)
}
}
func startGracefulShutdown(ctx context.Context, internalServices *internalServiceContainer) {
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
internalServices.jetstreamSub.Close()
if err := internalServices.apiService.Shutdown(ctx); err != nil {
lo.Fatal("Could not gracefully shutdown api server", "err", err)
}
internalServices.taskerService.Stop()
}

View File

@ -33,7 +33,7 @@ func HandleSignTransfer(c echo.Context) error {
)
if err := c.Bind(&req); err != nil {
return NewBadRequestError(err)
return NewBadRequestError(ErrInvalidJSON)
}
if err := c.Validate(req); err != nil {

View File

@ -1,5 +1,11 @@
package api
import "errors"
var (
ErrInvalidJSON = errors.New("Invalid JSON structure.")
)
type H map[string]any
type OkResp struct {

View File

@ -1,7 +0,0 @@
package events
type EventPayload struct {
OtxId uint `json:"otxId"`
TrackingId string `json:"trackingId"`
TxHash string `json:"txHash"`
}

View File

@ -1,17 +1,20 @@
package events
import (
"context"
"encoding/json"
"errors"
"time"
"github.com/grassrootseconomics/cic-custodial/internal/store"
"github.com/nats-io/nats.go"
"github.com/zerodha/logf"
)
const (
StreamName string = "CUSTODIAL"
StreamSubjects string = "CUSTODIAL.*"
// Subjects
// Pub
StreamName string = "CUSTODIAL"
StreamSubjects string = "CUSTODIAL.*"
AccountNewNonce string = "CUSTODIAL.accountNewNonce"
AccountRegister string = "CUSTODIAL.accountRegister"
AccountGiftGas string = "CUSTODIAL.systemNewAccountGas"
@ -20,20 +23,36 @@ const (
DispatchFail string = "CUSTODIAL.dispatchFail"
DispatchSuccess string = "CUSTODIAL.dispatchSuccess"
SignTransfer string = "CUSTODIAL.signTransfer"
// Sub
durableId = "cic-custodial"
pullStream = "CHAIN"
pullSubject = "CHAIN.*"
actionTimeout = 5 * time.Second
)
type JetStreamOpts struct {
Logg logf.Logger
ServerUrl string
PersistDuration time.Duration
DedupDuration time.Duration
}
type (
JetStreamOpts struct {
Logg logf.Logger
ServerUrl string
PersistDuration time.Duration
PgStore store.Store
DedupDuration time.Duration
}
type JetStream struct {
logg logf.Logger
jsCtx nats.JetStreamContext
natsConn *nats.Conn
}
JetStream struct {
logg logf.Logger
jsCtx nats.JetStreamContext
pgStore store.Store
natsConn *nats.Conn
}
EventPayload struct {
OtxId uint `json:"otxId"`
TrackingId string `json:"trackingId"`
TxHash string `json:"txHash"`
}
)
func NewJetStreamEventEmitter(o JetStreamOpts) (*JetStream, error) {
natsConn, err := nats.Connect(o.ServerUrl)
@ -61,9 +80,20 @@ func NewJetStreamEventEmitter(o JetStreamOpts) (*JetStream, error) {
}
}
// Add a durable consumer
_, err = js.AddConsumer(pullStream, &nats.ConsumerConfig{
Durable: durableId,
AckPolicy: nats.AckExplicitPolicy,
FilterSubject: pullSubject,
})
if err != nil {
return nil, err
}
return &JetStream{
logg: o.Logg,
jsCtx: js,
pgStore: o.PgStore,
natsConn: natsConn,
}, nil
}
@ -89,3 +119,51 @@ func (js *JetStream) Publish(subject string, dedupId string, eventPayload interf
return nil
}
func (js *JetStream) Subscriber() error {
subOpts := []nats.SubOpt{
nats.ManualAck(),
nats.Bind(pullStream, durableId),
}
natsSub, err := js.jsCtx.PullSubscribe(pullSubject, durableId, subOpts...)
if err != nil {
return err
}
for {
events, err := natsSub.Fetch(1)
if err != nil {
if errors.Is(err, nats.ErrTimeout) {
continue
} else if errors.Is(err, nats.ErrConnectionClosed) {
return nil
} else {
return err
}
}
if len(events) > 0 {
var (
chainEvent store.MinimalTxInfo
msg = events[0]
)
if err := json.Unmarshal(msg.Data, &chainEvent); err != nil {
msg.Nak()
js.logg.Error("jetstream sub: json unmarshal fail", "error", err)
} else {
ctx, cancel := context.WithTimeout(context.Background(), actionTimeout)
if err := js.pgStore.UpdateOtxStatusFromChainEvent(ctx, chainEvent); err != nil {
msg.Nak()
js.logg.Error("jetstream sub: otx marker failed to update state", "error", err)
} else {
msg.Ack()
}
cancel()
}
}
}
}

View File

@ -1,73 +0,0 @@
package events
import (
"context"
"encoding/json"
"errors"
"github.com/grassrootseconomics/cic-custodial/internal/store"
"github.com/nats-io/nats.go"
)
const (
durableId = "cic-custodial"
pullStream = "CHAIN"
pullSubject = "CHAIN.*"
)
func (js *JetStream) ChainSubscription(ctx context.Context, pgStore store.Store) error {
_, err := js.jsCtx.AddConsumer(pullStream, &nats.ConsumerConfig{
Durable: durableId,
AckPolicy: nats.AckExplicitPolicy,
FilterSubject: pullSubject,
})
if err != nil {
return err
}
subOpts := []nats.SubOpt{
nats.ManualAck(),
nats.Bind(pullStream, durableId),
}
natsSub, err := js.jsCtx.PullSubscribe(pullSubject, durableId, subOpts...)
if err != nil {
return err
}
for {
select {
case <-ctx.Done():
js.logg.Info("jetstream chain sub: shutdown signal received")
js.Close()
return nil
default:
events, err := natsSub.Fetch(1)
if err != nil {
if errors.Is(err, nats.ErrTimeout) {
continue
} else {
js.logg.Error("jetstream chain sub: fetch other error", "error", err)
}
}
if len(events) == 0 {
continue
}
var (
chainEvent store.MinimalTxInfo
)
if err := json.Unmarshal(events[0].Data, &chainEvent); err != nil {
js.logg.Error("jetstream chain sub: json unmarshal fail", "error", err)
}
if err := pgStore.UpdateOtxStatusFromChainEvent(context.Background(), chainEvent); err != nil {
events[0].Nak()
js.logg.Error("jetstream chain sub: otx marker failed to update state", "error", err)
}
events[0].Ack()
js.logg.Debug("jetstream chain sub: successfully updated status", "tx", chainEvent.TxHash)
}
}
}

View File

@ -60,6 +60,7 @@ func (s *PostgresStore) GetTxStatusByTrackingId(ctx context.Context, trackingId
}
func (s *PostgresStore) UpdateOtxStatusFromChainEvent(ctx context.Context, chainEvent MinimalTxInfo) error {
var (
status = enum.SUCCESS
)

View File

@ -19,18 +19,18 @@ type PostgresPoolOpts struct {
}
// NewPostgresPool creates a reusbale connection pool across the cic-custodial component.
func NewPostgresPool(o PostgresPoolOpts) (*pgxpool.Pool, error) {
func NewPostgresPool(ctx context.Context, o PostgresPoolOpts) (*pgxpool.Pool, error) {
parsedConfig, err := pgxpool.ParseConfig(o.DSN)
if err != nil {
return nil, err
}
dbPool, err := pgxpool.NewWithConfig(context.Background(), parsedConfig)
dbPool, err := pgxpool.NewWithConfig(ctx, parsedConfig)
if err != nil {
return nil, err
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
conn, err := dbPool.Acquire(ctx)

View File

@ -18,7 +18,7 @@ type RedisPool struct {
// NewRedisPool creates a reusable connection across the cic-custodial componenent.
// Note: Each db namespace requires its own connection pool.
func NewRedisPool(o RedisPoolOpts) (*RedisPool, error) {
func NewRedisPool(ctx context.Context, o RedisPoolOpts) (*RedisPool, error) {
redisOpts, err := redis.ParseURL(o.DSN)
if err != nil {
return nil, err
@ -28,7 +28,7 @@ func NewRedisPool(o RedisPoolOpts) (*RedisPool, error) {
redisClient := redis.NewClient(redisOpts)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
_, err = redisClient.Ping(ctx).Result()