diff --git a/handlers/application/menuhandler.go b/handlers/application/menuhandler.go index 8b4a30b..77abdd2 100644 --- a/handlers/application/menuhandler.go +++ b/handlers/application/menuhandler.go @@ -1830,12 +1830,12 @@ func (h *MenuHandlers) ValidateAmount(ctx context.Context, sym string, input []b return res, fmt.Errorf("missing session") } flag_invalid_amount, _ := h.flagManager.GetFlag("flag_invalid_amount") - store := h.userdataStore + userStore := h.userdataStore var balanceValue float64 // retrieve the active balance - activeBal, err := store.ReadEntry(ctx, sessionId, storedb.DATA_ACTIVE_BAL) + activeBal, err := userStore.ReadEntry(ctx, sessionId, storedb.DATA_ACTIVE_BAL) if err != nil { logg.ErrorCtxf(ctx, "failed to read activeBal entry with", "key", storedb.DATA_ACTIVE_BAL, "error", err) return res, err @@ -1861,9 +1861,15 @@ func (h *MenuHandlers) ValidateAmount(ctx context.Context, sym string, input []b return res, nil } - // Format the amount with 2 decimal places before saving - formattedAmount := fmt.Sprintf("%.2f", inputAmount) - err = store.WriteEntry(ctx, sessionId, storedb.DATA_AMOUNT, []byte(formattedAmount)) + // Format the amount to 2 decimal places before saving (truncated) + formattedAmount, err := store.TruncateDecimalString(amountStr, 2) + if err != nil { + res.FlagSet = append(res.FlagSet, flag_invalid_amount) + res.Content = amountStr + return res, nil + } + + err = userStore.WriteEntry(ctx, sessionId, storedb.DATA_AMOUNT, []byte(formattedAmount)) if err != nil { logg.ErrorCtxf(ctx, "failed to write amount entry with", "key", storedb.DATA_AMOUNT, "value", formattedAmount, "error", err) return res, err @@ -3046,7 +3052,15 @@ func (h *MenuHandlers) SwapPreview(ctx context.Context, sym string, input []byte return res, nil } - finalAmountStr, err := store.ParseAndScaleAmount(inputStr, swapData.ActiveSwapFromDecimal) + // Format the amount to 2 decimal places + formattedAmount, err := store.TruncateDecimalString(inputStr, 2) + if err != nil { + res.FlagSet = append(res.FlagSet, flag_invalid_amount) + res.Content = inputStr + return res, nil + } + + finalAmountStr, err := store.ParseAndScaleAmount(formattedAmount, swapData.ActiveSwapFromDecimal) if err != nil { return res, err } diff --git a/handlers/application/menuhandler_test.go b/handlers/application/menuhandler_test.go index a2485c0..f7a23ba 100644 --- a/handlers/application/menuhandler_test.go +++ b/handlers/application/menuhandler_test.go @@ -1678,6 +1678,22 @@ func TestValidateAmount(t *testing.T) { Content: "0.02ms", }, }, + { + name: "Test with valid decimal amount", + input: []byte("0.149"), + activeBal: []byte("5"), + expectedResult: resource.Result{ + Content: "0.14", + }, + }, + { + name: "Test with valid large decimal amount", + input: []byte("1.8599999999"), + activeBal: []byte("5"), + expectedResult: resource.Result{ + Content: "1.85", + }, + }, } for _, tt := range tests { @@ -2529,11 +2545,11 @@ func TestCheckTransactions(t *testing.T) { mockTXResponse := []dataserviceapi.Last10TxResponse{ { - Sender: "0X13242618721", Recipient: "0x41c188d63Qa", TransferValue: "100", TokenAddress: "0X1324262343rfdGW23", + Sender: "0X13242618721", Recipient: "0x41c188d63Qa", TransferValue: "100", ContractAddress: "0X1324262343rfdGW23", TxHash: "0x123wefsf34rf", DateBlock: time.Now(), TokenSymbol: "SRF", TokenDecimals: "6", }, { - Sender: "0x41c188d63Qa", Recipient: "0X13242618721", TransferValue: "200", TokenAddress: "0X1324262343rfdGW23", + Sender: "0x41c188d63Qa", Recipient: "0X13242618721", TransferValue: "200", ContractAddress: "0X1324262343rfdGW23", TxHash: "0xq34wresfdb44", DateBlock: time.Now(), TokenSymbol: "SRF", TokenDecimals: "6", }, } @@ -2585,11 +2601,11 @@ func TestGetTransactionsList(t *testing.T) { mockTXResponse := []dataserviceapi.Last10TxResponse{ { - Sender: "0X13242618721", Recipient: "0x41c188d63Qa", TransferValue: "1000", TokenAddress: "0X1324262343rfdGW23", + Sender: "0X13242618721", Recipient: "0x41c188d63Qa", TransferValue: "1000", ContractAddress: "0X1324262343rfdGW23", TxHash: "0x123wefsf34rf", DateBlock: dateBlock, TokenSymbol: "SRF", TokenDecimals: "2", }, { - Sender: "0x41c188d63Qa", Recipient: "0X13242618721", TransferValue: "2000", TokenAddress: "0X1324262343rfdGW23", + Sender: "0x41c188d63Qa", Recipient: "0X13242618721", TransferValue: "2000", ContractAddress: "0X1324262343rfdGW23", TxHash: "0xq34wresfdb44", DateBlock: dateBlock, TokenSymbol: "SRF", TokenDecimals: "2", }, } @@ -2654,11 +2670,11 @@ func TestViewTransactionStatement(t *testing.T) { mockTXResponse := []dataserviceapi.Last10TxResponse{ { - Sender: "0X13242618721", Recipient: "0x41c188d63Qa", TransferValue: "1000", TokenAddress: "0X1324262343rfdGW23", + Sender: "0X13242618721", Recipient: "0x41c188d63Qa", TransferValue: "1000", ContractAddress: "0X1324262343rfdGW23", TxHash: "0x123wefsf34rf", DateBlock: dateBlock, TokenSymbol: "SRF", TokenDecimals: "2", }, { - Sender: "0x41c188d63Qa", Recipient: "0X13242618721", TransferValue: "2000", TokenAddress: "0X1324262343rfdGW23", + Sender: "0x41c188d63Qa", Recipient: "0X13242618721", TransferValue: "2000", ContractAddress: "0X1324262343rfdGW23", TxHash: "0xq34wresfdb44", DateBlock: dateBlock, TokenSymbol: "SRF", TokenDecimals: "2", }, } diff --git a/store/tokens.go b/store/tokens.go index a7770c7..49ac175 100644 --- a/store/tokens.go +++ b/store/tokens.go @@ -3,6 +3,7 @@ package store import ( "context" "errors" + "fmt" "math/big" "reflect" "strconv" @@ -20,6 +21,27 @@ type TransactionData struct { ActiveAddress string } +// TruncateDecimalString safely truncates the input amount to the specified decimal places +func TruncateDecimalString(input string, decimalPlaces int) (string, error) { + num, ok := new(big.Float).SetString(input) + if !ok { + return "", fmt.Errorf("invalid input") + } + + // Multiply by 10^decimalPlaces + scale := new(big.Float).SetInt(new(big.Int).Exp(big.NewInt(10), big.NewInt(int64(decimalPlaces)), nil)) + scaled := new(big.Float).Mul(num, scale) + + // Truncate by converting to int (chops off decimals) + intPart, _ := scaled.Int(nil) + + // Divide back to get truncated float + truncated := new(big.Float).Quo(new(big.Float).SetInt(intPart), scale) + + // Format with fixed decimals + return truncated.Text('f', decimalPlaces), nil +} + func ParseAndScaleAmount(storedAmount, activeDecimal string) (string, error) { // Parse token decimal tokenDecimal, err := strconv.Atoi(activeDecimal) @@ -38,11 +60,8 @@ func ParseAndScaleAmount(storedAmount, activeDecimal string) (string, error) { multiplier := new(big.Float).SetInt(new(big.Int).Exp(big.NewInt(10), big.NewInt(int64(tokenDecimal)), nil)) finalAmount := new(big.Float).Mul(amount, multiplier) - // Convert finalAmount to a string - finalAmountStr := new(big.Int) - finalAmount.Int(finalAmountStr) - - return finalAmountStr.String(), nil + // Return finalAmount as a string with 0 decimal places (rounded) + return finalAmount.Text('f', 0), nil } func ReadTransactionData(ctx context.Context, store DataStore, sessionId string) (TransactionData, error) { diff --git a/store/tokens_test.go b/store/tokens_test.go index b8c0082..625a65d 100644 --- a/store/tokens_test.go +++ b/store/tokens_test.go @@ -7,6 +7,109 @@ import ( "github.com/alecthomas/assert/v2" ) +func TestTruncateDecimalString(t *testing.T) { + tests := []struct { + name string + input string + decimalPlaces int + want string + expectError bool + }{ + { + name: "whole number", + input: "4", + decimalPlaces: 2, + want: "4.00", + expectError: false, + }, + { + name: "single decimal", + input: "4.1", + decimalPlaces: 2, + want: "4.10", + expectError: false, + }, + { + name: "one decimal place", + input: "4.19", + decimalPlaces: 1, + want: "4.1", + expectError: false, + }, + { + name: "truncates to 2 dp", + input: "0.149", + decimalPlaces: 2, + want: "0.14", + expectError: false, + }, + { + name: "does not round", + input: "1.8599999999", + decimalPlaces: 2, + want: "1.85", + expectError: false, + }, + { + name: "high precision input", + input: "123.456789", + decimalPlaces: 4, + want: "123.4567", + expectError: false, + }, + { + name: "zero", + input: "0", + decimalPlaces: 2, + want: "0.00", + expectError: false, + }, + { + name: "invalid input string", + input: "abc", + decimalPlaces: 2, + want: "", + expectError: true, + }, + { + name: "edge rounding case", + input: "4.99999999", + decimalPlaces: 2, + want: "4.99", + expectError: false, + }, + { + name: "small value", + input: "0.0001", + decimalPlaces: 2, + want: "0.00", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := TruncateDecimalString(tt.input, tt.decimalPlaces) + + if tt.expectError { + if err == nil { + t.Errorf("TruncateDecimalString(%q, %d) expected error, got nil", tt.input, tt.decimalPlaces) + } + return + } + + if err != nil { + t.Errorf("TruncateDecimalString(%q, %d) unexpected error: %v", tt.input, tt.decimalPlaces, err) + return + } + + if got != tt.want { + t.Errorf("TruncateDecimalString(%q, %d) = %q, want %q", tt.input, tt.decimalPlaces, got, tt.want) + } + }) + } +} + func TestParseAndScaleAmount(t *testing.T) { tests := []struct { name string @@ -64,6 +167,20 @@ func TestParseAndScaleAmount(t *testing.T) { want: "0", expectError: false, }, + { + name: "high decimals", + amount: "1.85", + decimals: "18", + want: "1850000000000000000", + expectError: false, + }, + { + name: "6 d.p", + amount: "2.32", + decimals: "6", + want: "2320000", + expectError: false, + }, } for _, tt := range tests {