forked from urdt/ussd
updated the ValidateAmount to also check the active symbol, updated tests
This commit is contained in:
parent
f5dbfe553d
commit
7df77a1343
@ -761,74 +761,60 @@ func (h *Handlers) ResetTransactionAmount(ctx context.Context, sym string, input
|
|||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// MaxAmount gets the current balance from the API and sets it as
|
|
||||||
// the result content.
|
|
||||||
func (h *Handlers) MaxAmount(ctx context.Context, sym string, input []byte) (resource.Result, error) {
|
|
||||||
var res resource.Result
|
|
||||||
var err error
|
|
||||||
|
|
||||||
sessionId, ok := ctx.Value("SessionId").(string)
|
|
||||||
if !ok {
|
|
||||||
return res, fmt.Errorf("missing session")
|
|
||||||
}
|
|
||||||
store := h.userdataStore
|
|
||||||
publicKey, _ := store.ReadEntry(ctx, sessionId, utils.DATA_PUBLIC_KEY)
|
|
||||||
|
|
||||||
balance, err := h.accountService.CheckBalance(string(publicKey))
|
|
||||||
if err != nil {
|
|
||||||
return res, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
res.Content = balance
|
|
||||||
|
|
||||||
return res, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ValidateAmount ensures that the given input is a valid amount and that
|
// ValidateAmount ensures that the given input is a valid amount and that
|
||||||
// it is not more than the current balance.
|
// it is not more than the current balance.
|
||||||
func (h *Handlers) ValidateAmount(ctx context.Context, sym string, input []byte) (resource.Result, error) {
|
func (h *Handlers) ValidateAmount(ctx context.Context, sym string, input []byte) (resource.Result, error) {
|
||||||
var res resource.Result
|
var res resource.Result
|
||||||
var err error
|
|
||||||
|
|
||||||
sessionId, ok := ctx.Value("SessionId").(string)
|
sessionId, ok := ctx.Value("SessionId").(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return res, fmt.Errorf("missing session")
|
return res, fmt.Errorf("missing session")
|
||||||
}
|
}
|
||||||
|
|
||||||
flag_invalid_amount, _ := h.flagManager.GetFlag("flag_invalid_amount")
|
flag_invalid_amount, _ := h.flagManager.GetFlag("flag_invalid_amount")
|
||||||
|
|
||||||
store := h.userdataStore
|
store := h.userdataStore
|
||||||
publicKey, _ := store.ReadEntry(ctx, sessionId, utils.DATA_PUBLIC_KEY)
|
|
||||||
|
|
||||||
amountStr := string(input)
|
|
||||||
|
|
||||||
balanceStr, err := h.accountService.CheckBalance(string(publicKey))
|
|
||||||
|
|
||||||
|
publicKey, err := store.ReadEntry(ctx, sessionId, utils.DATA_PUBLIC_KEY)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return res, err
|
return res, err
|
||||||
}
|
}
|
||||||
res.Content = balanceStr
|
|
||||||
|
|
||||||
// Parse the balance
|
// retrieve the active symbol
|
||||||
balanceParts := strings.Split(balanceStr, " ")
|
activeSym, err := store.ReadEntry(ctx, sessionId, utils.DATA_ACTIVE_SYM)
|
||||||
if len(balanceParts) != 2 {
|
useActiveSymbol := err == nil && len(activeSym) > 0
|
||||||
return res, fmt.Errorf("unexpected balance format: %s", balanceStr)
|
|
||||||
}
|
var balanceValue float64
|
||||||
balanceValue, err := strconv.ParseFloat(balanceParts[0], 64)
|
if useActiveSymbol {
|
||||||
if err != nil {
|
// If active symbol is set, retrieve its balance
|
||||||
return res, fmt.Errorf("failed to parse balance: %v", err)
|
activeBal, err := store.ReadEntry(ctx, sessionId, utils.DATA_ACTIVE_BAL)
|
||||||
|
if err != nil {
|
||||||
|
return res, fmt.Errorf("failed to get active balance: %v", err)
|
||||||
|
}
|
||||||
|
balanceValue, err = strconv.ParseFloat(string(activeBal), 64)
|
||||||
|
if err != nil {
|
||||||
|
return res, fmt.Errorf("failed to parse active balance: %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// If no active symbol, use the current balance from the API
|
||||||
|
balanceStr, err := h.accountService.CheckBalance(string(publicKey))
|
||||||
|
if err != nil {
|
||||||
|
return res, fmt.Errorf("failed to check balance: %v", err)
|
||||||
|
}
|
||||||
|
res.Content = balanceStr
|
||||||
|
|
||||||
|
// Parse the balance string
|
||||||
|
balanceParts := strings.Split(balanceStr, " ")
|
||||||
|
if len(balanceParts) != 2 {
|
||||||
|
return res, fmt.Errorf("unexpected balance format: %s", balanceStr)
|
||||||
|
}
|
||||||
|
balanceValue, err = strconv.ParseFloat(balanceParts[0], 64)
|
||||||
|
if err != nil {
|
||||||
|
return res, fmt.Errorf("failed to parse balance: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract numeric part from input
|
// Extract numeric part from the input amount
|
||||||
re := regexp.MustCompile(`^(\d+(\.\d+)?)\s*(?:CELO)?$`)
|
amountStr := strings.TrimSpace(string(input))
|
||||||
matches := re.FindStringSubmatch(strings.TrimSpace(amountStr))
|
inputAmount, err := strconv.ParseFloat(amountStr, 64)
|
||||||
if len(matches) < 2 {
|
|
||||||
res.FlagSet = append(res.FlagSet, flag_invalid_amount)
|
|
||||||
res.Content = amountStr
|
|
||||||
return res, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
inputAmount, err := strconv.ParseFloat(matches[1], 64)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
res.FlagSet = append(res.FlagSet, flag_invalid_amount)
|
res.FlagSet = append(res.FlagSet, flag_invalid_amount)
|
||||||
res.Content = amountStr
|
res.Content = amountStr
|
||||||
@ -841,12 +827,12 @@ func (h *Handlers) ValidateAmount(ctx context.Context, sym string, input []byte)
|
|||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
res.Content = fmt.Sprintf("%.3f", inputAmount) // Format to 3 decimal places
|
|
||||||
err = store.WriteEntry(ctx, sessionId, utils.DATA_AMOUNT, []byte(amountStr))
|
err = store.WriteEntry(ctx, sessionId, utils.DATA_AMOUNT, []byte(amountStr))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return res, err
|
return res, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
res.Content = fmt.Sprintf("%.3f", inputAmount)
|
||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -434,34 +434,6 @@ func TestCheckIdentifier(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMaxAmount(t *testing.T) {
|
|
||||||
mockStore := new(mocks.MockUserDataStore)
|
|
||||||
mockCreateAccountService := new(mocks.MockAccountService)
|
|
||||||
|
|
||||||
// Define test data
|
|
||||||
sessionId := "session123"
|
|
||||||
ctx := context.WithValue(context.Background(), "SessionId", sessionId)
|
|
||||||
publicKey := "0xcasgatweksalw1018221"
|
|
||||||
expectedBalance := "0.003CELO"
|
|
||||||
|
|
||||||
// Set up the expected behavior of the mock
|
|
||||||
mockStore.On("ReadEntry", ctx, sessionId, utils.DATA_PUBLIC_KEY).Return([]byte(publicKey), nil)
|
|
||||||
mockCreateAccountService.On("CheckBalance", publicKey).Return(expectedBalance, nil)
|
|
||||||
|
|
||||||
// Create the Handlers instance with the mock store
|
|
||||||
h := &Handlers{
|
|
||||||
userdataStore: mockStore,
|
|
||||||
accountService: mockCreateAccountService,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Call the method
|
|
||||||
res, _ := h.MaxAmount(ctx, "max_amount", []byte("check_balance"))
|
|
||||||
|
|
||||||
//Assert that the balance that was set as the result content is what was returned by Check Balance
|
|
||||||
assert.Equal(t, expectedBalance, res.Content)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetSender(t *testing.T) {
|
func TestGetSender(t *testing.T) {
|
||||||
mockStore := new(mocks.MockUserDataStore)
|
mockStore := new(mocks.MockUserDataStore)
|
||||||
|
|
||||||
@ -1444,59 +1416,81 @@ func TestValidateAmount(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
input []byte
|
input []byte
|
||||||
publicKey []byte
|
publicKey []byte
|
||||||
|
activeSym []byte
|
||||||
|
activeBal []byte
|
||||||
balance string
|
balance string
|
||||||
expectedResult resource.Result
|
expectedResult resource.Result
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Test with valid amount",
|
name: "Test with valid amount and active symbol",
|
||||||
input: []byte("0.001"),
|
input: []byte("0.001"),
|
||||||
balance: "0.003 CELO",
|
|
||||||
publicKey: []byte("0xrqeqrequuq"),
|
publicKey: []byte("0xrqeqrequuq"),
|
||||||
|
activeSym: []byte("CELO"),
|
||||||
|
activeBal: []byte("0.003"),
|
||||||
expectedResult: resource.Result{
|
expectedResult: resource.Result{
|
||||||
Content: "0.001",
|
Content: "0.001",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Test with amount larger than balance",
|
name: "Test with amount larger than active balance",
|
||||||
input: []byte("0.02"),
|
input: []byte("0.02"),
|
||||||
balance: "0.003 CELO",
|
|
||||||
publicKey: []byte("0xrqeqrequuq"),
|
publicKey: []byte("0xrqeqrequuq"),
|
||||||
|
activeSym: []byte("CELO"),
|
||||||
|
activeBal: []byte("0.003"),
|
||||||
expectedResult: resource.Result{
|
expectedResult: resource.Result{
|
||||||
FlagSet: []uint32{flag_invalid_amount},
|
FlagSet: []uint32{flag_invalid_amount},
|
||||||
Content: "0.02",
|
Content: "0.02",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Test with invalid amount",
|
name: "Test with invalid amount format",
|
||||||
input: []byte("0.02ms"),
|
input: []byte("0.02ms"),
|
||||||
balance: "0.003 CELO",
|
|
||||||
publicKey: []byte("0xrqeqrequuq"),
|
publicKey: []byte("0xrqeqrequuq"),
|
||||||
|
balance: "0.003 CELO",
|
||||||
expectedResult: resource.Result{
|
expectedResult: resource.Result{
|
||||||
FlagSet: []uint32{flag_invalid_amount},
|
FlagSet: []uint32{flag_invalid_amount},
|
||||||
Content: "0.02ms",
|
Content: "0.02ms",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "Test fallback to current balance without active symbol",
|
||||||
|
input: []byte("0.001"),
|
||||||
|
publicKey: []byte("0xrqeqrequuq"),
|
||||||
|
balance: "0.003 CELO",
|
||||||
|
expectedResult: resource.Result{
|
||||||
|
Content: "0.001",
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Mock behavior for public key retrieval
|
||||||
mockDataStore.On("ReadEntry", ctx, sessionId, utils.DATA_PUBLIC_KEY).Return(tt.publicKey, nil)
|
mockDataStore.On("ReadEntry", ctx, sessionId, utils.DATA_PUBLIC_KEY).Return(tt.publicKey, nil)
|
||||||
mockCreateAccountService.On("CheckBalance", string(tt.publicKey)).Return(tt.balance, nil)
|
|
||||||
|
// Mock behavior for active symbol and balance retrieval (if present)
|
||||||
|
if len(tt.activeSym) > 0 {
|
||||||
|
mockDataStore.On("ReadEntry", ctx, sessionId, utils.DATA_ACTIVE_SYM).Return(tt.activeSym, nil)
|
||||||
|
mockDataStore.On("ReadEntry", ctx, sessionId, utils.DATA_ACTIVE_BAL).Return(tt.activeBal, nil)
|
||||||
|
} else {
|
||||||
|
mockDataStore.On("ReadEntry", ctx, sessionId, utils.DATA_ACTIVE_SYM).Return(nil, fmt.Errorf("not found"))
|
||||||
|
mockCreateAccountService.On("CheckBalance", string(tt.publicKey)).Return(tt.balance, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mock behavior for storing the amount (if valid)
|
||||||
mockDataStore.On("WriteEntry", ctx, sessionId, utils.DATA_AMOUNT, tt.input).Return(nil).Maybe()
|
mockDataStore.On("WriteEntry", ctx, sessionId, utils.DATA_AMOUNT, tt.input).Return(nil).Maybe()
|
||||||
|
|
||||||
// Call the method under test
|
// Call the method under test
|
||||||
res, _ := h.ValidateAmount(ctx, "test_validate_amount", tt.input)
|
res, _ := h.ValidateAmount(ctx, "test_validate_amount", tt.input)
|
||||||
|
|
||||||
// Assert that no errors occurred
|
// Assert no errors occurred
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
//Assert that the account created flag has been set to the result
|
// Assert the result matches the expected result
|
||||||
assert.Equal(t, res, tt.expectedResult, "Expected result should be equal to the actual result")
|
assert.Equal(t, tt.expectedResult, res, "Expected result should match actual result")
|
||||||
|
|
||||||
// Assert that expectations were met
|
// Assert all expectations were met
|
||||||
mockDataStore.AssertExpectations(t)
|
mockDataStore.AssertExpectations(t)
|
||||||
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user