From add7f2a4428ff90d17fb8f54c5373a0e9bd9fe20 Mon Sep 17 00:00:00 2001 From: Mohammed Sohail Date: Fri, 24 Feb 2023 16:46:46 +0000 Subject: [PATCH] refactor: ctx propagation, api handlers * use context timeout middleware for correct ctx propagation * Fix bind error handling * Fix validation error handling * Fix HTTP error handling (4XX) * tasker client now accepts ctx * add recovery and body size middleware --- cmd/service/api.go | 96 +++++++++++-------- go.mod | 1 + go.sum | 2 + internal/api/account.go | 84 ++++++++-------- internal/api/errors.go | 11 +++ internal/api/sign.go | 90 ++++++++--------- internal/api/track.go | 53 +++++----- internal/api/types.go | 7 -- internal/api/validator.go | 4 +- internal/tasker/client.go | 5 +- internal/tasker/task/account_gift_gas.go | 1 + internal/tasker/task/account_gift_voucher.go | 1 + internal/tasker/task/account_prepare.go | 3 + internal/tasker/task/account_refill_gas.go | 1 + .../tasker/task/account_register_onchain.go | 1 + internal/tasker/task/sign_transfer.go | 2 + 16 files changed, 201 insertions(+), 161 deletions(-) create mode 100644 internal/api/errors.go diff --git a/cmd/service/api.go b/cmd/service/api.go index 57f4f87..b6a9773 100644 --- a/cmd/service/api.go +++ b/cmd/service/api.go @@ -1,54 +1,46 @@ package main import ( - "errors" "net/http" + "time" "github.com/VictoriaMetrics/metrics" "github.com/go-playground/validator/v10" "github.com/grassrootseconomics/cic-custodial/internal/api" "github.com/grassrootseconomics/cic-custodial/internal/custodial" - "github.com/hibiken/asynq" "github.com/labstack/echo/v4" + "github.com/labstack/echo/v4/middleware" +) + +const ( + contextTimeout = 5 ) // Bootstrap API server. func initApiServer(custodialContainer *custodial.Custodial) *echo.Echo { - lo.Debug("api: bootstrapping api server") + customValidator := validator.New() + customValidator.RegisterValidation("eth_checksum", api.EthChecksumValidator) + server := echo.New() server.HideBanner = true server.HidePort = true - server.HTTPErrorHandler = func(err error, c echo.Context) { - // Handle asynq duplication errors across all api handlers. - if errors.Is(err, asynq.ErrTaskIDConflict) { - c.JSON(http.StatusForbidden, api.ErrResp{ - Ok: false, - Code: api.DUPLICATE_ERROR, - Message: "Request with duplicate tracking id submitted.", - }) - return - } - - if _, ok := err.(validator.ValidationErrors); ok { - c.JSON(http.StatusForbidden, api.ErrResp{ - Ok: false, - Code: api.VALIDATION_ERROR, - Message: err.(validator.ValidationErrors).Error(), - }) - return - } - - // Log internal server error for further investigation. - lo.Error("api:", "path", c.Path(), "err", err) - - c.JSON(http.StatusInternalServerError, api.ErrResp{ - Ok: false, - Code: api.INTERNAL_ERROR, - Message: "Internal server error.", - }) + server.Validator = &api.Validator{ + ValidatorProvider: customValidator, } + server.HTTPErrorHandler = customHTTPErrorHandler + + server.Use(func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + c.Set("cu", custodialContainer) + return next(c) + } + }) + server.Use(middleware.Recover()) + server.Use(middleware.BodyLimit("1M")) + server.Use(middleware.ContextTimeout(time.Duration(contextTimeout * time.Second))) + if ko.Bool("service.metrics") { server.GET("/metrics", func(c echo.Context) error { metrics.WritePrometheus(c.Response(), true) @@ -56,17 +48,39 @@ func initApiServer(custodialContainer *custodial.Custodial) *echo.Echo { }) } - customValidator := validator.New() - customValidator.RegisterValidation("eth_checksum", api.EthChecksumValidator) - - server.Validator = &api.Validator{ - ValidatorProvider: customValidator, - } - apiRoute := server.Group("/api") - apiRoute.POST("/account/create", api.CreateAccountHandler(custodialContainer)) - apiRoute.POST("/sign/transfer", api.SignTransferHandler(custodialContainer)) - apiRoute.GET("/track/:trackingId", api.TxStatus(custodialContainer.PgStore)) + apiRoute.POST("/account/create", api.HandleAccountCreate) + apiRoute.POST("/sign/transfer", api.HandleSignTransfer) + apiRoute.GET("/track/:trackingId", api.HandleTrackTx) return server } + +func customHTTPErrorHandler(err error, c echo.Context) { + if c.Response().Committed { + return + } + + he, ok := err.(*echo.HTTPError) + if ok { + var errorMsg string + + if m, ok := he.Message.(error); ok { + errorMsg = m.Error() + } else if m, ok := he.Message.(string); ok { + errorMsg = m + } + + c.JSON(he.Code, api.ErrResp{ + Ok: false, + Message: errorMsg, + }) + } else { + lo.Error("api: echo error", "path", c.Path(), "err", err) + + c.JSON(http.StatusInternalServerError, api.ErrResp{ + Ok: false, + Message: "Internal server error.", + }) + } +} diff --git a/go.mod b/go.mod index fa09467..a0082f7 100644 --- a/go.mod +++ b/go.mod @@ -40,6 +40,7 @@ require ( github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-stack/stack v1.8.1 // indirect + github.com/golang-jwt/jwt v3.2.2+incompatible // indirect github.com/golang/protobuf v1.5.2 // indirect github.com/golang/snappy v0.0.4 // indirect github.com/google/go-cmp v0.5.9 // indirect diff --git a/go.sum b/go.sum index 568e28d..d188f77 100644 --- a/go.sum +++ b/go.sum @@ -226,6 +226,8 @@ github.com/gofrs/uuid v3.3.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRx github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= +github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= github.com/golang/geo v0.0.0-20190916061304-5b978397cfec/go.mod h1:QZ0nwyI2jOfgRAoBvP+ab5aRr7c9x7lhGEJrKvBwjWI= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= diff --git a/internal/api/account.go b/internal/api/account.go index d72451b..c57f446 100644 --- a/internal/api/account.go +++ b/internal/api/account.go @@ -15,47 +15,49 @@ import ( // CreateAccountHandler route. // POST: /api/account/create // Returns the public key. -func CreateAccountHandler(cu *custodial.Custodial) func(echo.Context) error { - return func(c echo.Context) error { - trackingId := uuid.NewString() +func HandleAccountCreate(c echo.Context) error { + var ( + cu = c.Get("cu").(*custodial.Custodial) + ) - generatedKeyPair, err := keypair.Generate() - if err != nil { - return err - } - - id, err := cu.Keystore.WriteKeyPair(c.Request().Context(), generatedKeyPair) - if err != nil { - return err - } - - taskPayload, err := json.Marshal(task.AccountPayload{ - PublicKey: generatedKeyPair.Public, - TrackingId: trackingId, - }) - if err != nil { - return err - } - - _, err = cu.TaskerClient.CreateTask( - tasker.AccountPrepareTask, - tasker.DefaultPriority, - &tasker.Task{ - Id: trackingId, - Payload: taskPayload, - }, - ) - if err != nil { - return err - } - - return c.JSON(http.StatusOK, OkResp{ - Ok: true, - Result: H{ - "publicKey": generatedKeyPair.Public, - "custodialId": id, - "trackingId": trackingId, - }, - }) + generatedKeyPair, err := keypair.Generate() + if err != nil { + return err } + + id, err := cu.Keystore.WriteKeyPair(c.Request().Context(), generatedKeyPair) + if err != nil { + return err + } + + trackingId := uuid.NewString() + taskPayload, err := json.Marshal(task.AccountPayload{ + PublicKey: generatedKeyPair.Public, + TrackingId: trackingId, + }) + if err != nil { + return err + } + + _, err = cu.TaskerClient.CreateTask( + c.Request().Context(), + tasker.AccountPrepareTask, + tasker.DefaultPriority, + &tasker.Task{ + Id: trackingId, + Payload: taskPayload, + }, + ) + if err != nil { + return err + } + + return c.JSON(http.StatusOK, OkResp{ + Ok: true, + Result: H{ + "publicKey": generatedKeyPair.Public, + "custodialId": id, + "trackingId": trackingId, + }, + }) } diff --git a/internal/api/errors.go b/internal/api/errors.go new file mode 100644 index 0000000..001a776 --- /dev/null +++ b/internal/api/errors.go @@ -0,0 +1,11 @@ +package api + +import ( + "net/http" + + "github.com/labstack/echo/v4" +) + +func NewBadRequestError(message ...interface{}) *echo.HTTPError { + return echo.NewHTTPError(http.StatusBadRequest, message...) +} diff --git a/internal/api/sign.go b/internal/api/sign.go index 9b128bb..ddb41d5 100644 --- a/internal/api/sign.go +++ b/internal/api/sign.go @@ -12,7 +12,7 @@ import ( "github.com/labstack/echo/v4" ) -// SignTxHandler route. +// HandleSignTransfer route. // POST: /api/sign/transfer // JSON Body: // from -> ETH address @@ -21,54 +21,54 @@ import ( // amount -> int (6 d.p. precision) // e.g. 1000000 = 1 VOUCHER // Returns the task id. -func SignTransferHandler(cu *custodial.Custodial) func(echo.Context) error { - return func(c echo.Context) error { - trackingId := uuid.NewString() - - var transferRequest struct { +func HandleSignTransfer(c echo.Context) error { + var ( + cu = c.Get("cu").(*custodial.Custodial) + req struct { From string `json:"from" validate:"required,eth_checksum"` To string `json:"to" validate:"required,eth_checksum"` VoucherAddress string `json:"voucherAddress" validate:"required,eth_checksum"` - Amount uint64 `json:"amount" validate:"required,numeric"` + Amount uint64 `json:"amount" validate:"required"` } + ) - if err := c.Bind(&transferRequest); err != nil { - return err - } - - if err := c.Validate(transferRequest); err != nil { - return err - } - - // TODO: Checksum addresses - taskPayload, err := json.Marshal(task.TransferPayload{ - TrackingId: trackingId, - From: transferRequest.From, - To: transferRequest.To, - VoucherAddress: transferRequest.VoucherAddress, - Amount: transferRequest.Amount, - }) - if err != nil { - return err - } - - _, err = cu.TaskerClient.CreateTask( - tasker.SignTransferTask, - tasker.HighPriority, - &tasker.Task{ - Id: trackingId, - Payload: taskPayload, - }, - ) - if err != nil { - return err - } - - return c.JSON(http.StatusOK, OkResp{ - Ok: true, - Result: H{ - "trackingId": trackingId, - }, - }) + if err := c.Bind(&req); err != nil { + return NewBadRequestError(err) } + + if err := c.Validate(req); err != nil { + return err + } + + trackingId := uuid.NewString() + taskPayload, err := json.Marshal(task.TransferPayload{ + TrackingId: trackingId, + From: req.From, + To: req.To, + VoucherAddress: req.VoucherAddress, + Amount: req.Amount, + }) + if err != nil { + return err + } + + _, err = cu.TaskerClient.CreateTask( + c.Request().Context(), + tasker.SignTransferTask, + tasker.HighPriority, + &tasker.Task{ + Id: trackingId, + Payload: taskPayload, + }, + ) + if err != nil { + return err + } + + return c.JSON(http.StatusOK, OkResp{ + Ok: true, + Result: H{ + "trackingId": trackingId, + }, + }) } diff --git a/internal/api/track.go b/internal/api/track.go index 398f938..26f8adb 100644 --- a/internal/api/track.go +++ b/internal/api/track.go @@ -3,35 +3,40 @@ package api import ( "net/http" - "github.com/grassrootseconomics/cic-custodial/internal/store" + "github.com/grassrootseconomics/cic-custodial/internal/custodial" "github.com/labstack/echo/v4" ) -func TxStatus(store store.Store) func(echo.Context) error { - return func(c echo.Context) error { - var txStatusRequest struct { +// HandleTxStatus route. +// GET: /api/track/:trackingId +// Route param: +// trackingId -> tracking UUID +// Returns array of tx status. +func HandleTrackTx(c echo.Context) error { + var ( + cu = c.Get("cu").(*custodial.Custodial) + txStatusRequest struct { TrackingId string `param:"trackingId" validate:"required,uuid"` } + ) - if err := c.Bind(&txStatusRequest); err != nil { - return err - } - - if err := c.Validate(txStatusRequest); err != nil { - return err - } - - // TODO: handle potential timeouts - txs, err := store.GetTxStatusByTrackingId(c.Request().Context(), txStatusRequest.TrackingId) - if err != nil { - return err - } - - return c.JSON(http.StatusOK, OkResp{ - Ok: true, - Result: H{ - "transactions": txs, - }, - }) + if err := c.Bind(&txStatusRequest); err != nil { + return NewBadRequestError(err) } + + if err := c.Validate(txStatusRequest); err != nil { + return err + } + + txs, err := cu.PgStore.GetTxStatusByTrackingId(c.Request().Context(), txStatusRequest.TrackingId) + if err != nil { + return err + } + + return c.JSON(http.StatusOK, OkResp{ + Ok: true, + Result: H{ + "transactions": txs, + }, + }) } diff --git a/internal/api/types.go b/internal/api/types.go index cc236d5..6d59b1a 100644 --- a/internal/api/types.go +++ b/internal/api/types.go @@ -1,11 +1,5 @@ package api -const ( - INTERNAL_ERROR = "ERR_INTERNAL" - VALIDATION_ERROR = "ERR_VALIDATE" - DUPLICATE_ERROR = "ERR_DUPLICATE" -) - type H map[string]any type OkResp struct { @@ -15,6 +9,5 @@ type OkResp struct { type ErrResp struct { Ok bool `json:"ok"` - Code string `json:"errorCode"` Message string `json:"message"` } diff --git a/internal/api/validator.go b/internal/api/validator.go index 6e5edfb..713baf4 100644 --- a/internal/api/validator.go +++ b/internal/api/validator.go @@ -11,7 +11,9 @@ type Validator struct { func (v *Validator) Validate(i interface{}) error { if err := v.ValidatorProvider.Struct(i); err != nil { - return err + if _, ok := err.(validator.ValidationErrors); ok { + return NewBadRequestError(err.(validator.ValidationErrors).Error()) + } } return nil } diff --git a/internal/tasker/client.go b/internal/tasker/client.go index 6a44236..8510bf3 100644 --- a/internal/tasker/client.go +++ b/internal/tasker/client.go @@ -1,6 +1,7 @@ package tasker import ( + "context" "time" "github.com/google/uuid" @@ -28,7 +29,7 @@ func NewTaskerClient(o TaskerClientOpts) *TaskerClient { } } -func (c *TaskerClient) CreateTask(taskName TaskName, queueName QueueName, task *Task) (*asynq.TaskInfo, error) { +func (c *TaskerClient) CreateTask(ctx context.Context, taskName TaskName, queueName QueueName, task *Task) (*asynq.TaskInfo, error) { if task.Id == "" { task.Id = uuid.NewString() } @@ -42,7 +43,7 @@ func (c *TaskerClient) CreateTask(taskName TaskName, queueName QueueName, task * asynq.Timeout(taskTimeout*time.Second), ) - taskInfo, err := c.Client.Enqueue(qTask) + taskInfo, err := c.Client.EnqueueContext(ctx, qTask) if err != nil { return nil, err } diff --git a/internal/tasker/task/account_gift_gas.go b/internal/tasker/task/account_gift_gas.go index 6a1350a..1fc7239 100644 --- a/internal/tasker/task/account_gift_gas.go +++ b/internal/tasker/task/account_gift_gas.go @@ -93,6 +93,7 @@ func AccountGiftGasProcessor(cu *custodial.Custodial) func(context.Context, *asy } _, err = cu.TaskerClient.CreateTask( + ctx, tasker.DispatchTxTask, tasker.HighPriority, &tasker.Task{ diff --git a/internal/tasker/task/account_gift_voucher.go b/internal/tasker/task/account_gift_voucher.go index e718f95..cca5a01 100644 --- a/internal/tasker/task/account_gift_voucher.go +++ b/internal/tasker/task/account_gift_voucher.go @@ -103,6 +103,7 @@ func GiftVoucherProcessor(cu *custodial.Custodial) func(context.Context, *asynq. } _, err = cu.TaskerClient.CreateTask( + ctx, tasker.DispatchTxTask, tasker.HighPriority, &tasker.Task{ diff --git a/internal/tasker/task/account_prepare.go b/internal/tasker/task/account_prepare.go index 728cf56..e3b90db 100644 --- a/internal/tasker/task/account_prepare.go +++ b/internal/tasker/task/account_prepare.go @@ -29,6 +29,7 @@ func AccountPrepare(cu *custodial.Custodial) func(context.Context, *asynq.Task) } _, err := cu.TaskerClient.CreateTask( + ctx, tasker.AccountRegisterTask, tasker.DefaultPriority, &tasker.Task{ @@ -40,6 +41,7 @@ func AccountPrepare(cu *custodial.Custodial) func(context.Context, *asynq.Task) } _, err = cu.TaskerClient.CreateTask( + ctx, tasker.AccountGiftGasTask, tasker.DefaultPriority, &tasker.Task{ @@ -51,6 +53,7 @@ func AccountPrepare(cu *custodial.Custodial) func(context.Context, *asynq.Task) } _, err = cu.TaskerClient.CreateTask( + ctx, tasker.AccountGiftVoucherTask, tasker.DefaultPriority, &tasker.Task{ diff --git a/internal/tasker/task/account_refill_gas.go b/internal/tasker/task/account_refill_gas.go index 9bd6caf..5276d9a 100644 --- a/internal/tasker/task/account_refill_gas.go +++ b/internal/tasker/task/account_refill_gas.go @@ -108,6 +108,7 @@ func AccountRefillGasProcessor(cu *custodial.Custodial) func(context.Context, *a } _, err = cu.TaskerClient.CreateTask( + ctx, tasker.DispatchTxTask, tasker.HighPriority, &tasker.Task{ diff --git a/internal/tasker/task/account_register_onchain.go b/internal/tasker/task/account_register_onchain.go index 206d34f..883bf25 100644 --- a/internal/tasker/task/account_register_onchain.go +++ b/internal/tasker/task/account_register_onchain.go @@ -101,6 +101,7 @@ func AccountRegisterOnChainProcessor(cu *custodial.Custodial) func(context.Conte } _, err = cu.TaskerClient.CreateTask( + ctx, tasker.DispatchTxTask, tasker.HighPriority, &tasker.Task{ diff --git a/internal/tasker/task/sign_transfer.go b/internal/tasker/task/sign_transfer.go index efe104e..b90488e 100644 --- a/internal/tasker/task/sign_transfer.go +++ b/internal/tasker/task/sign_transfer.go @@ -123,6 +123,7 @@ func SignTransfer(cu *custodial.Custodial) func(context.Context, *asynq.Task) er } _, err = cu.TaskerClient.CreateTask( + ctx, tasker.DispatchTxTask, tasker.HighPriority, &tasker.Task{ @@ -141,6 +142,7 @@ func SignTransfer(cu *custodial.Custodial) func(context.Context, *asynq.Task) er } _, err = cu.TaskerClient.CreateTask( + ctx, tasker.AccountRefillGasTask, tasker.DefaultPriority, &tasker.Task{