refactor: ctx propagation, api handlers

* use context timeout middleware for correct ctx propagation
* Fix bind error handling
* Fix validation error handling
* Fix HTTP error handling (4XX)
* tasker client now  accepts ctx
* add recovery and body size middleware
This commit is contained in:
Mohamed Sohail 2023-02-24 16:46:46 +00:00
parent ce6bdbf4ed
commit add7f2a442
Signed by: kamikazechaser
GPG Key ID: 7DD45520C01CD85D
16 changed files with 201 additions and 161 deletions

View File

@ -1,53 +1,45 @@
package main package main
import ( import (
"errors"
"net/http" "net/http"
"time"
"github.com/VictoriaMetrics/metrics" "github.com/VictoriaMetrics/metrics"
"github.com/go-playground/validator/v10" "github.com/go-playground/validator/v10"
"github.com/grassrootseconomics/cic-custodial/internal/api" "github.com/grassrootseconomics/cic-custodial/internal/api"
"github.com/grassrootseconomics/cic-custodial/internal/custodial" "github.com/grassrootseconomics/cic-custodial/internal/custodial"
"github.com/hibiken/asynq"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
)
const (
contextTimeout = 5
) )
// Bootstrap API server. // Bootstrap API server.
func initApiServer(custodialContainer *custodial.Custodial) *echo.Echo { func initApiServer(custodialContainer *custodial.Custodial) *echo.Echo {
lo.Debug("api: bootstrapping api server") customValidator := validator.New()
customValidator.RegisterValidation("eth_checksum", api.EthChecksumValidator)
server := echo.New() server := echo.New()
server.HideBanner = true server.HideBanner = true
server.HidePort = true server.HidePort = true
server.HTTPErrorHandler = func(err error, c echo.Context) { server.Validator = &api.Validator{
// Handle asynq duplication errors across all api handlers. ValidatorProvider: customValidator,
if errors.Is(err, asynq.ErrTaskIDConflict) {
c.JSON(http.StatusForbidden, api.ErrResp{
Ok: false,
Code: api.DUPLICATE_ERROR,
Message: "Request with duplicate tracking id submitted.",
})
return
} }
if _, ok := err.(validator.ValidationErrors); ok { server.HTTPErrorHandler = customHTTPErrorHandler
c.JSON(http.StatusForbidden, api.ErrResp{
Ok: false,
Code: api.VALIDATION_ERROR,
Message: err.(validator.ValidationErrors).Error(),
})
return
}
// Log internal server error for further investigation. server.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
lo.Error("api:", "path", c.Path(), "err", err) return func(c echo.Context) error {
c.Set("cu", custodialContainer)
c.JSON(http.StatusInternalServerError, api.ErrResp{ return next(c)
Ok: false,
Code: api.INTERNAL_ERROR,
Message: "Internal server error.",
})
} }
})
server.Use(middleware.Recover())
server.Use(middleware.BodyLimit("1M"))
server.Use(middleware.ContextTimeout(time.Duration(contextTimeout * time.Second)))
if ko.Bool("service.metrics") { if ko.Bool("service.metrics") {
server.GET("/metrics", func(c echo.Context) error { server.GET("/metrics", func(c echo.Context) error {
@ -56,17 +48,39 @@ func initApiServer(custodialContainer *custodial.Custodial) *echo.Echo {
}) })
} }
customValidator := validator.New()
customValidator.RegisterValidation("eth_checksum", api.EthChecksumValidator)
server.Validator = &api.Validator{
ValidatorProvider: customValidator,
}
apiRoute := server.Group("/api") apiRoute := server.Group("/api")
apiRoute.POST("/account/create", api.CreateAccountHandler(custodialContainer)) apiRoute.POST("/account/create", api.HandleAccountCreate)
apiRoute.POST("/sign/transfer", api.SignTransferHandler(custodialContainer)) apiRoute.POST("/sign/transfer", api.HandleSignTransfer)
apiRoute.GET("/track/:trackingId", api.TxStatus(custodialContainer.PgStore)) apiRoute.GET("/track/:trackingId", api.HandleTrackTx)
return server return server
} }
func customHTTPErrorHandler(err error, c echo.Context) {
if c.Response().Committed {
return
}
he, ok := err.(*echo.HTTPError)
if ok {
var errorMsg string
if m, ok := he.Message.(error); ok {
errorMsg = m.Error()
} else if m, ok := he.Message.(string); ok {
errorMsg = m
}
c.JSON(he.Code, api.ErrResp{
Ok: false,
Message: errorMsg,
})
} else {
lo.Error("api: echo error", "path", c.Path(), "err", err)
c.JSON(http.StatusInternalServerError, api.ErrResp{
Ok: false,
Message: "Internal server error.",
})
}
}

1
go.mod
View File

@ -40,6 +40,7 @@ require (
github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-stack/stack v1.8.1 // indirect github.com/go-stack/stack v1.8.1 // indirect
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
github.com/golang/protobuf v1.5.2 // indirect github.com/golang/protobuf v1.5.2 // indirect
github.com/golang/snappy v0.0.4 // indirect github.com/golang/snappy v0.0.4 // indirect
github.com/google/go-cmp v0.5.9 // indirect github.com/google/go-cmp v0.5.9 // indirect

2
go.sum
View File

@ -226,6 +226,8 @@ github.com/gofrs/uuid v3.3.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRx
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
github.com/golang/geo v0.0.0-20190916061304-5b978397cfec/go.mod h1:QZ0nwyI2jOfgRAoBvP+ab5aRr7c9x7lhGEJrKvBwjWI= github.com/golang/geo v0.0.0-20190916061304-5b978397cfec/go.mod h1:QZ0nwyI2jOfgRAoBvP+ab5aRr7c9x7lhGEJrKvBwjWI=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=

View File

@ -15,9 +15,10 @@ import (
// CreateAccountHandler route. // CreateAccountHandler route.
// POST: /api/account/create // POST: /api/account/create
// Returns the public key. // Returns the public key.
func CreateAccountHandler(cu *custodial.Custodial) func(echo.Context) error { func HandleAccountCreate(c echo.Context) error {
return func(c echo.Context) error { var (
trackingId := uuid.NewString() cu = c.Get("cu").(*custodial.Custodial)
)
generatedKeyPair, err := keypair.Generate() generatedKeyPair, err := keypair.Generate()
if err != nil { if err != nil {
@ -29,6 +30,7 @@ func CreateAccountHandler(cu *custodial.Custodial) func(echo.Context) error {
return err return err
} }
trackingId := uuid.NewString()
taskPayload, err := json.Marshal(task.AccountPayload{ taskPayload, err := json.Marshal(task.AccountPayload{
PublicKey: generatedKeyPair.Public, PublicKey: generatedKeyPair.Public,
TrackingId: trackingId, TrackingId: trackingId,
@ -38,6 +40,7 @@ func CreateAccountHandler(cu *custodial.Custodial) func(echo.Context) error {
} }
_, err = cu.TaskerClient.CreateTask( _, err = cu.TaskerClient.CreateTask(
c.Request().Context(),
tasker.AccountPrepareTask, tasker.AccountPrepareTask,
tasker.DefaultPriority, tasker.DefaultPriority,
&tasker.Task{ &tasker.Task{
@ -58,4 +61,3 @@ func CreateAccountHandler(cu *custodial.Custodial) func(echo.Context) error {
}, },
}) })
} }
}

11
internal/api/errors.go Normal file
View File

@ -0,0 +1,11 @@
package api
import (
"net/http"
"github.com/labstack/echo/v4"
)
func NewBadRequestError(message ...interface{}) *echo.HTTPError {
return echo.NewHTTPError(http.StatusBadRequest, message...)
}

View File

@ -12,7 +12,7 @@ import (
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
) )
// SignTxHandler route. // HandleSignTransfer route.
// POST: /api/sign/transfer // POST: /api/sign/transfer
// JSON Body: // JSON Body:
// from -> ETH address // from -> ETH address
@ -21,38 +21,39 @@ import (
// amount -> int (6 d.p. precision) // amount -> int (6 d.p. precision)
// e.g. 1000000 = 1 VOUCHER // e.g. 1000000 = 1 VOUCHER
// Returns the task id. // Returns the task id.
func SignTransferHandler(cu *custodial.Custodial) func(echo.Context) error { func HandleSignTransfer(c echo.Context) error {
return func(c echo.Context) error { var (
trackingId := uuid.NewString() cu = c.Get("cu").(*custodial.Custodial)
req struct {
var transferRequest struct {
From string `json:"from" validate:"required,eth_checksum"` From string `json:"from" validate:"required,eth_checksum"`
To string `json:"to" validate:"required,eth_checksum"` To string `json:"to" validate:"required,eth_checksum"`
VoucherAddress string `json:"voucherAddress" validate:"required,eth_checksum"` VoucherAddress string `json:"voucherAddress" validate:"required,eth_checksum"`
Amount uint64 `json:"amount" validate:"required,numeric"` Amount uint64 `json:"amount" validate:"required"`
}
)
if err := c.Bind(&req); err != nil {
return NewBadRequestError(err)
} }
if err := c.Bind(&transferRequest); err != nil { if err := c.Validate(req); err != nil {
return err return err
} }
if err := c.Validate(transferRequest); err != nil { trackingId := uuid.NewString()
return err
}
// TODO: Checksum addresses
taskPayload, err := json.Marshal(task.TransferPayload{ taskPayload, err := json.Marshal(task.TransferPayload{
TrackingId: trackingId, TrackingId: trackingId,
From: transferRequest.From, From: req.From,
To: transferRequest.To, To: req.To,
VoucherAddress: transferRequest.VoucherAddress, VoucherAddress: req.VoucherAddress,
Amount: transferRequest.Amount, Amount: req.Amount,
}) })
if err != nil { if err != nil {
return err return err
} }
_, err = cu.TaskerClient.CreateTask( _, err = cu.TaskerClient.CreateTask(
c.Request().Context(),
tasker.SignTransferTask, tasker.SignTransferTask,
tasker.HighPriority, tasker.HighPriority,
&tasker.Task{ &tasker.Task{
@ -71,4 +72,3 @@ func SignTransferHandler(cu *custodial.Custodial) func(echo.Context) error {
}, },
}) })
} }
}

View File

@ -3,26 +3,32 @@ package api
import ( import (
"net/http" "net/http"
"github.com/grassrootseconomics/cic-custodial/internal/store" "github.com/grassrootseconomics/cic-custodial/internal/custodial"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
) )
func TxStatus(store store.Store) func(echo.Context) error { // HandleTxStatus route.
return func(c echo.Context) error { // GET: /api/track/:trackingId
var txStatusRequest struct { // Route param:
// trackingId -> tracking UUID
// Returns array of tx status.
func HandleTrackTx(c echo.Context) error {
var (
cu = c.Get("cu").(*custodial.Custodial)
txStatusRequest struct {
TrackingId string `param:"trackingId" validate:"required,uuid"` TrackingId string `param:"trackingId" validate:"required,uuid"`
} }
)
if err := c.Bind(&txStatusRequest); err != nil { if err := c.Bind(&txStatusRequest); err != nil {
return err return NewBadRequestError(err)
} }
if err := c.Validate(txStatusRequest); err != nil { if err := c.Validate(txStatusRequest); err != nil {
return err return err
} }
// TODO: handle potential timeouts txs, err := cu.PgStore.GetTxStatusByTrackingId(c.Request().Context(), txStatusRequest.TrackingId)
txs, err := store.GetTxStatusByTrackingId(c.Request().Context(), txStatusRequest.TrackingId)
if err != nil { if err != nil {
return err return err
} }
@ -34,4 +40,3 @@ func TxStatus(store store.Store) func(echo.Context) error {
}, },
}) })
} }
}

View File

@ -1,11 +1,5 @@
package api package api
const (
INTERNAL_ERROR = "ERR_INTERNAL"
VALIDATION_ERROR = "ERR_VALIDATE"
DUPLICATE_ERROR = "ERR_DUPLICATE"
)
type H map[string]any type H map[string]any
type OkResp struct { type OkResp struct {
@ -15,6 +9,5 @@ type OkResp struct {
type ErrResp struct { type ErrResp struct {
Ok bool `json:"ok"` Ok bool `json:"ok"`
Code string `json:"errorCode"`
Message string `json:"message"` Message string `json:"message"`
} }

View File

@ -11,7 +11,9 @@ type Validator struct {
func (v *Validator) Validate(i interface{}) error { func (v *Validator) Validate(i interface{}) error {
if err := v.ValidatorProvider.Struct(i); err != nil { if err := v.ValidatorProvider.Struct(i); err != nil {
return err if _, ok := err.(validator.ValidationErrors); ok {
return NewBadRequestError(err.(validator.ValidationErrors).Error())
}
} }
return nil return nil
} }

View File

@ -1,6 +1,7 @@
package tasker package tasker
import ( import (
"context"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
@ -28,7 +29,7 @@ func NewTaskerClient(o TaskerClientOpts) *TaskerClient {
} }
} }
func (c *TaskerClient) CreateTask(taskName TaskName, queueName QueueName, task *Task) (*asynq.TaskInfo, error) { func (c *TaskerClient) CreateTask(ctx context.Context, taskName TaskName, queueName QueueName, task *Task) (*asynq.TaskInfo, error) {
if task.Id == "" { if task.Id == "" {
task.Id = uuid.NewString() task.Id = uuid.NewString()
} }
@ -42,7 +43,7 @@ func (c *TaskerClient) CreateTask(taskName TaskName, queueName QueueName, task *
asynq.Timeout(taskTimeout*time.Second), asynq.Timeout(taskTimeout*time.Second),
) )
taskInfo, err := c.Client.Enqueue(qTask) taskInfo, err := c.Client.EnqueueContext(ctx, qTask)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -93,6 +93,7 @@ func AccountGiftGasProcessor(cu *custodial.Custodial) func(context.Context, *asy
} }
_, err = cu.TaskerClient.CreateTask( _, err = cu.TaskerClient.CreateTask(
ctx,
tasker.DispatchTxTask, tasker.DispatchTxTask,
tasker.HighPriority, tasker.HighPriority,
&tasker.Task{ &tasker.Task{

View File

@ -103,6 +103,7 @@ func GiftVoucherProcessor(cu *custodial.Custodial) func(context.Context, *asynq.
} }
_, err = cu.TaskerClient.CreateTask( _, err = cu.TaskerClient.CreateTask(
ctx,
tasker.DispatchTxTask, tasker.DispatchTxTask,
tasker.HighPriority, tasker.HighPriority,
&tasker.Task{ &tasker.Task{

View File

@ -29,6 +29,7 @@ func AccountPrepare(cu *custodial.Custodial) func(context.Context, *asynq.Task)
} }
_, err := cu.TaskerClient.CreateTask( _, err := cu.TaskerClient.CreateTask(
ctx,
tasker.AccountRegisterTask, tasker.AccountRegisterTask,
tasker.DefaultPriority, tasker.DefaultPriority,
&tasker.Task{ &tasker.Task{
@ -40,6 +41,7 @@ func AccountPrepare(cu *custodial.Custodial) func(context.Context, *asynq.Task)
} }
_, err = cu.TaskerClient.CreateTask( _, err = cu.TaskerClient.CreateTask(
ctx,
tasker.AccountGiftGasTask, tasker.AccountGiftGasTask,
tasker.DefaultPriority, tasker.DefaultPriority,
&tasker.Task{ &tasker.Task{
@ -51,6 +53,7 @@ func AccountPrepare(cu *custodial.Custodial) func(context.Context, *asynq.Task)
} }
_, err = cu.TaskerClient.CreateTask( _, err = cu.TaskerClient.CreateTask(
ctx,
tasker.AccountGiftVoucherTask, tasker.AccountGiftVoucherTask,
tasker.DefaultPriority, tasker.DefaultPriority,
&tasker.Task{ &tasker.Task{

View File

@ -108,6 +108,7 @@ func AccountRefillGasProcessor(cu *custodial.Custodial) func(context.Context, *a
} }
_, err = cu.TaskerClient.CreateTask( _, err = cu.TaskerClient.CreateTask(
ctx,
tasker.DispatchTxTask, tasker.DispatchTxTask,
tasker.HighPriority, tasker.HighPriority,
&tasker.Task{ &tasker.Task{

View File

@ -101,6 +101,7 @@ func AccountRegisterOnChainProcessor(cu *custodial.Custodial) func(context.Conte
} }
_, err = cu.TaskerClient.CreateTask( _, err = cu.TaskerClient.CreateTask(
ctx,
tasker.DispatchTxTask, tasker.DispatchTxTask,
tasker.HighPriority, tasker.HighPriority,
&tasker.Task{ &tasker.Task{

View File

@ -123,6 +123,7 @@ func SignTransfer(cu *custodial.Custodial) func(context.Context, *asynq.Task) er
} }
_, err = cu.TaskerClient.CreateTask( _, err = cu.TaskerClient.CreateTask(
ctx,
tasker.DispatchTxTask, tasker.DispatchTxTask,
tasker.HighPriority, tasker.HighPriority,
&tasker.Task{ &tasker.Task{
@ -141,6 +142,7 @@ func SignTransfer(cu *custodial.Custodial) func(context.Context, *asynq.Task) er
} }
_, err = cu.TaskerClient.CreateTask( _, err = cu.TaskerClient.CreateTask(
ctx,
tasker.AccountRefillGasTask, tasker.AccountRefillGasTask,
tasker.DefaultPriority, tasker.DefaultPriority,
&tasker.Task{ &tasker.Task{