From 7df77a134307823e058ca05541a004283189d0e1 Mon Sep 17 00:00:00 2001 From: alfred-mk Date: Sat, 12 Oct 2024 20:07:06 +0300 Subject: [PATCH] updated the ValidateAmount to also check the active symbol, updated tests --- internal/handlers/ussd/menuhandler.go | 88 +++++++++------------- internal/handlers/ussd/menuhandler_test.go | 76 +++++++++---------- 2 files changed, 72 insertions(+), 92 deletions(-) diff --git a/internal/handlers/ussd/menuhandler.go b/internal/handlers/ussd/menuhandler.go index d6d8102..3e164af 100644 --- a/internal/handlers/ussd/menuhandler.go +++ b/internal/handlers/ussd/menuhandler.go @@ -761,74 +761,60 @@ func (h *Handlers) ResetTransactionAmount(ctx context.Context, sym string, input return res, nil } -// MaxAmount gets the current balance from the API and sets it as -// the result content. -func (h *Handlers) MaxAmount(ctx context.Context, sym string, input []byte) (resource.Result, error) { - var res resource.Result - var err error - - sessionId, ok := ctx.Value("SessionId").(string) - if !ok { - return res, fmt.Errorf("missing session") - } - store := h.userdataStore - publicKey, _ := store.ReadEntry(ctx, sessionId, utils.DATA_PUBLIC_KEY) - - balance, err := h.accountService.CheckBalance(string(publicKey)) - if err != nil { - return res, nil - } - - res.Content = balance - - return res, nil -} - // ValidateAmount ensures that the given input is a valid amount and that // it is not more than the current balance. func (h *Handlers) ValidateAmount(ctx context.Context, sym string, input []byte) (resource.Result, error) { var res resource.Result - var err error sessionId, ok := ctx.Value("SessionId").(string) if !ok { return res, fmt.Errorf("missing session") } - flag_invalid_amount, _ := h.flagManager.GetFlag("flag_invalid_amount") - store := h.userdataStore - publicKey, _ := store.ReadEntry(ctx, sessionId, utils.DATA_PUBLIC_KEY) - - amountStr := string(input) - - balanceStr, err := h.accountService.CheckBalance(string(publicKey)) + publicKey, err := store.ReadEntry(ctx, sessionId, utils.DATA_PUBLIC_KEY) if err != nil { return res, err } - res.Content = balanceStr - // Parse the balance - balanceParts := strings.Split(balanceStr, " ") - if len(balanceParts) != 2 { - return res, fmt.Errorf("unexpected balance format: %s", balanceStr) - } - balanceValue, err := strconv.ParseFloat(balanceParts[0], 64) - if err != nil { - return res, fmt.Errorf("failed to parse balance: %v", err) + // retrieve the active symbol + activeSym, err := store.ReadEntry(ctx, sessionId, utils.DATA_ACTIVE_SYM) + useActiveSymbol := err == nil && len(activeSym) > 0 + + var balanceValue float64 + if useActiveSymbol { + // If active symbol is set, retrieve its balance + activeBal, err := store.ReadEntry(ctx, sessionId, utils.DATA_ACTIVE_BAL) + if err != nil { + return res, fmt.Errorf("failed to get active balance: %v", err) + } + balanceValue, err = strconv.ParseFloat(string(activeBal), 64) + if err != nil { + return res, fmt.Errorf("failed to parse active balance: %v", err) + } + } else { + // If no active symbol, use the current balance from the API + balanceStr, err := h.accountService.CheckBalance(string(publicKey)) + if err != nil { + return res, fmt.Errorf("failed to check balance: %v", err) + } + res.Content = balanceStr + + // Parse the balance string + balanceParts := strings.Split(balanceStr, " ") + if len(balanceParts) != 2 { + return res, fmt.Errorf("unexpected balance format: %s", balanceStr) + } + balanceValue, err = strconv.ParseFloat(balanceParts[0], 64) + if err != nil { + return res, fmt.Errorf("failed to parse balance: %v", err) + } } - // Extract numeric part from input - re := regexp.MustCompile(`^(\d+(\.\d+)?)\s*(?:CELO)?$`) - matches := re.FindStringSubmatch(strings.TrimSpace(amountStr)) - if len(matches) < 2 { - res.FlagSet = append(res.FlagSet, flag_invalid_amount) - res.Content = amountStr - return res, nil - } - - inputAmount, err := strconv.ParseFloat(matches[1], 64) + // Extract numeric part from the input amount + amountStr := strings.TrimSpace(string(input)) + inputAmount, err := strconv.ParseFloat(amountStr, 64) if err != nil { res.FlagSet = append(res.FlagSet, flag_invalid_amount) res.Content = amountStr @@ -841,12 +827,12 @@ func (h *Handlers) ValidateAmount(ctx context.Context, sym string, input []byte) return res, nil } - res.Content = fmt.Sprintf("%.3f", inputAmount) // Format to 3 decimal places err = store.WriteEntry(ctx, sessionId, utils.DATA_AMOUNT, []byte(amountStr)) if err != nil { return res, err } + res.Content = fmt.Sprintf("%.3f", inputAmount) return res, nil } diff --git a/internal/handlers/ussd/menuhandler_test.go b/internal/handlers/ussd/menuhandler_test.go index 9159722..4fd9d22 100644 --- a/internal/handlers/ussd/menuhandler_test.go +++ b/internal/handlers/ussd/menuhandler_test.go @@ -434,34 +434,6 @@ func TestCheckIdentifier(t *testing.T) { } } -func TestMaxAmount(t *testing.T) { - mockStore := new(mocks.MockUserDataStore) - mockCreateAccountService := new(mocks.MockAccountService) - - // Define test data - sessionId := "session123" - ctx := context.WithValue(context.Background(), "SessionId", sessionId) - publicKey := "0xcasgatweksalw1018221" - expectedBalance := "0.003CELO" - - // Set up the expected behavior of the mock - mockStore.On("ReadEntry", ctx, sessionId, utils.DATA_PUBLIC_KEY).Return([]byte(publicKey), nil) - mockCreateAccountService.On("CheckBalance", publicKey).Return(expectedBalance, nil) - - // Create the Handlers instance with the mock store - h := &Handlers{ - userdataStore: mockStore, - accountService: mockCreateAccountService, - } - - // Call the method - res, _ := h.MaxAmount(ctx, "max_amount", []byte("check_balance")) - - //Assert that the balance that was set as the result content is what was returned by Check Balance - assert.Equal(t, expectedBalance, res.Content) - -} - func TestGetSender(t *testing.T) { mockStore := new(mocks.MockUserDataStore) @@ -1444,59 +1416,81 @@ func TestValidateAmount(t *testing.T) { name string input []byte publicKey []byte + activeSym []byte + activeBal []byte balance string expectedResult resource.Result }{ { - name: "Test with valid amount", + name: "Test with valid amount and active symbol", input: []byte("0.001"), - balance: "0.003 CELO", publicKey: []byte("0xrqeqrequuq"), + activeSym: []byte("CELO"), + activeBal: []byte("0.003"), expectedResult: resource.Result{ Content: "0.001", }, }, { - name: "Test with amount larger than balance", + name: "Test with amount larger than active balance", input: []byte("0.02"), - balance: "0.003 CELO", publicKey: []byte("0xrqeqrequuq"), + activeSym: []byte("CELO"), + activeBal: []byte("0.003"), expectedResult: resource.Result{ FlagSet: []uint32{flag_invalid_amount}, Content: "0.02", }, }, { - name: "Test with invalid amount", + name: "Test with invalid amount format", input: []byte("0.02ms"), - balance: "0.003 CELO", publicKey: []byte("0xrqeqrequuq"), + balance: "0.003 CELO", expectedResult: resource.Result{ FlagSet: []uint32{flag_invalid_amount}, Content: "0.02ms", }, }, + { + name: "Test fallback to current balance without active symbol", + input: []byte("0.001"), + publicKey: []byte("0xrqeqrequuq"), + balance: "0.003 CELO", + expectedResult: resource.Result{ + Content: "0.001", + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - + // Mock behavior for public key retrieval mockDataStore.On("ReadEntry", ctx, sessionId, utils.DATA_PUBLIC_KEY).Return(tt.publicKey, nil) - mockCreateAccountService.On("CheckBalance", string(tt.publicKey)).Return(tt.balance, nil) + + // Mock behavior for active symbol and balance retrieval (if present) + if len(tt.activeSym) > 0 { + mockDataStore.On("ReadEntry", ctx, sessionId, utils.DATA_ACTIVE_SYM).Return(tt.activeSym, nil) + mockDataStore.On("ReadEntry", ctx, sessionId, utils.DATA_ACTIVE_BAL).Return(tt.activeBal, nil) + } else { + mockDataStore.On("ReadEntry", ctx, sessionId, utils.DATA_ACTIVE_SYM).Return(nil, fmt.Errorf("not found")) + mockCreateAccountService.On("CheckBalance", string(tt.publicKey)).Return(tt.balance, nil) + } + + // Mock behavior for storing the amount (if valid) mockDataStore.On("WriteEntry", ctx, sessionId, utils.DATA_AMOUNT, tt.input).Return(nil).Maybe() // Call the method under test res, _ := h.ValidateAmount(ctx, "test_validate_amount", tt.input) - // Assert that no errors occurred + // Assert no errors occurred assert.NoError(t, err) - //Assert that the account created flag has been set to the result - assert.Equal(t, res, tt.expectedResult, "Expected result should be equal to the actual result") + // Assert the result matches the expected result + assert.Equal(t, tt.expectedResult, res, "Expected result should match actual result") - // Assert that expectations were met + // Assert all expectations were met mockDataStore.AssertExpectations(t) - }) } }