wip-flag-migration #28

Merged
Alfred-mk merged 44 commits from wip-flag-migration into master 2024-09-04 11:25:34 +02:00
2 changed files with 236 additions and 77 deletions
Showing only changes of commit 31be1fa221 - Show all commits

View File

@ -15,7 +15,6 @@ import (
"git.defalsify.org/vise.git/resource" "git.defalsify.org/vise.git/resource"
"git.defalsify.org/vise.git/state" "git.defalsify.org/vise.git/state"
"git.grassecon.net/urdt/ussd/internal/handlers/server" "git.grassecon.net/urdt/ussd/internal/handlers/server"
"git.grassecon.net/urdt/ussd/internal/models"
"git.grassecon.net/urdt/ussd/internal/utils" "git.grassecon.net/urdt/ussd/internal/utils"
"gopkg.in/leonelquinteros/gotext.v1" "gopkg.in/leonelquinteros/gotext.v1"
) )
@ -30,9 +29,13 @@ type FSData struct {
St *state.State St *state.State
} }
type FlagParserInterface interface {
GetFlag(key string) (uint32, error)
}
type Handlers struct { type Handlers struct {
fs *FSData fs *FSData
parser *asm.FlagParser parser FlagParserInterface
accountFileHandler utils.AccountFileHandlerInterface accountFileHandler utils.AccountFileHandlerInterface
accountService server.AccountServiceInterface accountService server.AccountServiceInterface
} }
@ -78,8 +81,16 @@ func (h *Handlers) PreloadFlags(flagKeys []string) (map[string]uint32, error) {
// SetLanguage sets the language across the menu // SetLanguage sets the language across the menu
func (h *Handlers) SetLanguage(ctx context.Context, sym string, input []byte) (resource.Result, error) { func (h *Handlers) SetLanguage(ctx context.Context, sym string, input []byte) (resource.Result, error) {
inputStr := string(input)
res := resource.Result{} res := resource.Result{}
// Preload the required flag
flagKeys := []string{"flag_language_set"}
flags, err := h.PreloadFlags(flagKeys)
if err != nil {
return res, err
}
inputStr := string(input)
switch inputStr { switch inputStr {
case "0": case "0":
res.FlagSet = []uint32{state.FLAG_LANG} res.FlagSet = []uint32{state.FLAG_LANG}
@ -90,7 +101,7 @@ func (h *Handlers) SetLanguage(ctx context.Context, sym string, input []byte) (r
default: default:
Alfred-mk marked this conversation as resolved Outdated
Outdated
Review

It shouldn't be necessary to process the flags twice.

It shouldn't be necessary to process the flags twice.
} }
res.FlagSet = append(res.FlagSet, models.USERFLAG_LANGUAGE_SET) res.FlagSet = append(res.FlagSet, flags["flag_language_set"])
return res, nil return res, nil
} }
@ -101,7 +112,14 @@ func (h *Handlers) SetLanguage(ctx context.Context, sym string, input []byte) (r
func (h *Handlers) CreateAccount(ctx context.Context, sym string, input []byte) (resource.Result, error) { func (h *Handlers) CreateAccount(ctx context.Context, sym string, input []byte) (resource.Result, error) {
res := resource.Result{} res := resource.Result{}
err := h.accountFileHandler.EnsureFileExists() // Preload the required flags
flagKeys := []string{"flag_account_created", "flag_account_creation_failed"}
flags, err := h.PreloadFlags(flagKeys)
if err != nil {
return res, err
}
err = h.accountFileHandler.EnsureFileExists()
if err != nil { if err != nil {
return res, err return res, err
} }
@ -114,7 +132,7 @@ func (h *Handlers) CreateAccount(ctx context.Context, sym string, input []byte)
accountResp, err := h.accountService.CreateAccount() accountResp, err := h.accountService.CreateAccount()
if err != nil { if err != nil {
res.FlagSet = append(res.FlagSet, models.USERFLAG_ACCOUNT_CREATION_FAILED) res.FlagSet = append(res.FlagSet, flags["flag_account_creation_failed"])
return res, err return res, err
} }
@ -135,13 +153,21 @@ func (h *Handlers) CreateAccount(ctx context.Context, sym string, input []byte)
return res, err return res, err
} }
res.FlagSet = append(res.FlagSet, models.USERFLAG_ACCOUNT_CREATED) res.FlagSet = append(res.FlagSet, flags["flag_account_created"])
return res, err return res, err
} }
// SavePin persists the user's PIN choice into the filesystem // SavePin persists the user's PIN choice into the filesystem
func (h *Handlers) SavePin(ctx context.Context, sym string, input []byte) (resource.Result, error) { func (h *Handlers) SavePin(ctx context.Context, sym string, input []byte) (resource.Result, error) {
res := resource.Result{} res := resource.Result{}
// Preload the required flags
flagKeys := []string{"flag_incorrect_pin"}
flags, err := h.PreloadFlags(flagKeys)
if err != nil {
return res, err
}
accountPIN := string(input) accountPIN := string(input)
accountData, err := h.accountFileHandler.ReadAccountData() accountData, err := h.accountFileHandler.ReadAccountData()
@ -151,11 +177,11 @@ func (h *Handlers) SavePin(ctx context.Context, sym string, input []byte) (resou
// Validate that the PIN is a 4-digit number // Validate that the PIN is a 4-digit number
if !isValidPIN(accountPIN) { if !isValidPIN(accountPIN) {
res.FlagSet = append(res.FlagSet, models.USERFLAG_INCORRECTPIN) res.FlagSet = append(res.FlagSet, flags["flag_incorrect_pin"])
return res, nil return res, nil
} }
res.FlagReset = append(res.FlagReset, models.USERFLAG_INCORRECTPIN) res.FlagReset = append(res.FlagReset, flags["flag_incorrect_pin"])
accountData["AccountPIN"] = accountPIN accountData["AccountPIN"] = accountPIN
err = h.accountFileHandler.WriteAccountData(accountData) err = h.accountFileHandler.WriteAccountData(accountData)
@ -170,18 +196,26 @@ func (h *Handlers) SavePin(ctx context.Context, sym string, input []byte) (resou
func (h *Handlers) SetResetSingleEdit(ctx context.Context, sym string, input []byte) (resource.Result, error) { func (h *Handlers) SetResetSingleEdit(ctx context.Context, sym string, input []byte) (resource.Result, error) {
res := resource.Result{} res := resource.Result{}
menuOption := string(input) menuOption := string(input)
// Preload the required flags
flagKeys := []string{"flag_allow_update", "flag_single_edit"}
flags, err := h.PreloadFlags(flagKeys)
if err != nil {
return res, err
}
switch menuOption { switch menuOption {
case "2": case "2":
res.FlagReset = append(res.FlagSet, models.USERFLAG_ALLOW_UPDATE) res.FlagReset = append(res.FlagSet, flags["flag_allow_update"])
res.FlagSet = append(res.FlagSet, models.USERFLAG_SINGLE_EDIT) res.FlagSet = append(res.FlagSet, flags["flag_single_edit"])
case "3": case "3":
res.FlagReset = append(res.FlagSet, models.USERFLAG_ALLOW_UPDATE) res.FlagReset = append(res.FlagSet, flags["flag_allow_update"])
res.FlagSet = append(res.FlagSet, models.USERFLAG_SINGLE_EDIT) res.FlagSet = append(res.FlagSet, flags["flag_single_edit"])
case "4": case "4":
res.FlagReset = append(res.FlagSet, models.USERFLAG_ALLOW_UPDATE) res.FlagReset = append(res.FlagSet, flags["flag_allow_update"])
res.FlagSet = append(res.FlagSet, models.USERFLAG_SINGLE_EDIT) res.FlagSet = append(res.FlagSet, flags["flag_single_edit"])
default: default:
res.FlagReset = append(res.FlagReset, models.USERFLAG_SINGLE_EDIT) res.FlagReset = append(res.FlagReset, flags["flag_single_edit"])
} }
return res, nil return res, nil
@ -193,17 +227,24 @@ func (h *Handlers) SetResetSingleEdit(ctx context.Context, sym string, input []b
func (h *Handlers) VerifyPin(ctx context.Context, sym string, input []byte) (resource.Result, error) { func (h *Handlers) VerifyPin(ctx context.Context, sym string, input []byte) (resource.Result, error) {
res := resource.Result{} res := resource.Result{}
// Preload the required flags
flagKeys := []string{"flag_valid_pin", "flag_pin_mismatch", "flag_pin_set"}
flags, err := h.PreloadFlags(flagKeys)
if err != nil {
return res, err
}
accountData, err := h.accountFileHandler.ReadAccountData() accountData, err := h.accountFileHandler.ReadAccountData()
if err != nil { if err != nil {
return res, err return res, err
} }
if bytes.Equal(input, []byte(accountData["AccountPIN"])) { if bytes.Equal(input, []byte(accountData["AccountPIN"])) {
res.FlagSet = []uint32{models.USERFLAG_VALIDPIN} res.FlagSet = []uint32{flags["flag_valid_pin"]}
res.FlagReset = []uint32{models.USERFLAG_PINMISMATCH} res.FlagReset = []uint32{flags["flag_pin_mismatch"]}
res.FlagSet = append(res.FlagSet, models.USERFLAG_PIN_SET) res.FlagSet = append(res.FlagSet, flags["flag_pin_set"])
} else { } else {
res.FlagSet = []uint32{models.USERFLAG_PINMISMATCH} res.FlagSet = []uint32{flags["flag_pin_mismatch"]}
} }
return res, nil return res, nil
@ -361,14 +402,30 @@ func (h *Handlers) SaveOfferings(ctx context.Context, sym string, input []byte)
// ResetAllowUpdate resets the allowupdate flag that allows a user to update profile data. // ResetAllowUpdate resets the allowupdate flag that allows a user to update profile data.
func (h *Handlers) ResetAllowUpdate(ctx context.Context, sym string, input []byte) (resource.Result, error) { func (h *Handlers) ResetAllowUpdate(ctx context.Context, sym string, input []byte) (resource.Result, error) {
res := resource.Result{} res := resource.Result{}
res.FlagReset = append(res.FlagReset, models.USERFLAG_ALLOW_UPDATE)
// Preload the required flag
flagKeys := []string{"flag_allow_update"}
flags, err := h.PreloadFlags(flagKeys)
if err != nil {
return res, err
}
res.FlagReset = append(res.FlagReset, flags["flag_allow_update"])
return res, nil return res, nil
} }
// ResetAccountAuthorized resets the account authorization flag after a successful PIN entry. // ResetAccountAuthorized resets the account authorization flag after a successful PIN entry.
func (h *Handlers) ResetAccountAuthorized(ctx context.Context, sym string, input []byte) (resource.Result, error) { func (h *Handlers) ResetAccountAuthorized(ctx context.Context, sym string, input []byte) (resource.Result, error) {
res := resource.Result{} res := resource.Result{}
res.FlagReset = append(res.FlagReset, models.USERFLAG_ACCOUNT_AUTHORIZED)
// Preload the required flags
flagKeys := []string{"flag_account_authorized"}
flags, err := h.PreloadFlags(flagKeys)
if err != nil {
return res, err
}
res.FlagReset = append(res.FlagReset, flags["flag_account_authorized"])
return res, nil return res, nil
} }
@ -390,12 +447,6 @@ func (h *Handlers) CheckIdentifier(ctx context.Context, sym string, input []byte
// It sets the required flags that control the flow. // It sets the required flags that control the flow.
func (h *Handlers) Authorize(ctx context.Context, sym string, input []byte) (resource.Result, error) { func (h *Handlers) Authorize(ctx context.Context, sym string, input []byte) (resource.Result, error) {
res := resource.Result{} res := resource.Result{}
pin := string(input)
accountData, err := h.accountFileHandler.ReadAccountData()
if err != nil {
return res, err
}
// Preload the required flags // Preload the required flags
flagKeys := []string{"flag_incorrect_pin", "flag_account_authorized", "flag_allow_update"} flagKeys := []string{"flag_incorrect_pin", "flag_account_authorized", "flag_allow_update"}
@ -404,6 +455,13 @@ func (h *Handlers) Authorize(ctx context.Context, sym string, input []byte) (res
return res, err return res, err
} }
pin := string(input)
accountData, err := h.accountFileHandler.ReadAccountData()
if err != nil {
return res, err
}
if len(input) == 4 { if len(input) == 4 {
if pin != accountData["AccountPIN"] { if pin != accountData["AccountPIN"] {
res.FlagSet = append(res.FlagSet, flags["flag_incorrect_pin"]) res.FlagSet = append(res.FlagSet, flags["flag_incorrect_pin"])
@ -424,7 +482,15 @@ func (h *Handlers) Authorize(ctx context.Context, sym string, input []byte) (res
// ResetIncorrectPin resets the incorrect pin flag after a new PIN attempt. // ResetIncorrectPin resets the incorrect pin flag after a new PIN attempt.
func (h *Handlers) ResetIncorrectPin(ctx context.Context, sym string, input []byte) (resource.Result, error) { func (h *Handlers) ResetIncorrectPin(ctx context.Context, sym string, input []byte) (resource.Result, error) {
res := resource.Result{} res := resource.Result{}
res.FlagReset = append(res.FlagReset, models.USERFLAG_INCORRECTPIN)
// Preload the required flag
flagKeys := []string{"flag_incorrect_pin"}
flags, err := h.PreloadFlags(flagKeys)
if err != nil {
return res, err
}
res.FlagReset = append(res.FlagReset, flags["flag_incorrect_pin"])
return res, nil return res, nil
} }
@ -433,6 +499,13 @@ func (h *Handlers) ResetIncorrectPin(ctx context.Context, sym string, input []by
func (h *Handlers) CheckAccountStatus(ctx context.Context, sym string, input []byte) (resource.Result, error) { func (h *Handlers) CheckAccountStatus(ctx context.Context, sym string, input []byte) (resource.Result, error) {
res := resource.Result{} res := resource.Result{}
// Preload the required flags
flagKeys := []string{"flag_account_success", "flag_account_pending"}
flags, err := h.PreloadFlags(flagKeys)
if err != nil {
return res, err
}
accountData, err := h.accountFileHandler.ReadAccountData() accountData, err := h.accountFileHandler.ReadAccountData()
if err != nil { if err != nil {
return res, err return res, err
@ -448,11 +521,11 @@ func (h *Handlers) CheckAccountStatus(ctx context.Context, sym string, input []b
accountData["Status"] = status accountData["Status"] = status
if status == "SUCCESS" { if status == "SUCCESS" {
res.FlagSet = append(res.FlagSet, models.USERFLAG_ACCOUNT_SUCCESS) res.FlagSet = append(res.FlagSet, flags["flag_account_success"])
res.FlagReset = append(res.FlagReset, models.USERFLAG_ACCOUNT_PENDING) res.FlagReset = append(res.FlagReset, flags["flag_account_pending"])
} else { } else {
res.FlagReset = append(res.FlagSet, models.USERFLAG_ACCOUNT_SUCCESS) res.FlagReset = append(res.FlagSet, flags["flag_account_success"])
res.FlagSet = append(res.FlagReset, models.USERFLAG_ACCOUNT_PENDING) res.FlagSet = append(res.FlagReset, flags["flag_account_pending"])
} }
err = h.accountFileHandler.WriteAccountData(accountData) err = h.accountFileHandler.WriteAccountData(accountData)
@ -467,30 +540,45 @@ func (h *Handlers) CheckAccountStatus(ctx context.Context, sym string, input []b
func (h *Handlers) Quit(ctx context.Context, sym string, input []byte) (resource.Result, error) { func (h *Handlers) Quit(ctx context.Context, sym string, input []byte) (resource.Result, error) {
res := resource.Result{} res := resource.Result{}
// Preload the required flags
flagKeys := []string{"flag_account_authorized"}
flags, err := h.PreloadFlags(flagKeys)
if err != nil {
return res, err
}
code := codeFromCtx(ctx) code := codeFromCtx(ctx)
l := gotext.NewLocale(translationDir, code) l := gotext.NewLocale(translationDir, code)
l.AddDomain("default") l.AddDomain("default")
res.Content = l.Get("Thank you for using Sarafu. Goodbye!") res.Content = l.Get("Thank you for using Sarafu. Goodbye!")
res.FlagReset = append(res.FlagReset, models.USERFLAG_ACCOUNT_AUTHORIZED) res.FlagReset = append(res.FlagReset, flags["flag_account_authorized"])
return res, nil return res, nil
} }
// VerifyYob verifies the length of the given input // VerifyYob verifies the length of the given input
func (h *Handlers) VerifyYob(ctx context.Context, sym string, input []byte) (resource.Result, error) { func (h *Handlers) VerifyYob(ctx context.Context, sym string, input []byte) (resource.Result, error) {
res := resource.Result{} res := resource.Result{}
// Preload the required flag
flagKeys := []string{"flag_incorrect_date_format"}
flags, err := h.PreloadFlags(flagKeys)
if err != nil {
return res, err
}
date := string(input) date := string(input)
_, err := strconv.Atoi(date) _, err = strconv.Atoi(date)
if err != nil { if err != nil {
// If conversion fails, input is not numeric // If conversion fails, input is not numeric
res.FlagSet = append(res.FlagSet, models.USERFLAG_INCORRECTDATEFORMAT) res.FlagSet = append(res.FlagSet, flags["flag_incorrect_date_format"])
return res, nil return res, nil
} }
if len(date) == 4 { if len(date) == 4 {
res.FlagReset = append(res.FlagReset, models.USERFLAG_INCORRECTDATEFORMAT) res.FlagReset = append(res.FlagReset, flags["flag_incorrect_date_format"])
} else { } else {
res.FlagSet = append(res.FlagSet, models.USERFLAG_INCORRECTDATEFORMAT) res.FlagSet = append(res.FlagSet, flags["flag_incorrect_date_format"])
} }
return res, nil return res, nil
@ -499,7 +587,15 @@ func (h *Handlers) VerifyYob(ctx context.Context, sym string, input []byte) (res
// ResetIncorrectYob resets the incorrect date format after a new attempt // ResetIncorrectYob resets the incorrect date format after a new attempt
func (h *Handlers) ResetIncorrectYob(ctx context.Context, sym string, input []byte) (resource.Result, error) { func (h *Handlers) ResetIncorrectYob(ctx context.Context, sym string, input []byte) (resource.Result, error) {
res := resource.Result{} res := resource.Result{}
res.FlagReset = append(res.FlagReset, models.USERFLAG_INCORRECTDATEFORMAT)
// Preload the required flags
flagKeys := []string{"flag_incorrect_date_format"}
flags, err := h.PreloadFlags(flagKeys)
if err != nil {
return res, err
}
res.FlagReset = append(res.FlagReset, flags["flag_incorrect_date_format"])
return res, nil return res, nil
} }
@ -525,6 +621,14 @@ func (h *Handlers) CheckBalance(ctx context.Context, sym string, input []byte) (
// ValidateRecipient validates that the given input is a valid phone number. // ValidateRecipient validates that the given input is a valid phone number.
func (h *Handlers) ValidateRecipient(ctx context.Context, sym string, input []byte) (resource.Result, error) { func (h *Handlers) ValidateRecipient(ctx context.Context, sym string, input []byte) (resource.Result, error) {
res := resource.Result{} res := resource.Result{}
// Preload the required flags
flagKeys := []string{"flag_invalid_recipient"}
flags, err := h.PreloadFlags(flagKeys)
if err != nil {
return res, err
}
carlos marked this conversation as resolved Outdated
Outdated
Review

we cannot have panics anywhere in the vm execution! Just return error.

we cannot have panics anywhere in the vm execution! Just return error.
recipient := string(input) recipient := string(input)
accountData, err := h.accountFileHandler.ReadAccountData() accountData, err := h.accountFileHandler.ReadAccountData()
@ -535,7 +639,7 @@ func (h *Handlers) ValidateRecipient(ctx context.Context, sym string, input []by
if recipient != "0" { if recipient != "0" {
// mimic invalid number check // mimic invalid number check
if recipient == "000" { if recipient == "000" {
res.FlagSet = append(res.FlagSet, models.USERFLAG_INVALID_RECIPIENT) res.FlagSet = append(res.FlagSet, flags["flag_invalid_recipient"])
res.Content = recipient res.Content = recipient
return res, nil return res, nil
@ -556,6 +660,14 @@ func (h *Handlers) ValidateRecipient(ctx context.Context, sym string, input []by
// as well as the invalid flags // as well as the invalid flags
func (h *Handlers) TransactionReset(ctx context.Context, sym string, input []byte) (resource.Result, error) { func (h *Handlers) TransactionReset(ctx context.Context, sym string, input []byte) (resource.Result, error) {
res := resource.Result{} res := resource.Result{}
// Preload the required flags
flagKeys := []string{"flag_invalid_recipient", "flag_invalid_recipient_with_invite"}
flags, err := h.PreloadFlags(flagKeys)
if err != nil {
return res, err
}
accountData, err := h.accountFileHandler.ReadAccountData() accountData, err := h.accountFileHandler.ReadAccountData()
if err != nil { if err != nil {
return res, err return res, err
@ -570,7 +682,7 @@ func (h *Handlers) TransactionReset(ctx context.Context, sym string, input []byt
return res, err return res, err
} }
res.FlagReset = append(res.FlagReset, models.USERFLAG_INVALID_RECIPIENT, models.USERFLAG_INVALID_RECIPIENT_WITH_INVITE) res.FlagReset = append(res.FlagReset, flags["flag_invalid_recipient"], flags["flag_invalid_recipient_with_invite"])
return res, nil return res, nil
} }
@ -578,6 +690,14 @@ func (h *Handlers) TransactionReset(ctx context.Context, sym string, input []byt
// ResetTransactionAmount resets the transaction amount and invalid flag // ResetTransactionAmount resets the transaction amount and invalid flag
func (h *Handlers) ResetTransactionAmount(ctx context.Context, sym string, input []byte) (resource.Result, error) { func (h *Handlers) ResetTransactionAmount(ctx context.Context, sym string, input []byte) (resource.Result, error) {
res := resource.Result{} res := resource.Result{}
// Preload the required flag
flagKeys := []string{"flag_invalid_amount"}
flags, err := h.PreloadFlags(flagKeys)
if err != nil {
return res, err
}
accountData, err := h.accountFileHandler.ReadAccountData() accountData, err := h.accountFileHandler.ReadAccountData()
if err != nil { if err != nil {
return res, err return res, err
@ -591,7 +711,7 @@ func (h *Handlers) ResetTransactionAmount(ctx context.Context, sym string, input
return res, err return res, err
} }
res.FlagReset = append(res.FlagReset, models.USERFLAG_INVALID_AMOUNT) res.FlagReset = append(res.FlagReset, flags["flag_invalid_amount"])
return res, nil return res, nil
} }
@ -620,6 +740,14 @@ func (h *Handlers) MaxAmount(ctx context.Context, sym string, input []byte) (res
// it is not more than the current balance. // it is not more than the current balance.
func (h *Handlers) ValidateAmount(ctx context.Context, sym string, input []byte) (resource.Result, error) { func (h *Handlers) ValidateAmount(ctx context.Context, sym string, input []byte) (resource.Result, error) {
res := resource.Result{} res := resource.Result{}
// Preload the required flag
flagKeys := []string{"flag_invalid_amount"}
flags, err := h.PreloadFlags(flagKeys)
if err != nil {
return res, err
}
amountStr := string(input) amountStr := string(input)
accountData, err := h.accountFileHandler.ReadAccountData() accountData, err := h.accountFileHandler.ReadAccountData()
@ -647,20 +775,20 @@ func (h *Handlers) ValidateAmount(ctx context.Context, sym string, input []byte)
re := regexp.MustCompile(`^(\d+(\.\d+)?)\s*(?:CELO)?$`) re := regexp.MustCompile(`^(\d+(\.\d+)?)\s*(?:CELO)?$`)
matches := re.FindStringSubmatch(strings.TrimSpace(amountStr)) matches := re.FindStringSubmatch(strings.TrimSpace(amountStr))
if len(matches) < 2 { if len(matches) < 2 {
res.FlagSet = append(res.FlagSet, models.USERFLAG_INVALID_AMOUNT) res.FlagSet = append(res.FlagSet, flags["flag_invalid_amount"])
res.Content = amountStr res.Content = amountStr
return res, nil return res, nil
} }
inputAmount, err := strconv.ParseFloat(matches[1], 64) inputAmount, err := strconv.ParseFloat(matches[1], 64)
if err != nil { if err != nil {
res.FlagSet = append(res.FlagSet, models.USERFLAG_INVALID_AMOUNT) res.FlagSet = append(res.FlagSet, flags["flag_invalid_amount"])
res.Content = amountStr res.Content = amountStr
return res, nil return res, nil
} }
if inputAmount > balanceValue { if inputAmount > balanceValue {
res.FlagSet = append(res.FlagSet, models.USERFLAG_INVALID_AMOUNT) res.FlagSet = append(res.FlagSet, flags["flag_invalid_amount"])
res.Content = amountStr res.Content = amountStr
return res, nil return res, nil
} }
@ -755,6 +883,14 @@ func (h *Handlers) GetAmount(ctx context.Context, sym string, input []byte) (res
// gracefully exiting the session. // gracefully exiting the session.
func (h *Handlers) QuitWithBalance(ctx context.Context, sym string, input []byte) (resource.Result, error) { func (h *Handlers) QuitWithBalance(ctx context.Context, sym string, input []byte) (resource.Result, error) {
res := resource.Result{} res := resource.Result{}
// Preload the required flag
flagKeys := []string{"flag_account_authorized"}
flags, err := h.PreloadFlags(flagKeys)
if err != nil {
return res, err
}
code := codeFromCtx(ctx) code := codeFromCtx(ctx)
l := gotext.NewLocale(translationDir, code) l := gotext.NewLocale(translationDir, code)
l.AddDomain("default") l.AddDomain("default")
@ -767,7 +903,7 @@ func (h *Handlers) QuitWithBalance(ctx context.Context, sym string, input []byte
return res, nil return res, nil
} }
res.Content = l.Get("Your account balance is %s", balance) res.Content = l.Get("Your account balance is %s", balance)
res.FlagReset = append(res.FlagReset, models.USERFLAG_ACCOUNT_AUTHORIZED) res.FlagReset = append(res.FlagReset, flags["flag_account_authorized"])
return res, nil return res, nil
} }

View File

@ -21,6 +21,15 @@ type MockAccountService struct {
mock.Mock mock.Mock
} }
type MockFlagParser struct {
mock.Mock
}
func (m *MockFlagParser) GetFlag(key string) (uint32, error) {
args := m.Called(key)
return args.Get(0).(uint32), args.Error(1)
}
func (m *MockAccountService) CreateAccount() (*models.AccountResponse, error) { func (m *MockAccountService) CreateAccount() (*models.AccountResponse, error) {
args := m.Called() args := m.Called()
return args.Get(0).(*models.AccountResponse), args.Error(1) return args.Get(0).(*models.AccountResponse), args.Error(1)
@ -70,11 +79,20 @@ func TestCreateAccount(t *testing.T) {
// Set up expectations for the mock account service // Set up expectations for the mock account service
mockAccountService.On("CreateAccount").Return(mockAccountResponse, nil) mockAccountService.On("CreateAccount").Return(mockAccountResponse, nil)
mockParser := new(MockFlagParser)
flag_account_created := uint32(1)
flag_account_creation_failed := uint32(2)
mockParser.On("GetFlag", "flag_account_created").Return(flag_account_created, nil)
mockParser.On("GetFlag", "flag_account_creation_failed").Return(flag_account_creation_failed, nil)
// Initialize Handlers with mock account service // Initialize Handlers with mock account service
h := &Handlers{ h := &Handlers{
fs: &FSData{Path: accountFilePath}, fs: &FSData{Path: accountFilePath},
accountFileHandler: accountFileHandler, accountFileHandler: accountFileHandler,
accountService: mockAccountService, accountService: mockAccountService,
parser: mockParser,
} }
tests := []struct { tests := []struct {
@ -87,7 +105,7 @@ func TestCreateAccount(t *testing.T) {
name: "New account creation", name: "New account creation",
existingData: nil, existingData: nil,
expectedResult: resource.Result{ expectedResult: resource.Result{
FlagSet: []uint32{models.USERFLAG_ACCOUNT_CREATED}, FlagSet: []uint32{flag_account_created},
}, },
expectedData: map[string]string{ expectedData: map[string]string{
"TrackingId": "test-tracking-id", "TrackingId": "test-tracking-id",
@ -248,10 +266,16 @@ func TestSavePin(t *testing.T) {
// Create a new AccountFileHandler and set it in the Handlers struct // Create a new AccountFileHandler and set it in the Handlers struct
accountFileHandler := utils.NewAccountFileHandler(accountFilePath) accountFileHandler := utils.NewAccountFileHandler(accountFilePath)
mockParser := new(MockFlagParser)
h := &Handlers{ h := &Handlers{
accountFileHandler: accountFileHandler, accountFileHandler: accountFileHandler,
parser: mockParser,
} }
flag_incorrect_pin := uint32(1)
mockParser.On("GetFlag", "flag_incorrect_pin").Return(flag_incorrect_pin, nil)
tests := []struct { tests := []struct {
name string name string
input []byte input []byte
@ -272,21 +296,21 @@ func TestSavePin(t *testing.T) {
{ {
name: "Invalid PIN - non-numeric", name: "Invalid PIN - non-numeric",
input: []byte("12ab"), input: []byte("12ab"),
expectedFlags: []uint32{models.USERFLAG_INCORRECTPIN}, expectedFlags: []uint32{flag_incorrect_pin},
expectedData: initialAccountData, // No changes expected expectedData: initialAccountData, // No changes expected
expectedErrors: false, expectedErrors: false,
}, },
{ {
name: "Invalid PIN - less than 4 digits", name: "Invalid PIN - less than 4 digits",
input: []byte("123"), input: []byte("123"),
expectedFlags: []uint32{models.USERFLAG_INCORRECTPIN}, expectedFlags: []uint32{flag_incorrect_pin},
expectedData: initialAccountData, // No changes expected expectedData: initialAccountData, // No changes expected
expectedErrors: false, expectedErrors: false,
}, },
{ {
name: "Invalid PIN - more than 4 digits", name: "Invalid PIN - more than 4 digits",
input: []byte("12345"), input: []byte("12345"),
expectedFlags: []uint32{models.USERFLAG_INCORRECTPIN}, expectedFlags: []uint32{flag_incorrect_pin},
expectedData: initialAccountData, // No changes expected expectedData: initialAccountData, // No changes expected
expectedErrors: false, expectedErrors: false,
}, },
@ -294,7 +318,6 @@ func TestSavePin(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
// Ensure the file exists before running the test
err := accountFileHandler.EnsureFileExists() err := accountFileHandler.EnsureFileExists()
if err != nil { if err != nil {
t.Fatalf("Failed to ensure account file exists: %v", err) t.Fatalf("Failed to ensure account file exists: %v", err)