diff --git a/common/pin.go b/common/pin.go new file mode 100644 index 0000000..6db9d15 --- /dev/null +++ b/common/pin.go @@ -0,0 +1,33 @@ +package common + +import ( + "regexp" + + "golang.org/x/crypto/bcrypt" +) + +// Define the regex pattern as a constant +const ( + pinPattern = `^\d{4}$` +) + +// checks whether the given input is a 4 digit number +func IsValidPIN(pin string) bool { + match, _ := regexp.MatchString(pinPattern, pin) + return match +} + +// HashPIN uses bcrypt with 8 salt rounds to hash the PIN +func HashPIN(pin string) (string, error) { + hash, err := bcrypt.GenerateFromPassword([]byte(pin), 8) + if err != nil { + return "", err + } + return string(hash), nil +} + +// VerifyPIN compareS the hashed PIN with the plaintext PIN +func VerifyPIN(hashedPIN, pin string) bool { + err := bcrypt.CompareHashAndPassword([]byte(hashedPIN), []byte(pin)) + return err == nil +} diff --git a/common/pin_test.go b/common/pin_test.go new file mode 100644 index 0000000..154ab06 --- /dev/null +++ b/common/pin_test.go @@ -0,0 +1,173 @@ +package common + +import ( + "testing" + + "golang.org/x/crypto/bcrypt" +) + +func TestIsValidPIN(t *testing.T) { + tests := []struct { + name string + pin string + expected bool + }{ + { + name: "Valid PIN with 4 digits", + pin: "1234", + expected: true, + }, + { + name: "Valid PIN with leading zeros", + pin: "0001", + expected: true, + }, + { + name: "Invalid PIN with less than 4 digits", + pin: "123", + expected: false, + }, + { + name: "Invalid PIN with more than 4 digits", + pin: "12345", + expected: false, + }, + { + name: "Invalid PIN with letters", + pin: "abcd", + expected: false, + }, + { + name: "Invalid PIN with special characters", + pin: "12@#", + expected: false, + }, + { + name: "Empty PIN", + pin: "", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actual := IsValidPIN(tt.pin) + if actual != tt.expected { + t.Errorf("IsValidPIN(%q) = %v; expected %v", tt.pin, actual, tt.expected) + } + }) + } +} + +func TestHashPIN(t *testing.T) { + tests := []struct { + name string + pin string + }{ + { + name: "Valid PIN with 4 digits", + pin: "1234", + }, + { + name: "Valid PIN with leading zeros", + pin: "0001", + }, + { + name: "Empty PIN", + pin: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hashedPIN, err := HashPIN(tt.pin) + if err != nil { + t.Errorf("HashPIN(%q) returned an error: %v", tt.pin, err) + return + } + + if hashedPIN == "" { + t.Errorf("HashPIN(%q) returned an empty hash", tt.pin) + } + + // Ensure the hash can be verified with bcrypt + err = bcrypt.CompareHashAndPassword([]byte(hashedPIN), []byte(tt.pin)) + if tt.pin != "" && err != nil { + t.Errorf("HashPIN(%q) produced a hash that does not match: %v", tt.pin, err) + } + }) + } +} + +func TestVerifyMigratedHashPin(t *testing.T) { + tests := []struct { + pin string + hash string + }{ + { + pin: "1234", + hash: "$2b$08$dTvIGxCCysJtdvrSnaLStuylPoOS/ZLYYkxvTeR5QmTFY3TSvPQC6", + }, + } + + for _, tt := range tests { + t.Run(tt.pin, func(t *testing.T) { + ok := VerifyPIN(tt.hash, tt.pin) + if !ok { + t.Errorf("VerifyPIN could not verify migrated PIN: %v", tt.pin) + } + }) + } +} + +func TestVerifyPIN(t *testing.T) { + tests := []struct { + name string + pin string + hashedPIN string + shouldPass bool + }{ + { + name: "Valid PIN verification", + pin: "1234", + hashedPIN: hashPINHelper("1234"), + shouldPass: true, + }, + { + name: "Invalid PIN verification with incorrect PIN", + pin: "5678", + hashedPIN: hashPINHelper("1234"), + shouldPass: false, + }, + { + name: "Invalid PIN verification with empty PIN", + pin: "", + hashedPIN: hashPINHelper("1234"), + shouldPass: false, + }, + { + name: "Invalid PIN verification with invalid hash", + pin: "1234", + hashedPIN: "invalidhash", + shouldPass: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := VerifyPIN(tt.hashedPIN, tt.pin) + if result != tt.shouldPass { + t.Errorf("VerifyPIN(%q, %q) = %v; expected %v", tt.hashedPIN, tt.pin, result, tt.shouldPass) + } + }) + } +} + +// Helper function to hash a PIN for testing purposes +func hashPINHelper(pin string) string { + hashedPIN, err := HashPIN(pin) + if err != nil { + panic("Failed to hash PIN for test setup: " + err.Error()) + } + return hashedPIN +} diff --git a/internal/handlers/ussd/menuhandler.go b/internal/handlers/ussd/menuhandler.go index 640517f..3919595 100644 --- a/internal/handlers/ussd/menuhandler.go +++ b/internal/handlers/ussd/menuhandler.go @@ -5,7 +5,6 @@ import ( "context" "fmt" "path" - "regexp" "strconv" "strings" @@ -34,17 +33,6 @@ var ( translationDir = path.Join(scriptDir, "locale") ) -// Define the regex patterns as constants -const ( - pinPattern = `^\d{4}$` -) - -// checks whether the given input is a 4 digit number -func isValidPIN(pin string) bool { - match, _ := regexp.MatchString(pinPattern, pin) - return match -} - // FlagManager handles centralized flag management type FlagManager struct { parser *asm.FlagParser @@ -281,7 +269,7 @@ func (h *Handlers) VerifyNewPin(ctx context.Context, sym string, input []byte) ( flag_valid_pin, _ := h.flagManager.GetFlag("flag_valid_pin") pinInput := string(input) // Validate that the PIN is a 4-digit number. - if isValidPIN(pinInput) { + if common.IsValidPIN(pinInput) { res.FlagSet = append(res.FlagSet, flag_valid_pin) } else { res.FlagReset = append(res.FlagReset, flag_valid_pin) @@ -306,7 +294,7 @@ func (h *Handlers) SaveTemporaryPin(ctx context.Context, sym string, input []byt accountPIN := string(input) // Validate that the PIN is a 4-digit number. - if !isValidPIN(accountPIN) { + if !common.IsValidPIN(accountPIN) { res.FlagSet = append(res.FlagSet, flag_incorrect_pin) return res, nil } @@ -368,11 +356,20 @@ func (h *Handlers) ConfirmPinChange(ctx context.Context, sym string, input []byt res.FlagReset = append(res.FlagReset, flag_pin_mismatch) } else { res.FlagSet = append(res.FlagSet, flag_pin_mismatch) + return res, nil } - // If matched, save the confirmed PIN as the new account PIN - err = store.WriteEntry(ctx, sessionId, common.DATA_ACCOUNT_PIN, []byte(temporaryPin)) + + // Hash the PIN + hashedPIN, err := common.HashPIN(string(temporaryPin)) if err != nil { - logg.ErrorCtxf(ctx, "failed to write temporaryPin entry with", "key", common.DATA_ACCOUNT_PIN, "value", temporaryPin, "error", err) + logg.ErrorCtxf(ctx, "failed to hash temporaryPin", "error", err) + return res, err + } + + // save the hashed PIN as the new account PIN + err = store.WriteEntry(ctx, sessionId, common.DATA_ACCOUNT_PIN, []byte(hashedPIN)) + if err != nil { + logg.ErrorCtxf(ctx, "failed to write DATA_ACCOUNT_PIN entry with", "key", common.DATA_ACCOUNT_PIN, "hashedPIN value", hashedPIN, "error", err) return res, err } return res, nil @@ -404,11 +401,19 @@ func (h *Handlers) VerifyCreatePin(ctx context.Context, sym string, input []byte res.FlagSet = append(res.FlagSet, flag_pin_set) } else { res.FlagSet = []uint32{flag_pin_mismatch} + return res, nil } - err = store.WriteEntry(ctx, sessionId, common.DATA_ACCOUNT_PIN, []byte(temporaryPin)) + // Hash the PIN + hashedPIN, err := common.HashPIN(string(temporaryPin)) if err != nil { - logg.ErrorCtxf(ctx, "failed to write temporaryPin entry with", "key", common.DATA_ACCOUNT_PIN, "value", temporaryPin, "error", err) + logg.ErrorCtxf(ctx, "failed to hash temporaryPin", "error", err) + return res, err + } + + err = store.WriteEntry(ctx, sessionId, common.DATA_ACCOUNT_PIN, []byte(hashedPIN)) + if err != nil { + logg.ErrorCtxf(ctx, "failed to write DATA_ACCOUNT_PIN entry with", "key", common.DATA_ACCOUNT_PIN, "value", hashedPIN, "error", err) return res, err } @@ -722,7 +727,7 @@ func (h *Handlers) Authorize(ctx context.Context, sym string, input []byte) (res return res, err } if len(input) == 4 { - if bytes.Equal(input, AccountPin) { + if common.VerifyPIN(string(AccountPin), string(input)) { if h.st.MatchFlag(flag_account_authorized, false) { res.FlagReset = append(res.FlagReset, flag_incorrect_pin) res.FlagSet = append(res.FlagSet, flag_allow_update, flag_account_authorized) @@ -949,7 +954,15 @@ func (h *Handlers) ResetOthersPin(ctx context.Context, sym string, input []byte) logg.ErrorCtxf(ctx, "failed to read temporaryPin entry with", "key", common.DATA_TEMPORARY_VALUE, "error", err) return res, err } - err = store.WriteEntry(ctx, string(blockedPhonenumber), common.DATA_ACCOUNT_PIN, []byte(temporaryPin)) + + // Hash the PIN + hashedPIN, err := common.HashPIN(string(temporaryPin)) + if err != nil { + logg.ErrorCtxf(ctx, "failed to hash temporaryPin", "error", err) + return res, err + } + + err = store.WriteEntry(ctx, string(blockedPhonenumber), common.DATA_ACCOUNT_PIN, []byte(hashedPIN)) if err != nil { return res, nil } @@ -1400,7 +1413,6 @@ func (h *Handlers) GetCurrentProfileInfo(ctx context.Context, sym string, input defaultValue = "Not Provided" } - sm, _ := h.st.Where() parts := strings.SplitN(sm, "_", 2) filename := parts[1] diff --git a/internal/handlers/ussd/menuhandler_test.go b/internal/handlers/ussd/menuhandler_test.go index 2b168f2..12ed5c2 100644 --- a/internal/handlers/ussd/menuhandler_test.go +++ b/internal/handlers/ussd/menuhandler_test.go @@ -1047,7 +1047,14 @@ func TestAuthorize(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err = store.WriteEntry(ctx, sessionId, common.DATA_ACCOUNT_PIN, []byte(accountPIN)) + // Hash the PIN + hashedPIN, err := common.HashPIN(accountPIN) + if err != nil { + logg.ErrorCtxf(ctx, "failed to hash temporaryPin", "error", err) + t.Fatal(err) + } + + err = store.WriteEntry(ctx, sessionId, common.DATA_ACCOUNT_PIN, []byte(hashedPIN)) if err != nil { t.Fatal(err) } @@ -1499,59 +1506,6 @@ func TestQuit(t *testing.T) { } } -func TestIsValidPIN(t *testing.T) { - tests := []struct { - name string - pin string - expected bool - }{ - { - name: "Valid PIN with 4 digits", - pin: "1234", - expected: true, - }, - { - name: "Valid PIN with leading zeros", - pin: "0001", - expected: true, - }, - { - name: "Invalid PIN with less than 4 digits", - pin: "123", - expected: false, - }, - { - name: "Invalid PIN with more than 4 digits", - pin: "12345", - expected: false, - }, - { - name: "Invalid PIN with letters", - pin: "abcd", - expected: false, - }, - { - name: "Invalid PIN with special characters", - pin: "12@#", - expected: false, - }, - { - name: "Empty PIN", - pin: "", - expected: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - actual := isValidPIN(tt.pin) - if actual != tt.expected { - t.Errorf("isValidPIN(%q) = %v; expected %v", tt.pin, actual, tt.expected) - } - }) - } -} - func TestValidateAmount(t *testing.T) { fm, err := NewFlagManager(flagsPath) if err != nil {