Merge branch 'lash/purify-more' into postgres-switch-for-tests

This commit is contained in:
lash
2025-01-08 22:05:12 +00:00
31 changed files with 584 additions and 212 deletions

View File

@@ -128,6 +128,7 @@ func (ls *LocalHandlerService) GetHandler(accountService remote.AccountServiceIn
ls.DbRs.AddLocalFunc("view_statement", ussdHandlers.ViewTransactionStatement)
ls.DbRs.AddLocalFunc("update_all_profile_items", ussdHandlers.UpdateAllProfileItems)
ls.DbRs.AddLocalFunc("set_back", ussdHandlers.SetBack)
ls.DbRs.AddLocalFunc("show_blocked_account", ussdHandlers.ShowBlockedAccount)
return ussdHandlers, nil
}

View File

@@ -734,11 +734,23 @@ func (h *Handlers) Authorize(ctx context.Context, sym string, input []byte) (res
if h.st.MatchFlag(flag_account_authorized, false) {
res.FlagReset = append(res.FlagReset, flag_incorrect_pin)
res.FlagSet = append(res.FlagSet, flag_allow_update, flag_account_authorized)
err := h.resetIncorrectPINAttempts(ctx, sessionId)
if err != nil {
return res, err
}
} else {
res.FlagSet = append(res.FlagSet, flag_allow_update)
res.FlagReset = append(res.FlagReset, flag_account_authorized)
err := h.resetIncorrectPINAttempts(ctx, sessionId)
if err != nil {
return res, err
}
}
} else {
err := h.incrementIncorrectPINAttempts(ctx, sessionId)
if err != nil {
return res, err
}
res.FlagSet = append(res.FlagSet, flag_incorrect_pin)
res.FlagReset = append(res.FlagReset, flag_account_authorized)
return res, nil
@@ -752,8 +764,34 @@ func (h *Handlers) Authorize(ctx context.Context, sym string, input []byte) (res
// ResetIncorrectPin resets the incorrect pin flag after a new PIN attempt.
func (h *Handlers) ResetIncorrectPin(ctx context.Context, sym string, input []byte) (resource.Result, error) {
var res resource.Result
store := h.userdataStore
flag_incorrect_pin, _ := h.flagManager.GetFlag("flag_incorrect_pin")
flag_account_blocked, _ := h.flagManager.GetFlag("flag_account_blocked")
sessionId, ok := ctx.Value("SessionId").(string)
if !ok {
return res, fmt.Errorf("missing session")
}
res.FlagReset = append(res.FlagReset, flag_incorrect_pin)
currentWrongPinAttempts, err := store.ReadEntry(ctx, sessionId, common.DATA_INCORRECT_PIN_ATTEMPTS)
if err != nil {
if !db.IsNotFound(err) {
return res, err
}
}
pinAttemptsValue, _ := strconv.ParseUint(string(currentWrongPinAttempts), 0, 64)
remainingPINAttempts := common.AllowedPINAttempts - uint8(pinAttemptsValue)
if remainingPINAttempts == 0 {
res.FlagSet = append(res.FlagSet, flag_account_blocked)
return res, nil
}
if remainingPINAttempts < common.AllowedPINAttempts {
res.Content = strconv.Itoa(int(remainingPINAttempts))
}
return res, nil
}
@@ -840,6 +878,16 @@ func (h *Handlers) QuitWithHelp(ctx context.Context, sym string, input []byte) (
return res, nil
}
// ShowBlockedAccount displays a message after an account has been blocked and how to reach support.
func (h *Handlers) ShowBlockedAccount(ctx context.Context, sym string, input []byte) (resource.Result, error) {
var res resource.Result
code := codeFromCtx(ctx)
l := gotext.NewLocale(translationDir, code)
l.AddDomain("default")
res.Content = l.Get("Your account has been locked. For help on how to unblock your account, contact support at: 0757628885")
return res, nil
}
// VerifyYob verifies the length of the given input.
func (h *Handlers) VerifyYob(ctx context.Context, sym string, input []byte) (resource.Result, error) {
var res resource.Result
@@ -2075,3 +2123,53 @@ func (h *Handlers) UpdateAllProfileItems(ctx context.Context, sym string, input
}
return res, nil
}
// incrementIncorrectPINAttempts keeps track of the number of incorrect PIN attempts
func (h *Handlers) incrementIncorrectPINAttempts(ctx context.Context, sessionId string) error {
var pinAttemptsCount uint8
store := h.userdataStore
currentWrongPinAttempts, err := store.ReadEntry(ctx, sessionId, common.DATA_INCORRECT_PIN_ATTEMPTS)
if err != nil {
if db.IsNotFound(err) {
//First time Wrong PIN attempt: initialize with a count of 1
pinAttemptsCount = 1
err = store.WriteEntry(ctx, sessionId, common.DATA_INCORRECT_PIN_ATTEMPTS, []byte(strconv.Itoa(int(pinAttemptsCount))))
if err != nil {
logg.ErrorCtxf(ctx, "failed to write incorrect PIN attempts ", "key", common.DATA_INCORRECT_PIN_ATTEMPTS, "value", currentWrongPinAttempts, "error", err)
return err
}
return nil
}
}
pinAttemptsValue, _ := strconv.ParseUint(string(currentWrongPinAttempts), 0, 64)
pinAttemptsCount = uint8(pinAttemptsValue) + 1
err = store.WriteEntry(ctx, sessionId, common.DATA_INCORRECT_PIN_ATTEMPTS, []byte(strconv.Itoa(int(pinAttemptsCount))))
if err != nil {
logg.ErrorCtxf(ctx, "failed to write incorrect PIN attempts ", "key", common.DATA_INCORRECT_PIN_ATTEMPTS, "value", pinAttemptsCount, "error", err)
return err
}
return nil
}
// resetIncorrectPINAttempts resets the number of incorrect PIN attempts after a correct PIN entry
func (h *Handlers) resetIncorrectPINAttempts(ctx context.Context, sessionId string) error {
store := h.userdataStore
currentWrongPinAttempts, err := store.ReadEntry(ctx, sessionId, common.DATA_INCORRECT_PIN_ATTEMPTS)
if err != nil {
if db.IsNotFound(err) {
return nil
}
return err
}
currentWrongPinAttemptsCount, _ := strconv.ParseUint(string(currentWrongPinAttempts), 0, 64)
if currentWrongPinAttemptsCount <= uint64(common.AllowedPINAttempts) {
err = store.WriteEntry(ctx, sessionId, common.DATA_INCORRECT_PIN_ATTEMPTS, []byte(string("0")))
if err != nil {
logg.ErrorCtxf(ctx, "failed to reset incorrect PIN attempts ", "key", common.DATA_INCORRECT_PIN_ATTEMPTS, "value", common.AllowedPINAttempts, "error", err)
return err
}
}
return nil
}

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"log"
"path"
"strconv"
"strings"
"testing"
@@ -907,37 +908,79 @@ func TestResetAccountAuthorized(t *testing.T) {
}
func TestIncorrectPinReset(t *testing.T) {
sessionId := "session123"
ctx, store := InitializeTestStore(t)
fm, err := NewFlagManager(flagsPath)
if err != nil {
log.Fatal(err)
}
flag_incorrect_pin, _ := fm.parser.GetFlag("flag_incorrect_pin")
flag_account_blocked, _ := fm.parser.GetFlag("flag_account_blocked")
ctx = context.WithValue(ctx, "SessionId", sessionId)
// Define test cases
tests := []struct {
name string
input []byte
attempts uint8
expectedResult resource.Result
}{
{
name: "Test incorrect pin reset",
name: "Test when incorrect PIN attempts is 2",
input: []byte(""),
expectedResult: resource.Result{
FlagReset: []uint32{flag_incorrect_pin},
Content: "1", //Expected remaining PIN attempts
},
attempts: 2,
},
{
name: "Test incorrect pin reset when incorrect PIN attempts is 1",
input: []byte(""),
expectedResult: resource.Result{
FlagReset: []uint32{flag_incorrect_pin},
Content: "2", //Expected remaining PIN attempts
},
attempts: 1,
},
{
name: "Test incorrect pin reset when incorrect PIN attempts is 1",
input: []byte(""),
expectedResult: resource.Result{
FlagReset: []uint32{flag_incorrect_pin},
Content: "2", //Expected remaining PIN attempts
},
attempts: 1,
},
{
name: "Test incorrect pin reset when incorrect PIN attempts is 3(account expected to be blocked)",
input: []byte(""),
expectedResult: resource.Result{
FlagReset: []uint32{flag_incorrect_pin},
FlagSet: []uint32{flag_account_blocked},
},
attempts: 3,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := store.WriteEntry(ctx, sessionId, common.DATA_INCORRECT_PIN_ATTEMPTS, []byte(strconv.Itoa(int(tt.attempts)))); err != nil {
t.Fatal(err)
}
// Create the Handlers instance with the mock flag manager
h := &Handlers{
flagManager: fm.parser,
flagManager: fm.parser,
userdataStore: store,
}
// Call the method
res, err := h.ResetIncorrectPin(context.Background(), "reset_incorrect_pin", tt.input)
res, err := h.ResetIncorrectPin(ctx, "reset_incorrect_pin", tt.input)
if err != nil {
t.Error(err)
}
@@ -2190,3 +2233,55 @@ func TestGetVoucherDetails(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, expectedResult, res)
}
func TestCountIncorrectPINAttempts(t *testing.T) {
ctx, store := InitializeTestStore(t)
sessionId := "session123"
ctx = context.WithValue(ctx, "SessionId", sessionId)
attempts := uint8(2)
h := &Handlers{
userdataStore: store,
}
err := store.WriteEntry(ctx, sessionId, common.DATA_INCORRECT_PIN_ATTEMPTS, []byte(strconv.Itoa(int(attempts))))
if err != nil {
t.Logf(err.Error())
}
err = h.incrementIncorrectPINAttempts(ctx, sessionId)
if err != nil {
t.Logf(err.Error())
}
attemptsAfterCount, err := store.ReadEntry(ctx, sessionId, common.DATA_INCORRECT_PIN_ATTEMPTS)
if err != nil {
t.Logf(err.Error())
}
pinAttemptsValue, _ := strconv.ParseUint(string(attemptsAfterCount), 0, 64)
pinAttemptsCount := uint8(pinAttemptsValue)
expectedAttempts := attempts + 1
assert.Equal(t, pinAttemptsCount, expectedAttempts)
}
func TestResetIncorrectPINAttempts(t *testing.T) {
ctx, store := InitializeTestStore(t)
sessionId := "session123"
ctx = context.WithValue(ctx, "SessionId", sessionId)
err := store.WriteEntry(ctx, sessionId, common.DATA_INCORRECT_PIN_ATTEMPTS, []byte(string("2")))
if err != nil {
t.Logf(err.Error())
}
h := &Handlers{
userdataStore: store,
}
h.resetIncorrectPINAttempts(ctx, sessionId)
incorrectAttempts, err := store.ReadEntry(ctx, sessionId, common.DATA_INCORRECT_PIN_ATTEMPTS)
if err != nil {
t.Logf(err.Error())
}
assert.Equal(t, "0", string(incorrectAttempts))
}

View File

@@ -41,6 +41,7 @@ func NewAuther(ctx context.Context, keyStore *SshKeyStore) *auther {
}
func(a *auther) Check(conn ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
logg.TraceCtxf(a.Ctx, "looking for publickey", "pubkey", fmt.Sprintf("%x", pubKey))
va, err := a.keyStore.Get(a.Ctx, pubKey)
if err != nil {
return nil, err
@@ -71,6 +72,20 @@ func(a *auther) Get(k []byte) (string, error) {
return v, nil
}
type SshRunner struct {
Ctx context.Context
Cfg engine.Config
FlagFile string
Conn storage.ConnData
ResourceDir string
Debug bool
SrvKeyFile string
Host string
Port uint
wg sync.WaitGroup
lst net.Listener
}
func(s *SshRunner) serve(ctx context.Context, sessionId string, ch ssh.NewChannel, en engine.Engine) error {
if ch == nil {
return errors.New("nil channel")
@@ -128,32 +143,13 @@ func(s *SshRunner) serve(ctx context.Context, sessionId string, ch ssh.NewChanne
return nil
}
type SshRunner struct {
Ctx context.Context
Cfg engine.Config
FlagFile string
DbDir string
ResourceDir string
Debug bool
SrvKeyFile string
Host string
Port uint
wg sync.WaitGroup
lst net.Listener
}
func(s *SshRunner) Stop() error {
return s.lst.Close()
}
func(s *SshRunner) GetEngine(sessionId string) (engine.Engine, func(), error) {
ctx := s.Ctx
menuStorageService := storage.NewMenuStorageService(s.DbDir, s.ResourceDir)
err := menuStorageService.EnsureDbDir()
if err != nil {
return nil, nil, err
}
menuStorageService := storage.NewMenuStorageService(s.Conn, s.ResourceDir)
rs, err := menuStorageService.GetResource(ctx)
if err != nil {
@@ -208,6 +204,7 @@ func(s *SshRunner) GetEngine(sessionId string) (engine.Engine, func(), error) {
// adapted example from crypto/ssh package, NewServerConn doc
func(s *SshRunner) Run(ctx context.Context, keyStore *SshKeyStore) {
s.Ctx = ctx
running := true
// TODO: waitgroup should probably not be global

69
internal/storage/parse.go Normal file
View File

@@ -0,0 +1,69 @@
package storage
import (
"fmt"
"net/url"
"path"
)
const (
DBTYPE_MEM = iota
DBTYPE_GDBM
DBTYPE_POSTGRES
)
type ConnData struct {
typ int
str string
}
func (cd *ConnData) DbType() int {
return cd.typ
}
func (cd *ConnData) String() string {
return cd.str
}
func probePostgres(s string) (string, bool) {
v, err := url.Parse(s)
if err != nil {
return "", false
}
if v.Scheme != "postgres" {
return "", false
}
return s, true
}
func probeGdbm(s string) (string, bool) {
if !path.IsAbs(s) {
return "", false
}
s = path.Clean(s)
return s, true
}
func ToConnData(connStr string) (ConnData, error) {
var o ConnData
if connStr == "" {
return o, nil
}
v, ok := probePostgres(connStr)
if ok {
o.typ = DBTYPE_POSTGRES
o.str = v
return o, nil
}
v, ok = probeGdbm(connStr)
if ok {
o.typ = DBTYPE_GDBM
o.str = v
return o, nil
}
return o, fmt.Errorf("invalid connection string: %s", connStr)
}

View File

@@ -0,0 +1,28 @@
package storage
import (
"testing"
)
func TestParseConnStr(t *testing.T) {
_, err := ToConnData("postgres://foo:bar@localhost:5432/baz")
if err != nil {
t.Fatal(err)
}
_, err = ToConnData("/foo/bar")
if err != nil {
t.Fatal(err)
}
_, err = ToConnData("/foo/bar/")
if err != nil {
t.Fatal(err)
}
_, err = ToConnData("foo/bar")
if err == nil {
t.Fatalf("expected error")
}
_, err = ToConnData("http://foo/bar")
if err == nil {
t.Fatalf("expected error")
}
}

View File

@@ -13,7 +13,6 @@ import (
"git.defalsify.org/vise.git/logging"
"git.defalsify.org/vise.git/persist"
"git.defalsify.org/vise.git/resource"
"git.grassecon.net/urdt/ussd/initializers"
gdbmstorage "git.grassecon.net/urdt/ussd/internal/storage/db/gdbm"
"github.com/jackc/pgx/v5/pgxpool"
)
@@ -26,11 +25,10 @@ type StorageService interface {
GetPersister(ctx context.Context) (*persist.Persister, error)
GetUserdataDb(ctx context.Context) db.Db
GetResource(ctx context.Context) (resource.Resource, error)
EnsureDbDir() error
}
type MenuStorageService struct {
dbDir string
conn ConnData
resourceDir string
poResource resource.Resource
resourceStore db.Db
@@ -38,29 +36,53 @@ type MenuStorageService struct {
userDataStore db.Db
}
func BuildConnStr() string {
host := initializers.GetEnv("DB_HOST", "localhost")
user := initializers.GetEnv("DB_USER", "postgres")
password := initializers.GetEnv("DB_PASSWORD", "")
dbName := initializers.GetEnv("DB_NAME", "")
port := initializers.GetEnv("DB_PORT", "5432")
connString := fmt.Sprintf(
"postgres://%s:%s@%s:%s/%s",
user, password, host, port, dbName,
)
logg.Debugf("pg conn string", "conn", connString)
return connString
}
func NewMenuStorageService(dbDir string, resourceDir string) *MenuStorageService {
func NewMenuStorageService(conn ConnData, resourceDir string) *MenuStorageService {
return &MenuStorageService{
dbDir: dbDir,
conn: conn,
resourceDir: resourceDir,
}
}
func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.Db, section string) (db.Db, error) {
var newDb db.Db
var err error
if existingDb != nil {
return existingDb, nil
}
connStr := ms.conn.String()
dbTyp := ms.conn.DbType()
if dbTyp == DBTYPE_POSTGRES {
// // Ensure the schema exists
// err = ensureSchemaExists(ctx, connStr, schema)
// if err != nil {
// return nil, fmt.Errorf("failed to ensure schema exists: %w", err)
// }
//
// newDb = postgres.NewPgDb().WithSchema(schema)
newDb = postgres.NewPgDb()
} else if dbTyp == DBTYPE_GDBM {
err = ms.ensureDbDir()
if err != nil {
return nil, err
}
connStr = path.Join(connStr, section)
newDb = gdbmstorage.NewThreadGdbmDb()
} else {
return nil, fmt.Errorf("unsupported connection string: '%s'\n", ms.conn.String())
}
logg.DebugCtxf(ctx, "connecting to db", "conn", connStr)
err = newDb.Connect(ctx, connStr)
if err != nil {
return nil, err
}
return newDb, nil
}
// WithGettext triggers use of gettext for translation of templates and menus.
//
// The first language in `lns` will be used as default language, to resolve node keys to
@@ -83,48 +105,6 @@ func (ms *MenuStorageService) WithGettext(path string, lns []lang.Language) *Men
return ms
}
func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.Db, fileName string) (db.Db, error) {
database, ok := ctx.Value("Database").(string)
if !ok {
return nil, fmt.Errorf("failed to select the database")
}
schema, ok := ctx.Value("Schema").(string)
if !ok {
return nil, fmt.Errorf("failed to select the schema")
}
if existingDb != nil {
return existingDb, nil
}
var newDb db.Db
var err error
if database == "postgres" {
connStr := BuildConnStr()
// Ensure the schema exists
err = ensureSchemaExists(ctx, connStr, schema)
if err != nil {
return nil, fmt.Errorf("failed to ensure schema exists: %w", err)
}
newDb = postgres.NewPgDb().WithSchema(schema)
err = newDb.Connect(ctx, connStr)
} else {
newDb = gdbmstorage.NewThreadGdbmDb()
storeFile := path.Join(ms.dbDir, fileName)
err = newDb.Connect(ctx, storeFile)
}
if err != nil {
return nil, err
}
return newDb, nil
}
// ensureSchemaExists creates a new schema if it does not exist
func ensureSchemaExists(ctx context.Context, connStr, schema string) error {
conn, err := pgxpool.New(ctx, connStr)
@@ -196,8 +176,8 @@ func (ms *MenuStorageService) GetStateStore(ctx context.Context) (db.Db, error)
return ms.stateStore, nil
}
func (ms *MenuStorageService) EnsureDbDir() error {
err := os.MkdirAll(ms.dbDir, 0700)
func (ms *MenuStorageService) ensureDbDir() error {
err := os.MkdirAll(ms.conn.String(), 0700)
if err != nil {
return fmt.Errorf("state dir create exited with error: %v\n", err)
}

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"os"
"path"
"path/filepath"
"time"
"git.defalsify.org/vise.git/engine"
@@ -16,11 +17,9 @@ import (
"git.grassecon.net/urdt/ussd/internal/testutil/testservice"
"git.grassecon.net/urdt/ussd/internal/testutil/testtag"
"git.grassecon.net/urdt/ussd/remote"
testdataloader "github.com/peteole/testdata-loader"
)
var (
baseDir = testdataloader.GetBasePath()
logg = logging.NewVanilla()
scriptDir = path.Join(baseDir, "services", "registration")
selectedDatabase = ""
@@ -28,7 +27,7 @@ var (
)
func init() {
initializers.LoadEnvVariables(baseDir)
initializers.LoadEnvVariables()
}
// SetDatabase updates the database used by TestEngine
@@ -40,9 +39,6 @@ func SetDatabase(dbType string, dbSchema string) {
func TestEngine(sessionId string) (engine.Engine, func(), chan bool) {
ctx := context.Background()
ctx = context.WithValue(ctx, "SessionId", sessionId)
ctx = context.WithValue(ctx, "Database", selectedDatabase)
ctx = context.WithValue(ctx, "Schema", selectedDbSchema)
pfp := path.Join(scriptDir, "pp.csv")
var eventChannel = make(chan bool)
@@ -54,37 +50,40 @@ func TestEngine(sessionId string) (engine.Engine, func(), chan bool) {
FlagCount: uint32(128),
}
dbDir := ".test_state"
resourceDir := scriptDir
menuStorageService := storage.NewMenuStorageService(dbDir, resourceDir)
err := menuStorageService.EnsureDbDir()
connStr, err := filepath.Abs(".test_state/state.gdbm")
if err != nil {
fmt.Fprintf(os.Stderr, err.Error())
fmt.Fprintf(os.Stderr, "connstr err: %v", err)
os.Exit(1)
}
conn, err := storage.ToConnData(connStr)
if err != nil {
fmt.Fprintf(os.Stderr, "connstr parse err: %v", err)
os.Exit(1)
}
resourceDir := scriptDir
menuStorageService := storage.NewMenuStorageService(conn, resourceDir)
rs, err := menuStorageService.GetResource(ctx)
if err != nil {
fmt.Fprintf(os.Stderr, err.Error())
fmt.Fprintf(os.Stderr, "resource error: %v", err)
os.Exit(1)
}
pe, err := menuStorageService.GetPersister(ctx)
if err != nil {
fmt.Fprintf(os.Stderr, err.Error())
fmt.Fprintf(os.Stderr, "persister error: %v", err)
os.Exit(1)
}
userDataStore, err := menuStorageService.GetUserdataDb(ctx)
if err != nil {
fmt.Fprintf(os.Stderr, err.Error())
fmt.Fprintf(os.Stderr, "userdb error: %v", err)
os.Exit(1)
}
dbResource, ok := rs.(*resource.DbResource)
if !ok {
fmt.Fprintf(os.Stderr, err.Error())
fmt.Fprintf(os.Stderr, "dbresource cast error")
os.Exit(1)
}

View File

@@ -0,0 +1,15 @@
package testutil
import (
"testing"
)
func TestCreateEngine(t *testing.T) {
o, clean, eventC := TestEngine("foo")
defer clean()
defer func() {
<-eventC
close(eventC)
}()
_ = o
}