updated the ValidateAmount to also check the active symbol, updated tests

This commit is contained in:
Alfred Kamanda 2024-10-12 20:07:06 +03:00
parent f5dbfe553d
commit 7df77a1343
Signed by untrusted user: Alfred-mk
GPG Key ID: 7EA3D01708908703
2 changed files with 72 additions and 92 deletions

View File

@ -761,74 +761,60 @@ func (h *Handlers) ResetTransactionAmount(ctx context.Context, sym string, input
return res, nil
}
// MaxAmount gets the current balance from the API and sets it as
// the result content.
func (h *Handlers) MaxAmount(ctx context.Context, sym string, input []byte) (resource.Result, error) {
var res resource.Result
var err error
sessionId, ok := ctx.Value("SessionId").(string)
if !ok {
return res, fmt.Errorf("missing session")
}
store := h.userdataStore
publicKey, _ := store.ReadEntry(ctx, sessionId, utils.DATA_PUBLIC_KEY)
balance, err := h.accountService.CheckBalance(string(publicKey))
if err != nil {
return res, nil
}
res.Content = balance
return res, nil
}
// ValidateAmount ensures that the given input is a valid amount and that
// it is not more than the current balance.
func (h *Handlers) ValidateAmount(ctx context.Context, sym string, input []byte) (resource.Result, error) {
var res resource.Result
var err error
sessionId, ok := ctx.Value("SessionId").(string)
if !ok {
return res, fmt.Errorf("missing session")
}
flag_invalid_amount, _ := h.flagManager.GetFlag("flag_invalid_amount")
store := h.userdataStore
publicKey, _ := store.ReadEntry(ctx, sessionId, utils.DATA_PUBLIC_KEY)
amountStr := string(input)
balanceStr, err := h.accountService.CheckBalance(string(publicKey))
publicKey, err := store.ReadEntry(ctx, sessionId, utils.DATA_PUBLIC_KEY)
if err != nil {
return res, err
}
res.Content = balanceStr
// Parse the balance
balanceParts := strings.Split(balanceStr, " ")
if len(balanceParts) != 2 {
return res, fmt.Errorf("unexpected balance format: %s", balanceStr)
}
balanceValue, err := strconv.ParseFloat(balanceParts[0], 64)
if err != nil {
return res, fmt.Errorf("failed to parse balance: %v", err)
// retrieve the active symbol
activeSym, err := store.ReadEntry(ctx, sessionId, utils.DATA_ACTIVE_SYM)
useActiveSymbol := err == nil && len(activeSym) > 0
var balanceValue float64
if useActiveSymbol {
// If active symbol is set, retrieve its balance
activeBal, err := store.ReadEntry(ctx, sessionId, utils.DATA_ACTIVE_BAL)
if err != nil {
return res, fmt.Errorf("failed to get active balance: %v", err)
}
balanceValue, err = strconv.ParseFloat(string(activeBal), 64)
if err != nil {
return res, fmt.Errorf("failed to parse active balance: %v", err)
}
} else {
// If no active symbol, use the current balance from the API
balanceStr, err := h.accountService.CheckBalance(string(publicKey))
if err != nil {
return res, fmt.Errorf("failed to check balance: %v", err)
}
res.Content = balanceStr
// Parse the balance string
balanceParts := strings.Split(balanceStr, " ")
if len(balanceParts) != 2 {
return res, fmt.Errorf("unexpected balance format: %s", balanceStr)
}
balanceValue, err = strconv.ParseFloat(balanceParts[0], 64)
if err != nil {
return res, fmt.Errorf("failed to parse balance: %v", err)
}
}
// Extract numeric part from input
re := regexp.MustCompile(`^(\d+(\.\d+)?)\s*(?:CELO)?$`)
matches := re.FindStringSubmatch(strings.TrimSpace(amountStr))
if len(matches) < 2 {
res.FlagSet = append(res.FlagSet, flag_invalid_amount)
res.Content = amountStr
return res, nil
}
inputAmount, err := strconv.ParseFloat(matches[1], 64)
// Extract numeric part from the input amount
amountStr := strings.TrimSpace(string(input))
inputAmount, err := strconv.ParseFloat(amountStr, 64)
if err != nil {
res.FlagSet = append(res.FlagSet, flag_invalid_amount)
res.Content = amountStr
@ -841,12 +827,12 @@ func (h *Handlers) ValidateAmount(ctx context.Context, sym string, input []byte)
return res, nil
}
res.Content = fmt.Sprintf("%.3f", inputAmount) // Format to 3 decimal places
err = store.WriteEntry(ctx, sessionId, utils.DATA_AMOUNT, []byte(amountStr))
if err != nil {
return res, err
}
res.Content = fmt.Sprintf("%.3f", inputAmount)
return res, nil
}

View File

@ -434,34 +434,6 @@ func TestCheckIdentifier(t *testing.T) {
}
}
func TestMaxAmount(t *testing.T) {
mockStore := new(mocks.MockUserDataStore)
mockCreateAccountService := new(mocks.MockAccountService)
// Define test data
sessionId := "session123"
ctx := context.WithValue(context.Background(), "SessionId", sessionId)
publicKey := "0xcasgatweksalw1018221"
expectedBalance := "0.003CELO"
// Set up the expected behavior of the mock
mockStore.On("ReadEntry", ctx, sessionId, utils.DATA_PUBLIC_KEY).Return([]byte(publicKey), nil)
mockCreateAccountService.On("CheckBalance", publicKey).Return(expectedBalance, nil)
// Create the Handlers instance with the mock store
h := &Handlers{
userdataStore: mockStore,
accountService: mockCreateAccountService,
}
// Call the method
res, _ := h.MaxAmount(ctx, "max_amount", []byte("check_balance"))
//Assert that the balance that was set as the result content is what was returned by Check Balance
assert.Equal(t, expectedBalance, res.Content)
}
func TestGetSender(t *testing.T) {
mockStore := new(mocks.MockUserDataStore)
@ -1444,59 +1416,81 @@ func TestValidateAmount(t *testing.T) {
name string
input []byte
publicKey []byte
activeSym []byte
activeBal []byte
balance string
expectedResult resource.Result
}{
{
name: "Test with valid amount",
name: "Test with valid amount and active symbol",
input: []byte("0.001"),
balance: "0.003 CELO",
publicKey: []byte("0xrqeqrequuq"),
activeSym: []byte("CELO"),
activeBal: []byte("0.003"),
expectedResult: resource.Result{
Content: "0.001",
},
},
{
name: "Test with amount larger than balance",
name: "Test with amount larger than active balance",
input: []byte("0.02"),
balance: "0.003 CELO",
publicKey: []byte("0xrqeqrequuq"),
activeSym: []byte("CELO"),
activeBal: []byte("0.003"),
expectedResult: resource.Result{
FlagSet: []uint32{flag_invalid_amount},
Content: "0.02",
},
},
{
name: "Test with invalid amount",
name: "Test with invalid amount format",
input: []byte("0.02ms"),
balance: "0.003 CELO",
publicKey: []byte("0xrqeqrequuq"),
balance: "0.003 CELO",
expectedResult: resource.Result{
FlagSet: []uint32{flag_invalid_amount},
Content: "0.02ms",
},
},
{
name: "Test fallback to current balance without active symbol",
input: []byte("0.001"),
publicKey: []byte("0xrqeqrequuq"),
balance: "0.003 CELO",
expectedResult: resource.Result{
Content: "0.001",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Mock behavior for public key retrieval
mockDataStore.On("ReadEntry", ctx, sessionId, utils.DATA_PUBLIC_KEY).Return(tt.publicKey, nil)
mockCreateAccountService.On("CheckBalance", string(tt.publicKey)).Return(tt.balance, nil)
// Mock behavior for active symbol and balance retrieval (if present)
if len(tt.activeSym) > 0 {
mockDataStore.On("ReadEntry", ctx, sessionId, utils.DATA_ACTIVE_SYM).Return(tt.activeSym, nil)
mockDataStore.On("ReadEntry", ctx, sessionId, utils.DATA_ACTIVE_BAL).Return(tt.activeBal, nil)
} else {
mockDataStore.On("ReadEntry", ctx, sessionId, utils.DATA_ACTIVE_SYM).Return(nil, fmt.Errorf("not found"))
mockCreateAccountService.On("CheckBalance", string(tt.publicKey)).Return(tt.balance, nil)
}
// Mock behavior for storing the amount (if valid)
mockDataStore.On("WriteEntry", ctx, sessionId, utils.DATA_AMOUNT, tt.input).Return(nil).Maybe()
// Call the method under test
res, _ := h.ValidateAmount(ctx, "test_validate_amount", tt.input)
// Assert that no errors occurred
// Assert no errors occurred
assert.NoError(t, err)
//Assert that the account created flag has been set to the result
assert.Equal(t, res, tt.expectedResult, "Expected result should be equal to the actual result")
// Assert the result matches the expected result
assert.Equal(t, tt.expectedResult, res, "Expected result should match actual result")
// Assert that expectations were met
// Assert all expectations were met
mockDataStore.AssertExpectations(t)
})
}
}