refactor: ctx propagation, api handlers

* use context timeout middleware for correct ctx propagation
* Fix bind error handling
* Fix validation error handling
* Fix HTTP error handling (4XX)
* tasker client now  accepts ctx
* add recovery and body size middleware
This commit is contained in:
Mohamed Sohail 2023-02-24 16:46:46 +00:00
parent ce6bdbf4ed
commit add7f2a442
Signed by: kamikazechaser
GPG Key ID: 7DD45520C01CD85D
16 changed files with 201 additions and 161 deletions

View File

@ -1,53 +1,45 @@
package main
import (
"errors"
"net/http"
"time"
"github.com/VictoriaMetrics/metrics"
"github.com/go-playground/validator/v10"
"github.com/grassrootseconomics/cic-custodial/internal/api"
"github.com/grassrootseconomics/cic-custodial/internal/custodial"
"github.com/hibiken/asynq"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
)
const (
contextTimeout = 5
)
// Bootstrap API server.
func initApiServer(custodialContainer *custodial.Custodial) *echo.Echo {
lo.Debug("api: bootstrapping api server")
customValidator := validator.New()
customValidator.RegisterValidation("eth_checksum", api.EthChecksumValidator)
server := echo.New()
server.HideBanner = true
server.HidePort = true
server.HTTPErrorHandler = func(err error, c echo.Context) {
// Handle asynq duplication errors across all api handlers.
if errors.Is(err, asynq.ErrTaskIDConflict) {
c.JSON(http.StatusForbidden, api.ErrResp{
Ok: false,
Code: api.DUPLICATE_ERROR,
Message: "Request with duplicate tracking id submitted.",
})
return
server.Validator = &api.Validator{
ValidatorProvider: customValidator,
}
if _, ok := err.(validator.ValidationErrors); ok {
c.JSON(http.StatusForbidden, api.ErrResp{
Ok: false,
Code: api.VALIDATION_ERROR,
Message: err.(validator.ValidationErrors).Error(),
})
return
}
server.HTTPErrorHandler = customHTTPErrorHandler
// Log internal server error for further investigation.
lo.Error("api:", "path", c.Path(), "err", err)
c.JSON(http.StatusInternalServerError, api.ErrResp{
Ok: false,
Code: api.INTERNAL_ERROR,
Message: "Internal server error.",
})
server.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
c.Set("cu", custodialContainer)
return next(c)
}
})
server.Use(middleware.Recover())
server.Use(middleware.BodyLimit("1M"))
server.Use(middleware.ContextTimeout(time.Duration(contextTimeout * time.Second)))
if ko.Bool("service.metrics") {
server.GET("/metrics", func(c echo.Context) error {
@ -56,17 +48,39 @@ func initApiServer(custodialContainer *custodial.Custodial) *echo.Echo {
})
}
customValidator := validator.New()
customValidator.RegisterValidation("eth_checksum", api.EthChecksumValidator)
server.Validator = &api.Validator{
ValidatorProvider: customValidator,
}
apiRoute := server.Group("/api")
apiRoute.POST("/account/create", api.CreateAccountHandler(custodialContainer))
apiRoute.POST("/sign/transfer", api.SignTransferHandler(custodialContainer))
apiRoute.GET("/track/:trackingId", api.TxStatus(custodialContainer.PgStore))
apiRoute.POST("/account/create", api.HandleAccountCreate)
apiRoute.POST("/sign/transfer", api.HandleSignTransfer)
apiRoute.GET("/track/:trackingId", api.HandleTrackTx)
return server
}
func customHTTPErrorHandler(err error, c echo.Context) {
if c.Response().Committed {
return
}
he, ok := err.(*echo.HTTPError)
if ok {
var errorMsg string
if m, ok := he.Message.(error); ok {
errorMsg = m.Error()
} else if m, ok := he.Message.(string); ok {
errorMsg = m
}
c.JSON(he.Code, api.ErrResp{
Ok: false,
Message: errorMsg,
})
} else {
lo.Error("api: echo error", "path", c.Path(), "err", err)
c.JSON(http.StatusInternalServerError, api.ErrResp{
Ok: false,
Message: "Internal server error.",
})
}
}

1
go.mod
View File

@ -40,6 +40,7 @@ require (
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-stack/stack v1.8.1 // indirect
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
github.com/golang/protobuf v1.5.2 // indirect
github.com/golang/snappy v0.0.4 // indirect
github.com/google/go-cmp v0.5.9 // indirect

2
go.sum
View File

@ -226,6 +226,8 @@ github.com/gofrs/uuid v3.3.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRx
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
github.com/golang/geo v0.0.0-20190916061304-5b978397cfec/go.mod h1:QZ0nwyI2jOfgRAoBvP+ab5aRr7c9x7lhGEJrKvBwjWI=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=

View File

@ -15,9 +15,10 @@ import (
// CreateAccountHandler route.
// POST: /api/account/create
// Returns the public key.
func CreateAccountHandler(cu *custodial.Custodial) func(echo.Context) error {
return func(c echo.Context) error {
trackingId := uuid.NewString()
func HandleAccountCreate(c echo.Context) error {
var (
cu = c.Get("cu").(*custodial.Custodial)
)
generatedKeyPair, err := keypair.Generate()
if err != nil {
@ -29,6 +30,7 @@ func CreateAccountHandler(cu *custodial.Custodial) func(echo.Context) error {
return err
}
trackingId := uuid.NewString()
taskPayload, err := json.Marshal(task.AccountPayload{
PublicKey: generatedKeyPair.Public,
TrackingId: trackingId,
@ -38,6 +40,7 @@ func CreateAccountHandler(cu *custodial.Custodial) func(echo.Context) error {
}
_, err = cu.TaskerClient.CreateTask(
c.Request().Context(),
tasker.AccountPrepareTask,
tasker.DefaultPriority,
&tasker.Task{
@ -58,4 +61,3 @@ func CreateAccountHandler(cu *custodial.Custodial) func(echo.Context) error {
},
})
}
}

11
internal/api/errors.go Normal file
View File

@ -0,0 +1,11 @@
package api
import (
"net/http"
"github.com/labstack/echo/v4"
)
func NewBadRequestError(message ...interface{}) *echo.HTTPError {
return echo.NewHTTPError(http.StatusBadRequest, message...)
}

View File

@ -12,7 +12,7 @@ import (
"github.com/labstack/echo/v4"
)
// SignTxHandler route.
// HandleSignTransfer route.
// POST: /api/sign/transfer
// JSON Body:
// from -> ETH address
@ -21,38 +21,39 @@ import (
// amount -> int (6 d.p. precision)
// e.g. 1000000 = 1 VOUCHER
// Returns the task id.
func SignTransferHandler(cu *custodial.Custodial) func(echo.Context) error {
return func(c echo.Context) error {
trackingId := uuid.NewString()
var transferRequest struct {
func HandleSignTransfer(c echo.Context) error {
var (
cu = c.Get("cu").(*custodial.Custodial)
req struct {
From string `json:"from" validate:"required,eth_checksum"`
To string `json:"to" validate:"required,eth_checksum"`
VoucherAddress string `json:"voucherAddress" validate:"required,eth_checksum"`
Amount uint64 `json:"amount" validate:"required,numeric"`
Amount uint64 `json:"amount" validate:"required"`
}
)
if err := c.Bind(&req); err != nil {
return NewBadRequestError(err)
}
if err := c.Bind(&transferRequest); err != nil {
if err := c.Validate(req); err != nil {
return err
}
if err := c.Validate(transferRequest); err != nil {
return err
}
// TODO: Checksum addresses
trackingId := uuid.NewString()
taskPayload, err := json.Marshal(task.TransferPayload{
TrackingId: trackingId,
From: transferRequest.From,
To: transferRequest.To,
VoucherAddress: transferRequest.VoucherAddress,
Amount: transferRequest.Amount,
From: req.From,
To: req.To,
VoucherAddress: req.VoucherAddress,
Amount: req.Amount,
})
if err != nil {
return err
}
_, err = cu.TaskerClient.CreateTask(
c.Request().Context(),
tasker.SignTransferTask,
tasker.HighPriority,
&tasker.Task{
@ -71,4 +72,3 @@ func SignTransferHandler(cu *custodial.Custodial) func(echo.Context) error {
},
})
}
}

View File

@ -3,26 +3,32 @@ package api
import (
"net/http"
"github.com/grassrootseconomics/cic-custodial/internal/store"
"github.com/grassrootseconomics/cic-custodial/internal/custodial"
"github.com/labstack/echo/v4"
)
func TxStatus(store store.Store) func(echo.Context) error {
return func(c echo.Context) error {
var txStatusRequest struct {
// HandleTxStatus route.
// GET: /api/track/:trackingId
// Route param:
// trackingId -> tracking UUID
// Returns array of tx status.
func HandleTrackTx(c echo.Context) error {
var (
cu = c.Get("cu").(*custodial.Custodial)
txStatusRequest struct {
TrackingId string `param:"trackingId" validate:"required,uuid"`
}
)
if err := c.Bind(&txStatusRequest); err != nil {
return err
return NewBadRequestError(err)
}
if err := c.Validate(txStatusRequest); err != nil {
return err
}
// TODO: handle potential timeouts
txs, err := store.GetTxStatusByTrackingId(c.Request().Context(), txStatusRequest.TrackingId)
txs, err := cu.PgStore.GetTxStatusByTrackingId(c.Request().Context(), txStatusRequest.TrackingId)
if err != nil {
return err
}
@ -34,4 +40,3 @@ func TxStatus(store store.Store) func(echo.Context) error {
},
})
}
}

View File

@ -1,11 +1,5 @@
package api
const (
INTERNAL_ERROR = "ERR_INTERNAL"
VALIDATION_ERROR = "ERR_VALIDATE"
DUPLICATE_ERROR = "ERR_DUPLICATE"
)
type H map[string]any
type OkResp struct {
@ -15,6 +9,5 @@ type OkResp struct {
type ErrResp struct {
Ok bool `json:"ok"`
Code string `json:"errorCode"`
Message string `json:"message"`
}

View File

@ -11,7 +11,9 @@ type Validator struct {
func (v *Validator) Validate(i interface{}) error {
if err := v.ValidatorProvider.Struct(i); err != nil {
return err
if _, ok := err.(validator.ValidationErrors); ok {
return NewBadRequestError(err.(validator.ValidationErrors).Error())
}
}
return nil
}

View File

@ -1,6 +1,7 @@
package tasker
import (
"context"
"time"
"github.com/google/uuid"
@ -28,7 +29,7 @@ func NewTaskerClient(o TaskerClientOpts) *TaskerClient {
}
}
func (c *TaskerClient) CreateTask(taskName TaskName, queueName QueueName, task *Task) (*asynq.TaskInfo, error) {
func (c *TaskerClient) CreateTask(ctx context.Context, taskName TaskName, queueName QueueName, task *Task) (*asynq.TaskInfo, error) {
if task.Id == "" {
task.Id = uuid.NewString()
}
@ -42,7 +43,7 @@ func (c *TaskerClient) CreateTask(taskName TaskName, queueName QueueName, task *
asynq.Timeout(taskTimeout*time.Second),
)
taskInfo, err := c.Client.Enqueue(qTask)
taskInfo, err := c.Client.EnqueueContext(ctx, qTask)
if err != nil {
return nil, err
}

View File

@ -93,6 +93,7 @@ func AccountGiftGasProcessor(cu *custodial.Custodial) func(context.Context, *asy
}
_, err = cu.TaskerClient.CreateTask(
ctx,
tasker.DispatchTxTask,
tasker.HighPriority,
&tasker.Task{

View File

@ -103,6 +103,7 @@ func GiftVoucherProcessor(cu *custodial.Custodial) func(context.Context, *asynq.
}
_, err = cu.TaskerClient.CreateTask(
ctx,
tasker.DispatchTxTask,
tasker.HighPriority,
&tasker.Task{

View File

@ -29,6 +29,7 @@ func AccountPrepare(cu *custodial.Custodial) func(context.Context, *asynq.Task)
}
_, err := cu.TaskerClient.CreateTask(
ctx,
tasker.AccountRegisterTask,
tasker.DefaultPriority,
&tasker.Task{
@ -40,6 +41,7 @@ func AccountPrepare(cu *custodial.Custodial) func(context.Context, *asynq.Task)
}
_, err = cu.TaskerClient.CreateTask(
ctx,
tasker.AccountGiftGasTask,
tasker.DefaultPriority,
&tasker.Task{
@ -51,6 +53,7 @@ func AccountPrepare(cu *custodial.Custodial) func(context.Context, *asynq.Task)
}
_, err = cu.TaskerClient.CreateTask(
ctx,
tasker.AccountGiftVoucherTask,
tasker.DefaultPriority,
&tasker.Task{

View File

@ -108,6 +108,7 @@ func AccountRefillGasProcessor(cu *custodial.Custodial) func(context.Context, *a
}
_, err = cu.TaskerClient.CreateTask(
ctx,
tasker.DispatchTxTask,
tasker.HighPriority,
&tasker.Task{

View File

@ -101,6 +101,7 @@ func AccountRegisterOnChainProcessor(cu *custodial.Custodial) func(context.Conte
}
_, err = cu.TaskerClient.CreateTask(
ctx,
tasker.DispatchTxTask,
tasker.HighPriority,
&tasker.Task{

View File

@ -123,6 +123,7 @@ func SignTransfer(cu *custodial.Custodial) func(context.Context, *asynq.Task) er
}
_, err = cu.TaskerClient.CreateTask(
ctx,
tasker.DispatchTxTask,
tasker.HighPriority,
&tasker.Task{
@ -141,6 +142,7 @@ func SignTransfer(cu *custodial.Custodial) func(context.Context, *asynq.Task) er
}
_, err = cu.TaskerClient.CreateTask(
ctx,
tasker.AccountRefillGasTask,
tasker.DefaultPriority,
&tasker.Task{