updated the ValidateAmount to also check the active symbol, updated tests

This commit is contained in:
Alfred Kamanda 2024-10-12 20:07:06 +03:00
parent f5dbfe553d
commit 7df77a1343
Signed by untrusted user: Alfred-mk
GPG Key ID: 7EA3D01708908703
2 changed files with 72 additions and 92 deletions

View File

@ -761,74 +761,60 @@ func (h *Handlers) ResetTransactionAmount(ctx context.Context, sym string, input
return res, nil return res, nil
} }
// MaxAmount gets the current balance from the API and sets it as
// the result content.
func (h *Handlers) MaxAmount(ctx context.Context, sym string, input []byte) (resource.Result, error) {
var res resource.Result
var err error
sessionId, ok := ctx.Value("SessionId").(string)
if !ok {
return res, fmt.Errorf("missing session")
}
store := h.userdataStore
publicKey, _ := store.ReadEntry(ctx, sessionId, utils.DATA_PUBLIC_KEY)
balance, err := h.accountService.CheckBalance(string(publicKey))
if err != nil {
return res, nil
}
res.Content = balance
return res, nil
}
// ValidateAmount ensures that the given input is a valid amount and that // ValidateAmount ensures that the given input is a valid amount and that
// it is not more than the current balance. // it is not more than the current balance.
func (h *Handlers) ValidateAmount(ctx context.Context, sym string, input []byte) (resource.Result, error) { func (h *Handlers) ValidateAmount(ctx context.Context, sym string, input []byte) (resource.Result, error) {
var res resource.Result var res resource.Result
var err error
sessionId, ok := ctx.Value("SessionId").(string) sessionId, ok := ctx.Value("SessionId").(string)
if !ok { if !ok {
return res, fmt.Errorf("missing session") return res, fmt.Errorf("missing session")
} }
flag_invalid_amount, _ := h.flagManager.GetFlag("flag_invalid_amount") flag_invalid_amount, _ := h.flagManager.GetFlag("flag_invalid_amount")
store := h.userdataStore store := h.userdataStore
publicKey, _ := store.ReadEntry(ctx, sessionId, utils.DATA_PUBLIC_KEY)
amountStr := string(input)
balanceStr, err := h.accountService.CheckBalance(string(publicKey))
publicKey, err := store.ReadEntry(ctx, sessionId, utils.DATA_PUBLIC_KEY)
if err != nil { if err != nil {
return res, err return res, err
} }
res.Content = balanceStr
// Parse the balance // retrieve the active symbol
balanceParts := strings.Split(balanceStr, " ") activeSym, err := store.ReadEntry(ctx, sessionId, utils.DATA_ACTIVE_SYM)
if len(balanceParts) != 2 { useActiveSymbol := err == nil && len(activeSym) > 0
return res, fmt.Errorf("unexpected balance format: %s", balanceStr)
} var balanceValue float64
balanceValue, err := strconv.ParseFloat(balanceParts[0], 64) if useActiveSymbol {
if err != nil { // If active symbol is set, retrieve its balance
return res, fmt.Errorf("failed to parse balance: %v", err) activeBal, err := store.ReadEntry(ctx, sessionId, utils.DATA_ACTIVE_BAL)
if err != nil {
return res, fmt.Errorf("failed to get active balance: %v", err)
}
balanceValue, err = strconv.ParseFloat(string(activeBal), 64)
if err != nil {
return res, fmt.Errorf("failed to parse active balance: %v", err)
}
} else {
// If no active symbol, use the current balance from the API
balanceStr, err := h.accountService.CheckBalance(string(publicKey))
if err != nil {
return res, fmt.Errorf("failed to check balance: %v", err)
}
res.Content = balanceStr
// Parse the balance string
balanceParts := strings.Split(balanceStr, " ")
if len(balanceParts) != 2 {
return res, fmt.Errorf("unexpected balance format: %s", balanceStr)
}
balanceValue, err = strconv.ParseFloat(balanceParts[0], 64)
if err != nil {
return res, fmt.Errorf("failed to parse balance: %v", err)
}
} }
// Extract numeric part from input // Extract numeric part from the input amount
re := regexp.MustCompile(`^(\d+(\.\d+)?)\s*(?:CELO)?$`) amountStr := strings.TrimSpace(string(input))
matches := re.FindStringSubmatch(strings.TrimSpace(amountStr)) inputAmount, err := strconv.ParseFloat(amountStr, 64)
if len(matches) < 2 {
res.FlagSet = append(res.FlagSet, flag_invalid_amount)
res.Content = amountStr
return res, nil
}
inputAmount, err := strconv.ParseFloat(matches[1], 64)
if err != nil { if err != nil {
res.FlagSet = append(res.FlagSet, flag_invalid_amount) res.FlagSet = append(res.FlagSet, flag_invalid_amount)
res.Content = amountStr res.Content = amountStr
@ -841,12 +827,12 @@ func (h *Handlers) ValidateAmount(ctx context.Context, sym string, input []byte)
return res, nil return res, nil
} }
res.Content = fmt.Sprintf("%.3f", inputAmount) // Format to 3 decimal places
err = store.WriteEntry(ctx, sessionId, utils.DATA_AMOUNT, []byte(amountStr)) err = store.WriteEntry(ctx, sessionId, utils.DATA_AMOUNT, []byte(amountStr))
if err != nil { if err != nil {
return res, err return res, err
} }
res.Content = fmt.Sprintf("%.3f", inputAmount)
return res, nil return res, nil
} }

View File

@ -434,34 +434,6 @@ func TestCheckIdentifier(t *testing.T) {
} }
} }
func TestMaxAmount(t *testing.T) {
mockStore := new(mocks.MockUserDataStore)
mockCreateAccountService := new(mocks.MockAccountService)
// Define test data
sessionId := "session123"
ctx := context.WithValue(context.Background(), "SessionId", sessionId)
publicKey := "0xcasgatweksalw1018221"
expectedBalance := "0.003CELO"
// Set up the expected behavior of the mock
mockStore.On("ReadEntry", ctx, sessionId, utils.DATA_PUBLIC_KEY).Return([]byte(publicKey), nil)
mockCreateAccountService.On("CheckBalance", publicKey).Return(expectedBalance, nil)
// Create the Handlers instance with the mock store
h := &Handlers{
userdataStore: mockStore,
accountService: mockCreateAccountService,
}
// Call the method
res, _ := h.MaxAmount(ctx, "max_amount", []byte("check_balance"))
//Assert that the balance that was set as the result content is what was returned by Check Balance
assert.Equal(t, expectedBalance, res.Content)
}
func TestGetSender(t *testing.T) { func TestGetSender(t *testing.T) {
mockStore := new(mocks.MockUserDataStore) mockStore := new(mocks.MockUserDataStore)
@ -1444,59 +1416,81 @@ func TestValidateAmount(t *testing.T) {
name string name string
input []byte input []byte
publicKey []byte publicKey []byte
activeSym []byte
activeBal []byte
balance string balance string
expectedResult resource.Result expectedResult resource.Result
}{ }{
{ {
name: "Test with valid amount", name: "Test with valid amount and active symbol",
input: []byte("0.001"), input: []byte("0.001"),
balance: "0.003 CELO",
publicKey: []byte("0xrqeqrequuq"), publicKey: []byte("0xrqeqrequuq"),
activeSym: []byte("CELO"),
activeBal: []byte("0.003"),
expectedResult: resource.Result{ expectedResult: resource.Result{
Content: "0.001", Content: "0.001",
}, },
}, },
{ {
name: "Test with amount larger than balance", name: "Test with amount larger than active balance",
input: []byte("0.02"), input: []byte("0.02"),
balance: "0.003 CELO",
publicKey: []byte("0xrqeqrequuq"), publicKey: []byte("0xrqeqrequuq"),
activeSym: []byte("CELO"),
activeBal: []byte("0.003"),
expectedResult: resource.Result{ expectedResult: resource.Result{
FlagSet: []uint32{flag_invalid_amount}, FlagSet: []uint32{flag_invalid_amount},
Content: "0.02", Content: "0.02",
}, },
}, },
{ {
name: "Test with invalid amount", name: "Test with invalid amount format",
input: []byte("0.02ms"), input: []byte("0.02ms"),
balance: "0.003 CELO",
publicKey: []byte("0xrqeqrequuq"), publicKey: []byte("0xrqeqrequuq"),
balance: "0.003 CELO",
expectedResult: resource.Result{ expectedResult: resource.Result{
FlagSet: []uint32{flag_invalid_amount}, FlagSet: []uint32{flag_invalid_amount},
Content: "0.02ms", Content: "0.02ms",
}, },
}, },
{
name: "Test fallback to current balance without active symbol",
input: []byte("0.001"),
publicKey: []byte("0xrqeqrequuq"),
balance: "0.003 CELO",
expectedResult: resource.Result{
Content: "0.001",
},
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
// Mock behavior for public key retrieval
mockDataStore.On("ReadEntry", ctx, sessionId, utils.DATA_PUBLIC_KEY).Return(tt.publicKey, nil) mockDataStore.On("ReadEntry", ctx, sessionId, utils.DATA_PUBLIC_KEY).Return(tt.publicKey, nil)
mockCreateAccountService.On("CheckBalance", string(tt.publicKey)).Return(tt.balance, nil)
// Mock behavior for active symbol and balance retrieval (if present)
if len(tt.activeSym) > 0 {
mockDataStore.On("ReadEntry", ctx, sessionId, utils.DATA_ACTIVE_SYM).Return(tt.activeSym, nil)
mockDataStore.On("ReadEntry", ctx, sessionId, utils.DATA_ACTIVE_BAL).Return(tt.activeBal, nil)
} else {
mockDataStore.On("ReadEntry", ctx, sessionId, utils.DATA_ACTIVE_SYM).Return(nil, fmt.Errorf("not found"))
mockCreateAccountService.On("CheckBalance", string(tt.publicKey)).Return(tt.balance, nil)
}
// Mock behavior for storing the amount (if valid)
mockDataStore.On("WriteEntry", ctx, sessionId, utils.DATA_AMOUNT, tt.input).Return(nil).Maybe() mockDataStore.On("WriteEntry", ctx, sessionId, utils.DATA_AMOUNT, tt.input).Return(nil).Maybe()
// Call the method under test // Call the method under test
res, _ := h.ValidateAmount(ctx, "test_validate_amount", tt.input) res, _ := h.ValidateAmount(ctx, "test_validate_amount", tt.input)
// Assert that no errors occurred // Assert no errors occurred
assert.NoError(t, err) assert.NoError(t, err)
//Assert that the account created flag has been set to the result // Assert the result matches the expected result
assert.Equal(t, res, tt.expectedResult, "Expected result should be equal to the actual result") assert.Equal(t, tt.expectedResult, res, "Expected result should match actual result")
// Assert that expectations were met // Assert all expectations were met
mockDataStore.AssertExpectations(t) mockDataStore.AssertExpectations(t)
}) })
} }
} }