forked from urdt/ussd
updated the ValidateAmount to also check the active symbol, updated tests
This commit is contained in:
parent
f5dbfe553d
commit
7df77a1343
@ -761,74 +761,60 @@ func (h *Handlers) ResetTransactionAmount(ctx context.Context, sym string, input
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// MaxAmount gets the current balance from the API and sets it as
|
||||
// the result content.
|
||||
func (h *Handlers) MaxAmount(ctx context.Context, sym string, input []byte) (resource.Result, error) {
|
||||
var res resource.Result
|
||||
var err error
|
||||
|
||||
sessionId, ok := ctx.Value("SessionId").(string)
|
||||
if !ok {
|
||||
return res, fmt.Errorf("missing session")
|
||||
}
|
||||
store := h.userdataStore
|
||||
publicKey, _ := store.ReadEntry(ctx, sessionId, utils.DATA_PUBLIC_KEY)
|
||||
|
||||
balance, err := h.accountService.CheckBalance(string(publicKey))
|
||||
if err != nil {
|
||||
return res, nil
|
||||
}
|
||||
|
||||
res.Content = balance
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// ValidateAmount ensures that the given input is a valid amount and that
|
||||
// it is not more than the current balance.
|
||||
func (h *Handlers) ValidateAmount(ctx context.Context, sym string, input []byte) (resource.Result, error) {
|
||||
var res resource.Result
|
||||
var err error
|
||||
|
||||
sessionId, ok := ctx.Value("SessionId").(string)
|
||||
if !ok {
|
||||
return res, fmt.Errorf("missing session")
|
||||
}
|
||||
|
||||
flag_invalid_amount, _ := h.flagManager.GetFlag("flag_invalid_amount")
|
||||
|
||||
store := h.userdataStore
|
||||
publicKey, _ := store.ReadEntry(ctx, sessionId, utils.DATA_PUBLIC_KEY)
|
||||
|
||||
amountStr := string(input)
|
||||
|
||||
balanceStr, err := h.accountService.CheckBalance(string(publicKey))
|
||||
|
||||
publicKey, err := store.ReadEntry(ctx, sessionId, utils.DATA_PUBLIC_KEY)
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
res.Content = balanceStr
|
||||
|
||||
// Parse the balance
|
||||
balanceParts := strings.Split(balanceStr, " ")
|
||||
if len(balanceParts) != 2 {
|
||||
return res, fmt.Errorf("unexpected balance format: %s", balanceStr)
|
||||
}
|
||||
balanceValue, err := strconv.ParseFloat(balanceParts[0], 64)
|
||||
if err != nil {
|
||||
return res, fmt.Errorf("failed to parse balance: %v", err)
|
||||
// retrieve the active symbol
|
||||
activeSym, err := store.ReadEntry(ctx, sessionId, utils.DATA_ACTIVE_SYM)
|
||||
useActiveSymbol := err == nil && len(activeSym) > 0
|
||||
|
||||
var balanceValue float64
|
||||
if useActiveSymbol {
|
||||
// If active symbol is set, retrieve its balance
|
||||
activeBal, err := store.ReadEntry(ctx, sessionId, utils.DATA_ACTIVE_BAL)
|
||||
if err != nil {
|
||||
return res, fmt.Errorf("failed to get active balance: %v", err)
|
||||
}
|
||||
balanceValue, err = strconv.ParseFloat(string(activeBal), 64)
|
||||
if err != nil {
|
||||
return res, fmt.Errorf("failed to parse active balance: %v", err)
|
||||
}
|
||||
} else {
|
||||
// If no active symbol, use the current balance from the API
|
||||
balanceStr, err := h.accountService.CheckBalance(string(publicKey))
|
||||
if err != nil {
|
||||
return res, fmt.Errorf("failed to check balance: %v", err)
|
||||
}
|
||||
res.Content = balanceStr
|
||||
|
||||
// Parse the balance string
|
||||
balanceParts := strings.Split(balanceStr, " ")
|
||||
if len(balanceParts) != 2 {
|
||||
return res, fmt.Errorf("unexpected balance format: %s", balanceStr)
|
||||
}
|
||||
balanceValue, err = strconv.ParseFloat(balanceParts[0], 64)
|
||||
if err != nil {
|
||||
return res, fmt.Errorf("failed to parse balance: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Extract numeric part from input
|
||||
re := regexp.MustCompile(`^(\d+(\.\d+)?)\s*(?:CELO)?$`)
|
||||
matches := re.FindStringSubmatch(strings.TrimSpace(amountStr))
|
||||
if len(matches) < 2 {
|
||||
res.FlagSet = append(res.FlagSet, flag_invalid_amount)
|
||||
res.Content = amountStr
|
||||
return res, nil
|
||||
}
|
||||
|
||||
inputAmount, err := strconv.ParseFloat(matches[1], 64)
|
||||
// Extract numeric part from the input amount
|
||||
amountStr := strings.TrimSpace(string(input))
|
||||
inputAmount, err := strconv.ParseFloat(amountStr, 64)
|
||||
if err != nil {
|
||||
res.FlagSet = append(res.FlagSet, flag_invalid_amount)
|
||||
res.Content = amountStr
|
||||
@ -841,12 +827,12 @@ func (h *Handlers) ValidateAmount(ctx context.Context, sym string, input []byte)
|
||||
return res, nil
|
||||
}
|
||||
|
||||
res.Content = fmt.Sprintf("%.3f", inputAmount) // Format to 3 decimal places
|
||||
err = store.WriteEntry(ctx, sessionId, utils.DATA_AMOUNT, []byte(amountStr))
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
|
||||
res.Content = fmt.Sprintf("%.3f", inputAmount)
|
||||
return res, nil
|
||||
}
|
||||
|
||||
|
@ -434,34 +434,6 @@ func TestCheckIdentifier(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaxAmount(t *testing.T) {
|
||||
mockStore := new(mocks.MockUserDataStore)
|
||||
mockCreateAccountService := new(mocks.MockAccountService)
|
||||
|
||||
// Define test data
|
||||
sessionId := "session123"
|
||||
ctx := context.WithValue(context.Background(), "SessionId", sessionId)
|
||||
publicKey := "0xcasgatweksalw1018221"
|
||||
expectedBalance := "0.003CELO"
|
||||
|
||||
// Set up the expected behavior of the mock
|
||||
mockStore.On("ReadEntry", ctx, sessionId, utils.DATA_PUBLIC_KEY).Return([]byte(publicKey), nil)
|
||||
mockCreateAccountService.On("CheckBalance", publicKey).Return(expectedBalance, nil)
|
||||
|
||||
// Create the Handlers instance with the mock store
|
||||
h := &Handlers{
|
||||
userdataStore: mockStore,
|
||||
accountService: mockCreateAccountService,
|
||||
}
|
||||
|
||||
// Call the method
|
||||
res, _ := h.MaxAmount(ctx, "max_amount", []byte("check_balance"))
|
||||
|
||||
//Assert that the balance that was set as the result content is what was returned by Check Balance
|
||||
assert.Equal(t, expectedBalance, res.Content)
|
||||
|
||||
}
|
||||
|
||||
func TestGetSender(t *testing.T) {
|
||||
mockStore := new(mocks.MockUserDataStore)
|
||||
|
||||
@ -1444,59 +1416,81 @@ func TestValidateAmount(t *testing.T) {
|
||||
name string
|
||||
input []byte
|
||||
publicKey []byte
|
||||
activeSym []byte
|
||||
activeBal []byte
|
||||
balance string
|
||||
expectedResult resource.Result
|
||||
}{
|
||||
{
|
||||
name: "Test with valid amount",
|
||||
name: "Test with valid amount and active symbol",
|
||||
input: []byte("0.001"),
|
||||
balance: "0.003 CELO",
|
||||
publicKey: []byte("0xrqeqrequuq"),
|
||||
activeSym: []byte("CELO"),
|
||||
activeBal: []byte("0.003"),
|
||||
expectedResult: resource.Result{
|
||||
Content: "0.001",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Test with amount larger than balance",
|
||||
name: "Test with amount larger than active balance",
|
||||
input: []byte("0.02"),
|
||||
balance: "0.003 CELO",
|
||||
publicKey: []byte("0xrqeqrequuq"),
|
||||
activeSym: []byte("CELO"),
|
||||
activeBal: []byte("0.003"),
|
||||
expectedResult: resource.Result{
|
||||
FlagSet: []uint32{flag_invalid_amount},
|
||||
Content: "0.02",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Test with invalid amount",
|
||||
name: "Test with invalid amount format",
|
||||
input: []byte("0.02ms"),
|
||||
balance: "0.003 CELO",
|
||||
publicKey: []byte("0xrqeqrequuq"),
|
||||
balance: "0.003 CELO",
|
||||
expectedResult: resource.Result{
|
||||
FlagSet: []uint32{flag_invalid_amount},
|
||||
Content: "0.02ms",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Test fallback to current balance without active symbol",
|
||||
input: []byte("0.001"),
|
||||
publicKey: []byte("0xrqeqrequuq"),
|
||||
balance: "0.003 CELO",
|
||||
expectedResult: resource.Result{
|
||||
Content: "0.001",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
// Mock behavior for public key retrieval
|
||||
mockDataStore.On("ReadEntry", ctx, sessionId, utils.DATA_PUBLIC_KEY).Return(tt.publicKey, nil)
|
||||
mockCreateAccountService.On("CheckBalance", string(tt.publicKey)).Return(tt.balance, nil)
|
||||
|
||||
// Mock behavior for active symbol and balance retrieval (if present)
|
||||
if len(tt.activeSym) > 0 {
|
||||
mockDataStore.On("ReadEntry", ctx, sessionId, utils.DATA_ACTIVE_SYM).Return(tt.activeSym, nil)
|
||||
mockDataStore.On("ReadEntry", ctx, sessionId, utils.DATA_ACTIVE_BAL).Return(tt.activeBal, nil)
|
||||
} else {
|
||||
mockDataStore.On("ReadEntry", ctx, sessionId, utils.DATA_ACTIVE_SYM).Return(nil, fmt.Errorf("not found"))
|
||||
mockCreateAccountService.On("CheckBalance", string(tt.publicKey)).Return(tt.balance, nil)
|
||||
}
|
||||
|
||||
// Mock behavior for storing the amount (if valid)
|
||||
mockDataStore.On("WriteEntry", ctx, sessionId, utils.DATA_AMOUNT, tt.input).Return(nil).Maybe()
|
||||
|
||||
// Call the method under test
|
||||
res, _ := h.ValidateAmount(ctx, "test_validate_amount", tt.input)
|
||||
|
||||
// Assert that no errors occurred
|
||||
// Assert no errors occurred
|
||||
assert.NoError(t, err)
|
||||
|
||||
//Assert that the account created flag has been set to the result
|
||||
assert.Equal(t, res, tt.expectedResult, "Expected result should be equal to the actual result")
|
||||
// Assert the result matches the expected result
|
||||
assert.Equal(t, tt.expectedResult, res, "Expected result should match actual result")
|
||||
|
||||
// Assert that expectations were met
|
||||
// Assert all expectations were met
|
||||
mockDataStore.AssertExpectations(t)
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user