Compare commits

..

No commits in common. "7df77a134307823e058ca05541a004283189d0e1" and "7fe8f0b7d578d1541207a8604a909c08b8124a68" have entirely different histories.

6 changed files with 101 additions and 80 deletions

View File

@ -69,6 +69,7 @@ func (ls *LocalHandlerService) GetHandler() (*ussd.Handlers, error) {
ls.DbRs.AddLocalFunc("check_balance", ussdHandlers.CheckBalance) ls.DbRs.AddLocalFunc("check_balance", ussdHandlers.CheckBalance)
ls.DbRs.AddLocalFunc("validate_recipient", ussdHandlers.ValidateRecipient) ls.DbRs.AddLocalFunc("validate_recipient", ussdHandlers.ValidateRecipient)
ls.DbRs.AddLocalFunc("transaction_reset", ussdHandlers.TransactionReset) ls.DbRs.AddLocalFunc("transaction_reset", ussdHandlers.TransactionReset)
ls.DbRs.AddLocalFunc("max_amount", ussdHandlers.MaxAmount)
ls.DbRs.AddLocalFunc("validate_amount", ussdHandlers.ValidateAmount) ls.DbRs.AddLocalFunc("validate_amount", ussdHandlers.ValidateAmount)
ls.DbRs.AddLocalFunc("reset_transaction_amount", ussdHandlers.ResetTransactionAmount) ls.DbRs.AddLocalFunc("reset_transaction_amount", ussdHandlers.ResetTransactionAmount)
ls.DbRs.AddLocalFunc("get_recipient", ussdHandlers.GetRecipient) ls.DbRs.AddLocalFunc("get_recipient", ussdHandlers.GetRecipient)

View File

@ -761,60 +761,74 @@ func (h *Handlers) ResetTransactionAmount(ctx context.Context, sym string, input
return res, nil return res, nil
} }
// ValidateAmount ensures that the given input is a valid amount and that // MaxAmount gets the current balance from the API and sets it as
// it is not more than the current balance. // the result content.
func (h *Handlers) ValidateAmount(ctx context.Context, sym string, input []byte) (resource.Result, error) { func (h *Handlers) MaxAmount(ctx context.Context, sym string, input []byte) (resource.Result, error) {
var res resource.Result var res resource.Result
var err error
sessionId, ok := ctx.Value("SessionId").(string) sessionId, ok := ctx.Value("SessionId").(string)
if !ok { if !ok {
return res, fmt.Errorf("missing session") return res, fmt.Errorf("missing session")
} }
flag_invalid_amount, _ := h.flagManager.GetFlag("flag_invalid_amount")
store := h.userdataStore store := h.userdataStore
publicKey, _ := store.ReadEntry(ctx, sessionId, utils.DATA_PUBLIC_KEY)
balance, err := h.accountService.CheckBalance(string(publicKey))
if err != nil {
return res, nil
}
res.Content = balance
return res, nil
}
// ValidateAmount ensures that the given input is a valid amount and that
// it is not more than the current balance.
func (h *Handlers) ValidateAmount(ctx context.Context, sym string, input []byte) (resource.Result, error) {
var res resource.Result
var err error
sessionId, ok := ctx.Value("SessionId").(string)
if !ok {
return res, fmt.Errorf("missing session")
}
flag_invalid_amount, _ := h.flagManager.GetFlag("flag_invalid_amount")
store := h.userdataStore
publicKey, _ := store.ReadEntry(ctx, sessionId, utils.DATA_PUBLIC_KEY)
amountStr := string(input)
balanceStr, err := h.accountService.CheckBalance(string(publicKey))
publicKey, err := store.ReadEntry(ctx, sessionId, utils.DATA_PUBLIC_KEY)
if err != nil { if err != nil {
return res, err return res, err
} }
res.Content = balanceStr
// retrieve the active symbol // Parse the balance
activeSym, err := store.ReadEntry(ctx, sessionId, utils.DATA_ACTIVE_SYM) balanceParts := strings.Split(balanceStr, " ")
useActiveSymbol := err == nil && len(activeSym) > 0 if len(balanceParts) != 2 {
return res, fmt.Errorf("unexpected balance format: %s", balanceStr)
var balanceValue float64 }
if useActiveSymbol { balanceValue, err := strconv.ParseFloat(balanceParts[0], 64)
// If active symbol is set, retrieve its balance if err != nil {
activeBal, err := store.ReadEntry(ctx, sessionId, utils.DATA_ACTIVE_BAL) return res, fmt.Errorf("failed to parse balance: %v", err)
if err != nil {
return res, fmt.Errorf("failed to get active balance: %v", err)
}
balanceValue, err = strconv.ParseFloat(string(activeBal), 64)
if err != nil {
return res, fmt.Errorf("failed to parse active balance: %v", err)
}
} else {
// If no active symbol, use the current balance from the API
balanceStr, err := h.accountService.CheckBalance(string(publicKey))
if err != nil {
return res, fmt.Errorf("failed to check balance: %v", err)
}
res.Content = balanceStr
// Parse the balance string
balanceParts := strings.Split(balanceStr, " ")
if len(balanceParts) != 2 {
return res, fmt.Errorf("unexpected balance format: %s", balanceStr)
}
balanceValue, err = strconv.ParseFloat(balanceParts[0], 64)
if err != nil {
return res, fmt.Errorf("failed to parse balance: %v", err)
}
} }
// Extract numeric part from the input amount // Extract numeric part from input
amountStr := strings.TrimSpace(string(input)) re := regexp.MustCompile(`^(\d+(\.\d+)?)\s*(?:CELO)?$`)
inputAmount, err := strconv.ParseFloat(amountStr, 64) matches := re.FindStringSubmatch(strings.TrimSpace(amountStr))
if len(matches) < 2 {
res.FlagSet = append(res.FlagSet, flag_invalid_amount)
res.Content = amountStr
return res, nil
}
inputAmount, err := strconv.ParseFloat(matches[1], 64)
if err != nil { if err != nil {
res.FlagSet = append(res.FlagSet, flag_invalid_amount) res.FlagSet = append(res.FlagSet, flag_invalid_amount)
res.Content = amountStr res.Content = amountStr
@ -827,12 +841,12 @@ func (h *Handlers) ValidateAmount(ctx context.Context, sym string, input []byte)
return res, nil return res, nil
} }
res.Content = fmt.Sprintf("%.3f", inputAmount) // Format to 3 decimal places
err = store.WriteEntry(ctx, sessionId, utils.DATA_AMOUNT, []byte(amountStr)) err = store.WriteEntry(ctx, sessionId, utils.DATA_AMOUNT, []byte(amountStr))
if err != nil { if err != nil {
return res, err return res, err
} }
res.Content = fmt.Sprintf("%.3f", inputAmount)
return res, nil return res, nil
} }

View File

@ -434,6 +434,34 @@ func TestCheckIdentifier(t *testing.T) {
} }
} }
func TestMaxAmount(t *testing.T) {
mockStore := new(mocks.MockUserDataStore)
mockCreateAccountService := new(mocks.MockAccountService)
// Define test data
sessionId := "session123"
ctx := context.WithValue(context.Background(), "SessionId", sessionId)
publicKey := "0xcasgatweksalw1018221"
expectedBalance := "0.003CELO"
// Set up the expected behavior of the mock
mockStore.On("ReadEntry", ctx, sessionId, utils.DATA_PUBLIC_KEY).Return([]byte(publicKey), nil)
mockCreateAccountService.On("CheckBalance", publicKey).Return(expectedBalance, nil)
// Create the Handlers instance with the mock store
h := &Handlers{
userdataStore: mockStore,
accountService: mockCreateAccountService,
}
// Call the method
res, _ := h.MaxAmount(ctx, "max_amount", []byte("check_balance"))
//Assert that the balance that was set as the result content is what was returned by Check Balance
assert.Equal(t, expectedBalance, res.Content)
}
func TestGetSender(t *testing.T) { func TestGetSender(t *testing.T) {
mockStore := new(mocks.MockUserDataStore) mockStore := new(mocks.MockUserDataStore)
@ -1416,81 +1444,59 @@ func TestValidateAmount(t *testing.T) {
name string name string
input []byte input []byte
publicKey []byte publicKey []byte
activeSym []byte
activeBal []byte
balance string balance string
expectedResult resource.Result expectedResult resource.Result
}{ }{
{ {
name: "Test with valid amount and active symbol", name: "Test with valid amount",
input: []byte("0.001"), input: []byte("0.001"),
balance: "0.003 CELO",
publicKey: []byte("0xrqeqrequuq"), publicKey: []byte("0xrqeqrequuq"),
activeSym: []byte("CELO"),
activeBal: []byte("0.003"),
expectedResult: resource.Result{ expectedResult: resource.Result{
Content: "0.001", Content: "0.001",
}, },
}, },
{ {
name: "Test with amount larger than active balance", name: "Test with amount larger than balance",
input: []byte("0.02"), input: []byte("0.02"),
balance: "0.003 CELO",
publicKey: []byte("0xrqeqrequuq"), publicKey: []byte("0xrqeqrequuq"),
activeSym: []byte("CELO"),
activeBal: []byte("0.003"),
expectedResult: resource.Result{ expectedResult: resource.Result{
FlagSet: []uint32{flag_invalid_amount}, FlagSet: []uint32{flag_invalid_amount},
Content: "0.02", Content: "0.02",
}, },
}, },
{ {
name: "Test with invalid amount format", name: "Test with invalid amount",
input: []byte("0.02ms"), input: []byte("0.02ms"),
publicKey: []byte("0xrqeqrequuq"),
balance: "0.003 CELO", balance: "0.003 CELO",
publicKey: []byte("0xrqeqrequuq"),
expectedResult: resource.Result{ expectedResult: resource.Result{
FlagSet: []uint32{flag_invalid_amount}, FlagSet: []uint32{flag_invalid_amount},
Content: "0.02ms", Content: "0.02ms",
}, },
}, },
{
name: "Test fallback to current balance without active symbol",
input: []byte("0.001"),
publicKey: []byte("0xrqeqrequuq"),
balance: "0.003 CELO",
expectedResult: resource.Result{
Content: "0.001",
},
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
// Mock behavior for public key retrieval
mockDataStore.On("ReadEntry", ctx, sessionId, utils.DATA_PUBLIC_KEY).Return(tt.publicKey, nil) mockDataStore.On("ReadEntry", ctx, sessionId, utils.DATA_PUBLIC_KEY).Return(tt.publicKey, nil)
mockCreateAccountService.On("CheckBalance", string(tt.publicKey)).Return(tt.balance, nil)
// Mock behavior for active symbol and balance retrieval (if present)
if len(tt.activeSym) > 0 {
mockDataStore.On("ReadEntry", ctx, sessionId, utils.DATA_ACTIVE_SYM).Return(tt.activeSym, nil)
mockDataStore.On("ReadEntry", ctx, sessionId, utils.DATA_ACTIVE_BAL).Return(tt.activeBal, nil)
} else {
mockDataStore.On("ReadEntry", ctx, sessionId, utils.DATA_ACTIVE_SYM).Return(nil, fmt.Errorf("not found"))
mockCreateAccountService.On("CheckBalance", string(tt.publicKey)).Return(tt.balance, nil)
}
// Mock behavior for storing the amount (if valid)
mockDataStore.On("WriteEntry", ctx, sessionId, utils.DATA_AMOUNT, tt.input).Return(nil).Maybe() mockDataStore.On("WriteEntry", ctx, sessionId, utils.DATA_AMOUNT, tt.input).Return(nil).Maybe()
// Call the method under test // Call the method under test
res, _ := h.ValidateAmount(ctx, "test_validate_amount", tt.input) res, _ := h.ValidateAmount(ctx, "test_validate_amount", tt.input)
// Assert no errors occurred // Assert that no errors occurred
assert.NoError(t, err) assert.NoError(t, err)
// Assert the result matches the expected result //Assert that the account created flag has been set to the result
assert.Equal(t, tt.expectedResult, res, "Expected result should match actual result") assert.Equal(t, res, tt.expectedResult, "Expected result should be equal to the actual result")
// Assert all expectations were met // Assert that expectations were met
mockDataStore.AssertExpectations(t) mockDataStore.AssertExpectations(t)
}) })
} }
} }

View File

@ -1,2 +1,2 @@
Maximum amount: {{.check_balance}} Maximum amount: {{.max_amount}}
Enter amount: Enter amount:

View File

@ -1,6 +1,6 @@
LOAD reset_transaction_amount 0 LOAD reset_transaction_amount 0
LOAD check_balance 48 LOAD max_amount 10
MAP check_balance MAP max_amount
MOUT back 0 MOUT back 0
HALT HALT
LOAD validate_amount 64 LOAD validate_amount 64

View File

@ -1,2 +1,2 @@
Kiwango cha juu: {{.check_balance}} Kiwango cha juu: {{.max_amount}}
Weka kiwango: Weka kiwango: