diff --git a/.dockerignore b/.dockerignore index a118f64..2c2b83b 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,5 +1,6 @@ /** !/cmd/africastalking +!/cmd/ssh !/common !/config !/initializers diff --git a/.env.example b/.env.example index c636fa8..6d0368f 100644 --- a/.env.example +++ b/.env.example @@ -6,15 +6,15 @@ HOST=127.0.0.1 AT_ENDPOINT=/ussd/africastalking #PostgreSQL -DB_HOST=localhost -DB_USER=postgres -DB_PASSWORD=strongpass -DB_NAME=urdt_ussd -DB_PORT=5432 -DB_SSLMODE=disable -DB_TIMEZONE=Africa/Nairobi +DB_CONN=postgres://postgres:strongpass@localhost:5432/urdt_ussd +#DB_TIMEZONE=Africa/Nairobi +#DB_SCHEMA=vise #External API Calls CUSTODIAL_URL_BASE=http://localhost:5003 BEARER_TOKEN=eyJeSIsInRcCI6IkpXVCJ.yJwdWJsaWNLZXkiOiIwrrrrrr DATA_URL_BASE=http://localhost:5006 + +#Language +DEFAULT_LANGUAGE=eng +LANGUAGES=eng, swa diff --git a/Dockerfile b/Dockerfile index 3a5da7d..d68733c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -19,6 +19,7 @@ WORKDIR /build RUN echo "Building on $BUILDPLATFORM, building for $TARGETPLATFORM" RUN go mod download RUN go build -tags logtrace -o ussd-africastalking -ldflags="-X main.build=${BUILD} -s -w" cmd/africastalking/main.go +RUN go build -tags logtrace -o ussd-ssh -ldflags="-X main.build=${BUILD} -s -w" cmd/ssh/main.go FROM debian:bookworm-slim @@ -30,6 +31,7 @@ RUN apt-get clean && rm -rf /var/lib/apt/lists/* WORKDIR /service COPY --from=build /build/ussd-africastalking . +COPY --from=build /build/ussd-ssh . COPY --from=build /build/LICENSE . COPY --from=build /build/README.md . COPY --from=build /build/services ./services @@ -37,5 +39,6 @@ COPY --from=build /build/.env.example . RUN mv .env.example .env EXPOSE 7123 +EXPOSE 7122 CMD ["./ussd-africastalking"] \ No newline at end of file diff --git a/cmd/africastalking/main.go b/cmd/africastalking/main.go index 0019239..24812a1 100644 --- a/cmd/africastalking/main.go +++ b/cmd/africastalking/main.go @@ -12,14 +12,15 @@ import ( "syscall" "git.defalsify.org/vise.git/engine" + "git.defalsify.org/vise.git/lang" "git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/resource" "git.grassecon.net/urdt/ussd/config" "git.grassecon.net/urdt/ussd/initializers" + "git.grassecon.net/urdt/ussd/internal/args" "git.grassecon.net/urdt/ussd/internal/handlers" "git.grassecon.net/urdt/ussd/internal/http/at" - httpserver "git.grassecon.net/urdt/ussd/internal/http/at" "git.grassecon.net/urdt/ussd/internal/storage" "git.grassecon.net/urdt/ussd/remote" ) @@ -34,29 +35,49 @@ var ( func init() { initializers.LoadEnvVariables() } + func main() { config.LoadConfig() - var dbDir string + var connStr string var resourceDir string var size uint - var database string var engineDebug bool var host string var port uint - flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from") + var err error + var gettextDir string + var langs args.LangVar + flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir") - flag.StringVar(&database, "db", "gdbm", "database to be used") + flag.StringVar(&connStr, "c", "", "connection string") flag.BoolVar(&engineDebug, "d", false, "use engine debug output") flag.UintVar(&size, "s", 160, "max size of output") flag.StringVar(&host, "h", initializers.GetEnv("HOST", "127.0.0.1"), "http host") flag.UintVar(&port, "p", initializers.GetEnvUint("PORT", 7123), "http port") + flag.StringVar(&gettextDir, "gettext", "", "use gettext translations from given directory") + flag.Var(&langs, "language", "add symbol resolution for language") flag.Parse() - logg.Infof("start command", "build", build, "dbdir", dbDir, "resourcedir", resourceDir, "outputsize", size) + if connStr == "" { + connStr = config.DbConn + } + connData, err := storage.ToConnData(connStr) + if err != nil { + fmt.Fprintf(os.Stderr, "connstr err: %v", err) + os.Exit(1) + } + + logg.Infof("start command", "build", build, "conn", connData, "resourcedir", resourceDir, "outputsize", size) ctx := context.Background() - ctx = context.WithValue(ctx, "Database", database) + ln, err := lang.LanguageFromCode(config.DefaultLanguage) + if err != nil { + fmt.Fprintf(os.Stderr, "default language set error: %v", err) + os.Exit(1) + } + ctx = context.WithValue(ctx, "Language", ln) + pfp := path.Join(scriptDir, "pp.csv") cfg := engine.Config{ @@ -70,14 +91,13 @@ func main() { cfg.EngineDebug = true } - menuStorageService := storage.NewMenuStorageService(dbDir, resourceDir) - rs, err := menuStorageService.GetResource(ctx) + menuStorageService := storage.NewMenuStorageService(connData, resourceDir) if err != nil { fmt.Fprintf(os.Stderr, err.Error()) os.Exit(1) } - err = menuStorageService.EnsureDbDir() + rs, err := menuStorageService.GetResource(ctx) if err != nil { fmt.Fprintf(os.Stderr, err.Error()) os.Exit(1) @@ -121,11 +141,9 @@ func main() { } defer stateStore.Close() - rp := &at.ATRequestParser{ - Context: ctx, - } + rp := &at.ATRequestParser{} bsh := handlers.NewBaseSessionHandler(cfg, rs, stateStore, userdataStore, rp, hl) - sh := httpserver.NewATSessionHandler(bsh) + sh := at.NewATSessionHandler(bsh) mux := http.NewServeMux() mux.Handle(initializers.GetEnv("AT_ENDPOINT", "/"), sh) diff --git a/cmd/async/main.go b/cmd/async/main.go index 9cd04b3..27db453 100644 --- a/cmd/async/main.go +++ b/cmd/async/main.go @@ -10,19 +10,21 @@ import ( "syscall" "git.defalsify.org/vise.git/engine" + "git.defalsify.org/vise.git/lang" "git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/resource" "git.grassecon.net/urdt/ussd/config" "git.grassecon.net/urdt/ussd/initializers" + "git.grassecon.net/urdt/ussd/internal/args" "git.grassecon.net/urdt/ussd/internal/handlers" "git.grassecon.net/urdt/ussd/internal/storage" "git.grassecon.net/urdt/ussd/remote" ) var ( - logg = logging.NewVanilla() - scriptDir = path.Join("services", "registration") + logg = logging.NewVanilla() + scriptDir = path.Join("services", "registration") menuSeparator = ": " ) @@ -35,7 +37,7 @@ type asyncRequestParser struct { input []byte } -func (p *asyncRequestParser) GetSessionId(r any) (string, error) { +func (p *asyncRequestParser) GetSessionId(ctx context.Context, r any) (string, error) { return p.sessionId, nil } @@ -46,28 +48,48 @@ func (p *asyncRequestParser) GetInput(r any) ([]byte, error) { func main() { config.LoadConfig() + var connStr string var sessionId string - var dbDir string var resourceDir string var size uint - var database string var engineDebug bool var host string var port uint + var err error + var gettextDir string + var langs args.LangVar + flag.StringVar(&sessionId, "session-id", "075xx2123", "session id") - flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from") flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir") - flag.StringVar(&database, "db", "gdbm", "database to be used") + flag.StringVar(&connStr, "c", "", "connection string") flag.BoolVar(&engineDebug, "d", false, "use engine debug output") flag.UintVar(&size, "s", 160, "max size of output") flag.StringVar(&host, "h", initializers.GetEnv("HOST", "127.0.0.1"), "http host") flag.UintVar(&port, "p", initializers.GetEnvUint("PORT", 7123), "http port") + flag.StringVar(&gettextDir, "gettext", "", "use gettext translations from given directory") + flag.Var(&langs, "language", "add symbol resolution for language") flag.Parse() - logg.Infof("start command", "dbdir", dbDir, "resourcedir", resourceDir, "outputsize", size, "sessionId", sessionId) + if connStr == "" { + connStr = config.DbConn + } + connData, err := storage.ToConnData(connStr) + if err != nil { + fmt.Fprintf(os.Stderr, "connstr err: %v", err) + os.Exit(1) + } + + logg.Infof("start command", "conn", connData, "resourcedir", resourceDir, "outputsize", size, "sessionId", sessionId) ctx := context.Background() - ctx = context.WithValue(ctx, "Database", database) + + ln, err := lang.LanguageFromCode(config.DefaultLanguage) + if err != nil { + fmt.Fprintf(os.Stderr, "default language set error: %v", err) + os.Exit(1) + } + ctx = context.WithValue(ctx, "Language", ln) + pfp := path.Join(scriptDir, "pp.csv") cfg := engine.Config{ @@ -81,14 +103,13 @@ func main() { cfg.EngineDebug = true } - menuStorageService := storage.NewMenuStorageService(dbDir, resourceDir) - rs, err := menuStorageService.GetResource(ctx) + menuStorageService := storage.NewMenuStorageService(connData, resourceDir) if err != nil { fmt.Fprintf(os.Stderr, err.Error()) os.Exit(1) } - err = menuStorageService.EnsureDbDir() + rs, err := menuStorageService.GetResource(ctx) if err != nil { fmt.Fprintf(os.Stderr, err.Error()) os.Exit(1) diff --git a/cmd/http/main.go b/cmd/http/main.go index 6ddfded..6617ca5 100644 --- a/cmd/http/main.go +++ b/cmd/http/main.go @@ -12,11 +12,13 @@ import ( "syscall" "git.defalsify.org/vise.git/engine" + "git.defalsify.org/vise.git/lang" "git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/resource" "git.grassecon.net/urdt/ussd/config" "git.grassecon.net/urdt/ussd/initializers" + "git.grassecon.net/urdt/ussd/internal/args" "git.grassecon.net/urdt/ussd/internal/handlers" httpserver "git.grassecon.net/urdt/ussd/internal/http" "git.grassecon.net/urdt/ussd/internal/storage" @@ -24,8 +26,8 @@ import ( ) var ( - logg = logging.NewVanilla() - scriptDir = path.Join("services", "registration") + logg = logging.NewVanilla() + scriptDir = path.Join("services", "registration") menuSeparator = ": " ) @@ -36,26 +38,46 @@ func init() { func main() { config.LoadConfig() - var dbDir string + var connStr string var resourceDir string var size uint - var database string var engineDebug bool var host string var port uint - flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from") + var err error + var gettextDir string + var langs args.LangVar + flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir") - flag.StringVar(&database, "db", "gdbm", "database to be used") + flag.StringVar(&connStr, "c", "", "connection string") flag.BoolVar(&engineDebug, "d", false, "use engine debug output") flag.UintVar(&size, "s", 160, "max size of output") flag.StringVar(&host, "h", initializers.GetEnv("HOST", "127.0.0.1"), "http host") flag.UintVar(&port, "p", initializers.GetEnvUint("PORT", 7123), "http port") + flag.StringVar(&gettextDir, "gettext", "", "use gettext translations from given directory") + flag.Var(&langs, "language", "add symbol resolution for language") flag.Parse() - logg.Infof("start command", "dbdir", dbDir, "resourcedir", resourceDir, "outputsize", size) + if connStr == "" { + connStr = config.DbConn + } + connData, err := storage.ToConnData(connStr) + if err != nil { + fmt.Fprintf(os.Stderr, "connstr err: %v", err) + os.Exit(1) + } + + logg.Infof("start command", "conn", connData, "resourcedir", resourceDir, "outputsize", size) ctx := context.Background() - ctx = context.WithValue(ctx, "Database", database) + + ln, err := lang.LanguageFromCode(config.DefaultLanguage) + if err != nil { + fmt.Fprintf(os.Stderr, "default language set error: %v", err) + os.Exit(1) + } + ctx = context.WithValue(ctx, "Language", ln) + pfp := path.Join(scriptDir, "pp.csv") cfg := engine.Config{ @@ -69,14 +91,9 @@ func main() { cfg.EngineDebug = true } - menuStorageService := storage.NewMenuStorageService(dbDir, resourceDir) - rs, err := menuStorageService.GetResource(ctx) - if err != nil { - fmt.Fprintf(os.Stderr, err.Error()) - os.Exit(1) - } + menuStorageService := storage.NewMenuStorageService(connData, resourceDir) - err = menuStorageService.EnsureDbDir() + rs, err := menuStorageService.GetResource(ctx) if err != nil { fmt.Fprintf(os.Stderr, err.Error()) os.Exit(1) diff --git a/cmd/main.go b/cmd/main.go index 4fd084f..d2fe0ba 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -8,10 +8,12 @@ import ( "path" "git.defalsify.org/vise.git/engine" + "git.defalsify.org/vise.git/lang" "git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/resource" "git.grassecon.net/urdt/ussd/config" "git.grassecon.net/urdt/ussd/initializers" + "git.grassecon.net/urdt/ussd/internal/args" "git.grassecon.net/urdt/ussd/internal/handlers" "git.grassecon.net/urdt/ussd/internal/storage" "git.grassecon.net/urdt/ussd/remote" @@ -27,26 +29,53 @@ func init() { initializers.LoadEnvVariables() } +// TODO: external script automatically generate language handler list from select language vise code OR consider dynamic menu generation script possibility func main() { config.LoadConfig() - var dbDir string + var connStr string var size uint var sessionId string - var database string var engineDebug bool + var resourceDir string + var err error + var gettextDir string + var langs args.LangVar + + flag.StringVar(&resourceDir, "resourcedir", scriptDir, "resource dir") flag.StringVar(&sessionId, "session-id", "075xx2123", "session id") - flag.StringVar(&database, "db", "gdbm", "database to be used") - flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from") + flag.StringVar(&connStr, "c", "", "connection string") flag.BoolVar(&engineDebug, "d", false, "use engine debug output") flag.UintVar(&size, "s", 160, "max size of output") + flag.StringVar(&gettextDir, "gettext", "", "use gettext translations from given directory") + flag.Var(&langs, "language", "add symbol resolution for language") flag.Parse() - logg.Infof("start command", "dbdir", dbDir, "outputsize", size) + if connStr == "" { + connStr = config.DbConn + } + connData, err := storage.ToConnData(connStr) + if err != nil { + fmt.Fprintf(os.Stderr, "connstr err: %v", err) + os.Exit(1) + } + + logg.Infof("start command", "conn", connData, "outputsize", size) + + if len(langs.Langs()) == 0 { + langs.Set(config.DefaultLanguage) + } ctx := context.Background() ctx = context.WithValue(ctx, "SessionId", sessionId) - ctx = context.WithValue(ctx, "Database", database) + + ln, err := lang.LanguageFromCode(config.DefaultLanguage) + if err != nil { + fmt.Fprintf(os.Stderr, "default language set error: %v", err) + os.Exit(1) + } + ctx = context.WithValue(ctx, "Language", ln) + pfp := path.Join(scriptDir, "pp.csv") cfg := engine.Config{ @@ -57,13 +86,10 @@ func main() { MenuSeparator: menuSeparator, } - resourceDir := scriptDir - menuStorageService := storage.NewMenuStorageService(dbDir, resourceDir) + menuStorageService := storage.NewMenuStorageService(connData, resourceDir) - err := menuStorageService.EnsureDbDir() - if err != nil { - fmt.Fprintf(os.Stderr, err.Error()) - os.Exit(1) + if gettextDir != "" { + menuStorageService = menuStorageService.WithGettext(gettextDir, langs.Langs()) } rs, err := menuStorageService.GetResource(ctx) diff --git a/cmd/ssh/README.md b/cmd/ssh/README.md new file mode 100644 index 0000000..ff325d7 --- /dev/null +++ b/cmd/ssh/README.md @@ -0,0 +1,34 @@ +# URDT-USSD SSH server + +An SSH server entry point for the vise engine. + + +## Adding public keys for access + +Map your (client) public key to a session identifier (e.g. phone number) + +``` +go run -v -tags logtrace ./cmd/ssh/sshkey/main.go -i [--dbdir ] +``` + + +## Create a private key for the server + +``` +ssh-keygen -N "" -f +``` + + +## Run the server + + +``` +go run -v -tags logtrace ./cmd/ssh/main.go -h -p [--dbdir ] +``` + + +## Connect to the server + +``` +ssh [-v] -T -p -i +``` diff --git a/cmd/ssh/main.go b/cmd/ssh/main.go new file mode 100644 index 0000000..51023e5 --- /dev/null +++ b/cmd/ssh/main.go @@ -0,0 +1,144 @@ +package main + +import ( + "context" + "flag" + "fmt" + "os" + "os/signal" + "path" + "sync" + "syscall" + + "git.defalsify.org/vise.git/db" + "git.defalsify.org/vise.git/engine" + "git.defalsify.org/vise.git/logging" + + "git.grassecon.net/urdt/ussd/config" + "git.grassecon.net/urdt/ussd/initializers" + "git.grassecon.net/urdt/ussd/internal/ssh" + "git.grassecon.net/urdt/ussd/internal/storage" +) + +var ( + wg sync.WaitGroup + keyStore db.Db + logg = logging.NewVanilla() + scriptDir = path.Join("services", "registration") + + build = "dev" +) + +func init() { + initializers.LoadEnvVariables() +} + +func main() { + config.LoadConfig() + + var connStr string + var authConnStr string + var resourceDir string + var size uint + var engineDebug bool + var stateDebug bool + var host string + var port uint + flag.StringVar(&connStr, "c", "", "connection string") + flag.StringVar(&authConnStr, "authdb", "", "auth connection string") + flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir") + flag.BoolVar(&engineDebug, "d", false, "use engine debug output") + flag.UintVar(&size, "s", 160, "max size of output") + flag.StringVar(&host, "h", "127.0.0.1", "socket host") + flag.UintVar(&port, "p", 7122, "socket port") + flag.Parse() + + if connStr == "" { + connStr = config.DbConn + } + if authConnStr == "" { + authConnStr = connStr + } + connData, err := storage.ToConnData(connStr) + if err != nil { + fmt.Fprintf(os.Stderr, "connstr err: %v", err) + os.Exit(1) + } + authConnData, err := storage.ToConnData(authConnStr) + if err != nil { + fmt.Fprintf(os.Stderr, "auth connstr err: %v", err) + os.Exit(1) + } + + sshKeyFile := flag.Arg(0) + _, err = os.Stat(sshKeyFile) + if err != nil { + fmt.Fprintf(os.Stderr, "cannot open ssh server private key file: %v\n", err) + os.Exit(1) + } + + ctx := context.Background() + logg.WarnCtxf(ctx, "!!!!! WARNING WARNING WARNING") + logg.WarnCtxf(ctx, "!!!!! =======================") + logg.WarnCtxf(ctx, "!!!!! This is not a production ready server!") + logg.WarnCtxf(ctx, "!!!!! Do not expose to internet and only use with tunnel!") + logg.WarnCtxf(ctx, "!!!!! (See ssh -L <...>)") + + logg.Infof("start command", "conn", connData, "authconn", authConnData, "resourcedir", resourceDir, "outputsize", size, "keyfile", sshKeyFile, "host", host, "port", port) + + pfp := path.Join(scriptDir, "pp.csv") + + cfg := engine.Config{ + Root: "root", + OutputSize: uint32(size), + FlagCount: uint32(16), + } + if stateDebug { + cfg.StateDebug = true + } + if engineDebug { + cfg.EngineDebug = true + } + + authKeyStore, err := ssh.NewSshKeyStore(ctx, authConnData.String()) + if err != nil { + fmt.Fprintf(os.Stderr, "keystore file open error: %v", err) + os.Exit(1) + } + defer func() { + logg.TraceCtxf(ctx, "shutdown auth key store reached") + err = authKeyStore.Close() + if err != nil { + logg.ErrorCtxf(ctx, "keystore close error", "err", err) + } + }() + + cint := make(chan os.Signal) + cterm := make(chan os.Signal) + signal.Notify(cint, os.Interrupt, syscall.SIGINT) + signal.Notify(cterm, os.Interrupt, syscall.SIGTERM) + + runner := &ssh.SshRunner{ + Cfg: cfg, + Debug: engineDebug, + FlagFile: pfp, + Conn: connData, + ResourceDir: resourceDir, + SrvKeyFile: sshKeyFile, + Host: host, + Port: port, + } + go func() { + select { + case _ = <-cint: + case _ = <-cterm: + } + logg.TraceCtxf(ctx, "shutdown runner reached") + err := runner.Stop() + if err != nil { + logg.ErrorCtxf(ctx, "runner stop error", "err", err) + } + + }() + runner.Run(ctx, authKeyStore) +} diff --git a/cmd/ssh/sshkey/main.go b/cmd/ssh/sshkey/main.go new file mode 100644 index 0000000..87b89a3 --- /dev/null +++ b/cmd/ssh/sshkey/main.go @@ -0,0 +1,44 @@ +package main + +import ( + "context" + "flag" + "fmt" + "os" + + "git.grassecon.net/urdt/ussd/internal/ssh" +) + +func main() { + var dbDir string + var sessionId string + flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from") + flag.StringVar(&sessionId, "i", "", "session id") + flag.Parse() + + if sessionId == "" { + fmt.Fprintf(os.Stderr, "empty session id\n") + os.Exit(1) + } + + ctx := context.Background() + + sshKeyFile := flag.Arg(0) + if sshKeyFile == "" { + fmt.Fprintf(os.Stderr, "missing key file argument\n") + os.Exit(1) + } + + store, err := ssh.NewSshKeyStore(ctx, dbDir) + if err != nil { + fmt.Fprintf(os.Stderr, "%v\n", err) + os.Exit(1) + } + defer store.Close() + + err = store.AddFromFile(ctx, sshKeyFile, sessionId) + if err != nil { + fmt.Fprintf(os.Stderr, "%v\n", err) + os.Exit(1) + } +} diff --git a/common/db.go b/common/db.go index a5cf1c1..2271716 100644 --- a/common/db.go +++ b/common/db.go @@ -7,7 +7,7 @@ import ( "git.defalsify.org/vise.git/logging" ) -// DataType is a subprefix value used in association with vise/db.DATATYPE_USERDATA. +// DataType is a subprefix value used in association with vise/db.DATATYPE_USERDATA. // // All keys are used only within the context of a single account. Unless otherwise specified, the user context is the session id. // @@ -55,6 +55,10 @@ const ( DATA_ACTIVE_DECIMAL // EVM address of the currently active voucher DATA_ACTIVE_ADDRESS + //Holds count of the number of incorrect PIN attempts + DATA_INCORRECT_PIN_ATTEMPTS + //ISO 639 code for the selected language. + DATA_SELECTED_LANGUAGE_CODE ) const ( diff --git a/common/pin.go b/common/pin.go index 6db9d15..13f21b3 100644 --- a/common/pin.go +++ b/common/pin.go @@ -6,9 +6,13 @@ import ( "golang.org/x/crypto/bcrypt" ) -// Define the regex pattern as a constant const ( + // Define the regex pattern as a constant pinPattern = `^\d{4}$` + + //Allowed incorrect PIN attempts + AllowedPINAttempts = uint8(3) + ) // checks whether the given input is a 4 digit number diff --git a/common/storage.go b/common/storage.go index d37bce3..2960578 100644 --- a/common/storage.go +++ b/common/storage.go @@ -23,17 +23,17 @@ type StorageServices interface { GetPersister(ctx context.Context) (*persist.Persister, error) GetUserdataDb(ctx context.Context) (db.Db, error) GetResource(ctx context.Context) (resource.Resource, error) - EnsureDbDir() error } type StorageService struct { svc *storage.MenuStorageService } -func NewStorageService(dbDir string) *StorageService { - return &StorageService{ - svc: storage.NewMenuStorageService(dbDir, ""), +func NewStorageService(conn storage.ConnData) (*StorageService, error) { + svc := &StorageService{ + svc: storage.NewMenuStorageService(conn, ""), } + return svc, nil } func(ss *StorageService) GetPersister(ctx context.Context) (*persist.Persister, error) { @@ -47,7 +47,3 @@ func(ss *StorageService) GetUserdataDb(ctx context.Context) (db.Db, error) { func(ss *StorageService) GetResource(ctx context.Context) (resource.Resource, error) { return nil, errors.New("not implemented") } - -func(ss *StorageService) EnsureDbDir() error { - return ss.svc.EnsureDbDir() -} diff --git a/config/config.go b/config/config.go index 3a8e8ed..4b43b42 100644 --- a/config/config.go +++ b/config/config.go @@ -2,6 +2,7 @@ package config import ( "net/url" + "strings" "git.grassecon.net/urdt/ussd/initializers" ) @@ -18,6 +19,11 @@ const ( AliasPrefix = "api/v1/alias" ) +var ( + defaultLanguage = "eng" + languages []string +) + var ( custodialURLBase string dataURLBase string @@ -34,8 +40,29 @@ var ( VoucherTransfersURL string VoucherDataURL string CheckAliasURL string + DbConn string + DefaultLanguage string + Languages []string ) +func setLanguage() error { + defaultLanguage = initializers.GetEnv("DEFAULT_LANGUAGE", defaultLanguage) + languages = strings.Split(initializers.GetEnv("LANGUAGES", defaultLanguage), ",") + haveDefaultLanguage := false + for i, v := range(languages) { + languages[i] = strings.ReplaceAll(v, " ", "") + if languages[i] == defaultLanguage { + haveDefaultLanguage = true + } + } + + if !haveDefaultLanguage { + languages = append([]string{defaultLanguage}, languages...) + } + + return nil +} + func setBase() error { var err error @@ -43,14 +70,20 @@ func setBase() error { dataURLBase = initializers.GetEnv("DATA_URL_BASE", "http://localhost:5006") BearerToken = initializers.GetEnv("BEARER_TOKEN", "") - _, err = url.JoinPath(custodialURLBase, "/foo") + _, err = url.Parse(custodialURLBase) if err != nil { return err } - _, err = url.JoinPath(dataURLBase, "/bar") + _, err = url.Parse(dataURLBase) if err != nil { return err } + + return nil +} + +func setConn() error { + DbConn = initializers.GetEnv("DB_CONN", "") return nil } @@ -60,6 +93,14 @@ func LoadConfig() error { if err != nil { return err } + err = setConn() + if err != nil { + return err + } + err = setLanguage() + if err != nil { + return err + } CreateAccountURL, _ = url.JoinPath(custodialURLBase, createAccountPath) TrackStatusURL, _ = url.JoinPath(custodialURLBase, trackStatusPath) BalanceURL, _ = url.JoinPath(custodialURLBase, balancePathPrefix) @@ -69,6 +110,8 @@ func LoadConfig() error { VoucherTransfersURL, _ = url.JoinPath(dataURLBase, voucherTransfersPathPrefix) VoucherDataURL, _ = url.JoinPath(dataURLBase, voucherDataPathPrefix) CheckAliasURL, _ = url.JoinPath(dataURLBase, AliasPrefix) + DefaultLanguage = defaultLanguage + Languages = languages return nil } diff --git a/devtools/lang/main.go b/devtools/lang/main.go new file mode 100644 index 0000000..83c68b3 --- /dev/null +++ b/devtools/lang/main.go @@ -0,0 +1,126 @@ +// create language files from environment +package main + +import ( + "flag" + "fmt" + "os" + "path" + "strings" + + "git.defalsify.org/vise.git/logging" + "git.defalsify.org/vise.git/lang" + "git.grassecon.net/urdt/ussd/config" + "git.grassecon.net/urdt/ussd/initializers" +) + +const ( + + changeHeadSrc = `LOAD reset_account_authorized 0 +LOAD reset_incorrect 0 +CATCH incorrect_pin flag_incorrect_pin 1 +CATCH pin_entry flag_account_authorized 0 +` + + selectSrc = `LOAD set_language 6 +RELOAD set_language +CATCH terms flag_account_created 0 +MOVE language_changed +` +) + +var ( + logg = logging.NewVanilla() + mouts string + incmps string +) + +func init() { + initializers.LoadEnvVariables() +} + +func toLanguageLabel(ln lang.Language) string { + s := ln.Name + v := strings.Split(s, " (") + if len(v) > 1 { + s = v[0] + } + return s +} + +func toLanguageKey(ln lang.Language) string { + s := toLanguageLabel(ln) + return strings.ToLower(s) +} + +func main() { + var srcDir string + + flag.StringVar(&srcDir, "o", ".", "resource dir write to") + flag.Parse() + + logg.Infof("start command", "dir", srcDir) + + err := config.LoadConfig() + if err != nil { + fmt.Fprintf(os.Stderr, "config load error: %v", err) + os.Exit(1) + } + logg.Tracef("using languages", "lang", config.Languages) + + for i, v := range(config.Languages) { + ln, err := lang.LanguageFromCode(v) + if err != nil { + fmt.Fprintf(os.Stderr, "error parsing language: %s\n", v) + os.Exit(1) + } + n := i + 1 + s := toLanguageKey(ln) + mouts += fmt.Sprintf("MOUT %s %v\n", s, n) + v = "set_" + ln.Code + incmps += fmt.Sprintf("INCMP %s %v\n", v, n) + + p := path.Join(srcDir, v) + w, err := os.OpenFile(p, os.O_WRONLY | os.O_CREATE | os.O_EXCL, 0600) + if err != nil { + fmt.Fprintf(os.Stderr, "failed open language set template output: %v\n", err) + os.Exit(1) + } + s = toLanguageLabel(ln) + defer w.Close() + _, err = w.Write([]byte(s)) + if err != nil { + fmt.Fprintf(os.Stderr, "failed write select language vis output: %v\n", err) + os.Exit(1) + } + } + src := mouts + "HALT\n" + incmps + src += "INCMP . *\n" + + p := path.Join(srcDir, "select_language.vis") + w, err := os.OpenFile(p, os.O_WRONLY | os.O_CREATE | os.O_EXCL, 0600) + if err != nil { + fmt.Fprintf(os.Stderr, "failed open select language vis output: %v\n", err) + os.Exit(1) + } + defer w.Close() + _, err = w.Write([]byte(src)) + if err != nil { + fmt.Fprintf(os.Stderr, "failed write select language vis output: %v\n", err) + os.Exit(1) + } + + src = changeHeadSrc + src + p = path.Join(srcDir, "change_language.vis") + w, err = os.OpenFile(p, os.O_WRONLY | os.O_CREATE | os.O_EXCL, 0600) + if err != nil { + fmt.Fprintf(os.Stderr, "failed open select language vis output: %v\n", err) + os.Exit(1) + } + defer w.Close() + _, err = w.Write([]byte(src)) + if err != nil { + fmt.Fprintf(os.Stderr, "failed write select language vis output: %v\n", err) + os.Exit(1) + } +} diff --git a/devtools/store/main.go b/devtools/store/dump/main.go similarity index 66% rename from devtools/store/main.go rename to devtools/store/dump/main.go index 8bd4d16..c84a134 100644 --- a/devtools/store/main.go +++ b/devtools/store/dump/main.go @@ -25,26 +25,46 @@ func init() { } +func formatItem(k []byte, v []byte) (string, error) { + o, err := debug.FromKey(k) + if err != nil { + return "", err + } + s := fmt.Sprintf("%vValue: %v\n\n", o, string(v)) + return s, nil +} + func main() { config.LoadConfig() - var dbDir string + var connStr string var sessionId string var database string var engineDebug bool + var err error flag.StringVar(&sessionId, "session-id", "075xx2123", "session id") - flag.StringVar(&database, "db", "gdbm", "database to be used") - flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from") + flag.StringVar(&connStr, "c", ".state", "connection string") flag.BoolVar(&engineDebug, "d", false, "use engine debug output") flag.Parse() + if connStr != "" { + connStr = config.DbConn + } + connData, err := storage.ToConnData(config.DbConn) + if err != nil { + fmt.Fprintf(os.Stderr, "connstr err: %v", err) + os.Exit(1) + } + + logg.Infof("start command", "conn", connData) + ctx := context.Background() ctx = context.WithValue(ctx, "SessionId", sessionId) ctx = context.WithValue(ctx, "Database", database) resourceDir := scriptDir - menuStorageService := storage.NewMenuStorageService(dbDir, resourceDir) + menuStorageService := storage.NewMenuStorageService(connData, resourceDir) store, err := menuStorageService.GetUserdataDb(ctx) if err != nil { @@ -64,12 +84,12 @@ func main() { if k == nil { break } - o, err := debug.FromKey(k) + r, err := formatItem(k, v) if err != nil { - fmt.Fprintf(os.Stderr, err.Error()) + fmt.Fprintf(os.Stderr, "format db item error: %v", err) os.Exit(1) } - fmt.Printf("%vValue: %v\n\n", o, string(v)) + fmt.Printf(r) } err = store.Close() diff --git a/devtools/gen/main.go b/devtools/store/generate/main.go similarity index 76% rename from devtools/gen/main.go rename to devtools/store/generate/main.go index b9e2aed..c421d1a 100644 --- a/devtools/gen/main.go +++ b/devtools/store/generate/main.go @@ -9,14 +9,16 @@ import ( "path" "git.defalsify.org/vise.git/logging" - "git.grassecon.net/urdt/ussd/config" - "git.grassecon.net/urdt/ussd/internal/storage" - "git.grassecon.net/urdt/ussd/initializers" "git.grassecon.net/urdt/ussd/common" + "git.grassecon.net/urdt/ussd/config" + "git.grassecon.net/urdt/ussd/initializers" + "git.grassecon.net/urdt/ussd/internal/storage" + testdataloader "github.com/peteole/testdata-loader" ) var ( logg = logging.NewVanilla() + baseDir = testdataloader.GetBasePath() scriptDir = path.Join("services", "registration") ) @@ -24,28 +26,38 @@ func init() { initializers.LoadEnvVariables() } - func main() { config.LoadConfig() - var dbDir string + var connStr string var sessionId string var database string var engineDebug bool + var err error flag.StringVar(&sessionId, "session-id", "075xx2123", "session id") - flag.StringVar(&database, "db", "gdbm", "database to be used") - flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from") + flag.StringVar(&connStr, "c", "", "connection string") flag.BoolVar(&engineDebug, "d", false, "use engine debug output") flag.Parse() + if connStr != "" { + connStr = config.DbConn + } + connData, err := storage.ToConnData(config.DbConn) + if err != nil { + fmt.Fprintf(os.Stderr, "connstr err: %v", err) + os.Exit(1) + } + + logg.Infof("start command", "conn", connData) + ctx := context.Background() ctx = context.WithValue(ctx, "SessionId", sessionId) ctx = context.WithValue(ctx, "Database", database) resourceDir := scriptDir - menuStorageService := storage.NewMenuStorageService(dbDir, resourceDir) - + menuStorageService := storage.NewMenuStorageService(connData, resourceDir) + store, err := menuStorageService.GetUserdataDb(ctx) if err != nil { fmt.Fprintf(os.Stderr, err.Error()) @@ -75,5 +87,4 @@ func main() { fmt.Fprintf(os.Stderr, err.Error()) os.Exit(1) } - } diff --git a/initializers/load.go b/initializers/load.go index 4ea5980..fc61746 100644 --- a/initializers/load.go +++ b/initializers/load.go @@ -3,24 +3,30 @@ package initializers import ( "log" "os" + "path" "strconv" "github.com/joho/godotenv" ) func LoadEnvVariables() { - err := godotenv.Load() + LoadEnvVariablesPath(".") +} + +func LoadEnvVariablesPath(dir string) { + fp := path.Join(dir, ".env") + err := godotenv.Load(fp) if err != nil { - log.Fatal("Error loading .env file") + log.Fatal("Error loading .env file", err) } } // Helper to get environment variables with a default fallback func GetEnv(key, defaultVal string) string { - if value, exists := os.LookupEnv(key); exists { - return value + if value, exists := os.LookupEnv(key); exists { + return value } - return defaultVal + return defaultVal } // Helper to safely convert environment variables to uint diff --git a/internal/args/lang.go b/internal/args/lang.go new file mode 100644 index 0000000..f9afdc9 --- /dev/null +++ b/internal/args/lang.go @@ -0,0 +1,34 @@ +package args + +import ( + "strings" + + "git.defalsify.org/vise.git/lang" +) + +type LangVar struct { + v []lang.Language +} + +func(lv *LangVar) Set(s string) error { + v, err := lang.LanguageFromCode(s) + if err != nil { + return err + } + lv.v = append(lv.v, v) + return err +} + +func(lv *LangVar) String() string { + var s []string + for _, v := range(lv.v) { + s = append(s, v.Code) + } + return strings.Join(s, ",") +} + +func(lv *LangVar) Langs() []lang.Language { + return lv.v +} + + diff --git a/internal/handlers/ussd/menuhandler.go b/internal/handlers/application/menu_handler.go similarity index 93% rename from internal/handlers/ussd/menuhandler.go rename to internal/handlers/application/menu_handler.go index 095d77b..193f7fe 100644 --- a/internal/handlers/ussd/menuhandler.go +++ b/internal/handlers/application/menu_handler.go @@ -1,4 +1,4 @@ -package ussd +package application import ( "bytes" @@ -28,7 +28,7 @@ import ( ) var ( - logg = logging.NewVanilla().WithDomain("ussdmenuhandler").WithContextKey("session-id") + logg = logging.NewVanilla().WithDomain("ussdmenuhandler").WithContextKey("SessionId") scriptDir = path.Join("services", "registration") translationDir = path.Join(scriptDir, "locale") ) @@ -124,7 +124,7 @@ func (h *Handlers) Init(ctx context.Context, sym string, input []byte) (resource sessionId, ok := ctx.Value("SessionId").(string) if ok { - context.WithValue(ctx, "session-id", sessionId) + ctx = context.WithValue(ctx, "SessionId", sessionId) } flag_admin_privilege, _ := h.flagManager.GetFlag("flag_admin_privilege") @@ -161,9 +161,12 @@ func (h *Handlers) SetLanguage(ctx context.Context, sym string, input []byte) (r //Fallback to english instead? code = "eng" } - res.FlagSet = append(res.FlagSet, state.FLAG_LANG) + err := h.persistLanguageCode(ctx, code) + if err != nil { + return res, err + } res.Content = code - + res.FlagSet = append(res.FlagSet, state.FLAG_LANG) languageSetFlag, err := h.flagManager.GetFlag("flag_language_set") if err != nil { logg.ErrorCtxf(ctx, "Error setting the languageSetFlag", "error", err) @@ -734,11 +737,23 @@ func (h *Handlers) Authorize(ctx context.Context, sym string, input []byte) (res if h.st.MatchFlag(flag_account_authorized, false) { res.FlagReset = append(res.FlagReset, flag_incorrect_pin) res.FlagSet = append(res.FlagSet, flag_allow_update, flag_account_authorized) + err := h.resetIncorrectPINAttempts(ctx, sessionId) + if err != nil { + return res, err + } } else { res.FlagSet = append(res.FlagSet, flag_allow_update) res.FlagReset = append(res.FlagReset, flag_account_authorized) + err := h.resetIncorrectPINAttempts(ctx, sessionId) + if err != nil { + return res, err + } } } else { + err := h.incrementIncorrectPINAttempts(ctx, sessionId) + if err != nil { + return res, err + } res.FlagSet = append(res.FlagSet, flag_incorrect_pin) res.FlagReset = append(res.FlagReset, flag_account_authorized) return res, nil @@ -752,8 +767,34 @@ func (h *Handlers) Authorize(ctx context.Context, sym string, input []byte) (res // ResetIncorrectPin resets the incorrect pin flag after a new PIN attempt. func (h *Handlers) ResetIncorrectPin(ctx context.Context, sym string, input []byte) (resource.Result, error) { var res resource.Result + store := h.userdataStore + flag_incorrect_pin, _ := h.flagManager.GetFlag("flag_incorrect_pin") + flag_account_blocked, _ := h.flagManager.GetFlag("flag_account_blocked") + + sessionId, ok := ctx.Value("SessionId").(string) + if !ok { + return res, fmt.Errorf("missing session") + } + res.FlagReset = append(res.FlagReset, flag_incorrect_pin) + + currentWrongPinAttempts, err := store.ReadEntry(ctx, sessionId, common.DATA_INCORRECT_PIN_ATTEMPTS) + if err != nil { + if !db.IsNotFound(err) { + return res, err + } + } + pinAttemptsValue, _ := strconv.ParseUint(string(currentWrongPinAttempts), 0, 64) + remainingPINAttempts := common.AllowedPINAttempts - uint8(pinAttemptsValue) + if remainingPINAttempts == 0 { + res.FlagSet = append(res.FlagSet, flag_account_blocked) + return res, nil + } + if remainingPINAttempts < common.AllowedPINAttempts { + res.Content = strconv.Itoa(int(remainingPINAttempts)) + } + return res, nil } @@ -835,11 +876,21 @@ func (h *Handlers) QuitWithHelp(ctx context.Context, sym string, input []byte) ( l := gotext.NewLocale(translationDir, code) l.AddDomain("default") - res.Content = l.Get("For more help,please call: 0757628885") + res.Content = l.Get("For more help, please call: 0757628885") res.FlagReset = append(res.FlagReset, flag_account_authorized) return res, nil } +// ShowBlockedAccount displays a message after an account has been blocked and how to reach support. +func (h *Handlers) ShowBlockedAccount(ctx context.Context, sym string, input []byte) (resource.Result, error) { + var res resource.Result + code := codeFromCtx(ctx) + l := gotext.NewLocale(translationDir, code) + l.AddDomain("default") + res.Content = l.Get("Your account has been locked. For help on how to unblock your account, contact support at: 0757628885") + return res, nil +} + // VerifyYob verifies the length of the given input. func (h *Handlers) VerifyYob(ctx context.Context, sym string, input []byte) (resource.Result, error) { var res resource.Result @@ -2075,3 +2126,68 @@ func (h *Handlers) UpdateAllProfileItems(ctx context.Context, sym string, input } return res, nil } + +// incrementIncorrectPINAttempts keeps track of the number of incorrect PIN attempts +func (h *Handlers) incrementIncorrectPINAttempts(ctx context.Context, sessionId string) error { + var pinAttemptsCount uint8 + store := h.userdataStore + + currentWrongPinAttempts, err := store.ReadEntry(ctx, sessionId, common.DATA_INCORRECT_PIN_ATTEMPTS) + if err != nil { + if db.IsNotFound(err) { + //First time Wrong PIN attempt: initialize with a count of 1 + pinAttemptsCount = 1 + err = store.WriteEntry(ctx, sessionId, common.DATA_INCORRECT_PIN_ATTEMPTS, []byte(strconv.Itoa(int(pinAttemptsCount)))) + if err != nil { + logg.ErrorCtxf(ctx, "failed to write incorrect PIN attempts ", "key", common.DATA_INCORRECT_PIN_ATTEMPTS, "value", currentWrongPinAttempts, "error", err) + return err + } + return nil + } + } + pinAttemptsValue, _ := strconv.ParseUint(string(currentWrongPinAttempts), 0, 64) + pinAttemptsCount = uint8(pinAttemptsValue) + 1 + + err = store.WriteEntry(ctx, sessionId, common.DATA_INCORRECT_PIN_ATTEMPTS, []byte(strconv.Itoa(int(pinAttemptsCount)))) + if err != nil { + logg.ErrorCtxf(ctx, "failed to write incorrect PIN attempts ", "key", common.DATA_INCORRECT_PIN_ATTEMPTS, "value", pinAttemptsCount, "error", err) + return err + } + return nil +} + +// resetIncorrectPINAttempts resets the number of incorrect PIN attempts after a correct PIN entry +func (h *Handlers) resetIncorrectPINAttempts(ctx context.Context, sessionId string) error { + store := h.userdataStore + currentWrongPinAttempts, err := store.ReadEntry(ctx, sessionId, common.DATA_INCORRECT_PIN_ATTEMPTS) + if err != nil { + if db.IsNotFound(err) { + return nil + } + return err + } + currentWrongPinAttemptsCount, _ := strconv.ParseUint(string(currentWrongPinAttempts), 0, 64) + if currentWrongPinAttemptsCount <= uint64(common.AllowedPINAttempts) { + err = store.WriteEntry(ctx, sessionId, common.DATA_INCORRECT_PIN_ATTEMPTS, []byte(string("0"))) + if err != nil { + logg.ErrorCtxf(ctx, "failed to reset incorrect PIN attempts ", "key", common.DATA_INCORRECT_PIN_ATTEMPTS, "value", common.AllowedPINAttempts, "error", err) + return err + } + } + return nil +} + +// persistLanguageCode persists the selected ISO 639 language code +func (h *Handlers) persistLanguageCode(ctx context.Context, code string) error { + store := h.userdataStore + sessionId, ok := ctx.Value("SessionId").(string) + if !ok { + return fmt.Errorf("missing session") + } + err := store.WriteEntry(ctx, sessionId, common.DATA_SELECTED_LANGUAGE_CODE, []byte(code)) + if err != nil { + logg.ErrorCtxf(ctx, "failed to persist language code", "key", common.DATA_SELECTED_LANGUAGE_CODE, "value", code, "error", err) + return err + } + return nil +} diff --git a/internal/handlers/ussd/menuhandler_test.go b/internal/handlers/application/menu_handler_test.go similarity index 93% rename from internal/handlers/ussd/menuhandler_test.go rename to internal/handlers/application/menu_handler_test.go index 914dffc..487fe2b 100644 --- a/internal/handlers/ussd/menuhandler_test.go +++ b/internal/handlers/application/menu_handler_test.go @@ -1,10 +1,11 @@ -package ussd +package application import ( "context" "fmt" "log" "path" + "strconv" "strings" "testing" @@ -774,6 +775,11 @@ func TestSetLanguage(t *testing.T) { log.Fatal(err) } + sessionId := "session123" + ctx, store := InitializeTestStore(t) + + ctx = context.WithValue(ctx, "SessionId", sessionId) + // Define test cases tests := []struct { name string @@ -806,12 +812,13 @@ func TestSetLanguage(t *testing.T) { // Create the Handlers instance with the mock flag manager h := &Handlers{ - flagManager: fm.parser, - st: mockState, + flagManager: fm.parser, + userdataStore: store, + st: mockState, } // Call the method - res, err := h.SetLanguage(context.Background(), "set_language", nil) + res, err := h.SetLanguage(ctx, "set_language", nil) if err != nil { t.Error(err) } @@ -907,37 +914,79 @@ func TestResetAccountAuthorized(t *testing.T) { } func TestIncorrectPinReset(t *testing.T) { + sessionId := "session123" + ctx, store := InitializeTestStore(t) fm, err := NewFlagManager(flagsPath) + if err != nil { log.Fatal(err) } flag_incorrect_pin, _ := fm.parser.GetFlag("flag_incorrect_pin") + flag_account_blocked, _ := fm.parser.GetFlag("flag_account_blocked") + + ctx = context.WithValue(ctx, "SessionId", sessionId) // Define test cases tests := []struct { name string input []byte + attempts uint8 expectedResult resource.Result }{ { - name: "Test incorrect pin reset", + name: "Test when incorrect PIN attempts is 2", input: []byte(""), expectedResult: resource.Result{ FlagReset: []uint32{flag_incorrect_pin}, + Content: "1", //Expected remaining PIN attempts }, + attempts: 2, + }, + { + name: "Test incorrect pin reset when incorrect PIN attempts is 1", + input: []byte(""), + expectedResult: resource.Result{ + FlagReset: []uint32{flag_incorrect_pin}, + Content: "2", //Expected remaining PIN attempts + }, + attempts: 1, + }, + { + name: "Test incorrect pin reset when incorrect PIN attempts is 1", + input: []byte(""), + expectedResult: resource.Result{ + FlagReset: []uint32{flag_incorrect_pin}, + Content: "2", //Expected remaining PIN attempts + }, + attempts: 1, + }, + { + name: "Test incorrect pin reset when incorrect PIN attempts is 3(account expected to be blocked)", + input: []byte(""), + expectedResult: resource.Result{ + FlagReset: []uint32{flag_incorrect_pin}, + FlagSet: []uint32{flag_account_blocked}, + }, + attempts: 3, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + + if err := store.WriteEntry(ctx, sessionId, common.DATA_INCORRECT_PIN_ATTEMPTS, []byte(strconv.Itoa(int(tt.attempts)))); err != nil { + t.Fatal(err) + } + // Create the Handlers instance with the mock flag manager h := &Handlers{ - flagManager: fm.parser, + flagManager: fm.parser, + userdataStore: store, } // Call the method - res, err := h.ResetIncorrectPin(context.Background(), "reset_incorrect_pin", tt.input) + res, err := h.ResetIncorrectPin(ctx, "reset_incorrect_pin", tt.input) if err != nil { t.Error(err) } @@ -2190,3 +2239,93 @@ func TestGetVoucherDetails(t *testing.T) { assert.NoError(t, err) assert.Equal(t, expectedResult, res) } + +func TestCountIncorrectPINAttempts(t *testing.T) { + ctx, store := InitializeTestStore(t) + sessionId := "session123" + ctx = context.WithValue(ctx, "SessionId", sessionId) + attempts := uint8(2) + + h := &Handlers{ + userdataStore: store, + } + err := store.WriteEntry(ctx, sessionId, common.DATA_INCORRECT_PIN_ATTEMPTS, []byte(strconv.Itoa(int(attempts)))) + if err != nil { + t.Logf(err.Error()) + } + err = h.incrementIncorrectPINAttempts(ctx, sessionId) + if err != nil { + t.Logf(err.Error()) + } + + attemptsAfterCount, err := store.ReadEntry(ctx, sessionId, common.DATA_INCORRECT_PIN_ATTEMPTS) + if err != nil { + t.Logf(err.Error()) + } + pinAttemptsValue, _ := strconv.ParseUint(string(attemptsAfterCount), 0, 64) + pinAttemptsCount := uint8(pinAttemptsValue) + expectedAttempts := attempts + 1 + assert.Equal(t, pinAttemptsCount, expectedAttempts) + +} + +func TestResetIncorrectPINAttempts(t *testing.T) { + ctx, store := InitializeTestStore(t) + sessionId := "session123" + ctx = context.WithValue(ctx, "SessionId", sessionId) + + err := store.WriteEntry(ctx, sessionId, common.DATA_INCORRECT_PIN_ATTEMPTS, []byte(string("2"))) + if err != nil { + t.Logf(err.Error()) + } + + h := &Handlers{ + userdataStore: store, + } + h.resetIncorrectPINAttempts(ctx, sessionId) + incorrectAttempts, err := store.ReadEntry(ctx, sessionId, common.DATA_INCORRECT_PIN_ATTEMPTS) + + if err != nil { + t.Logf(err.Error()) + } + assert.Equal(t, "0", string(incorrectAttempts)) + +} + +func TestPersistLanguageCode(t *testing.T) { + ctx, store := InitializeTestStore(t) + + sessionId := "session123" + ctx = context.WithValue(ctx, "SessionId", sessionId) + + h := &Handlers{ + userdataStore: store, + } + tests := []struct { + name string + code string + expectedLanguageCode string + }{ + { + name: "Set Default Language (English)", + code: "eng", + expectedLanguageCode: "eng", + }, + { + name: "Set Swahili Language", + code: "swa", + expectedLanguageCode: "swa", + }, + } + + for _, test := range tests { + err := h.persistLanguageCode(ctx, test.code) + if err != nil { + t.Logf(err.Error()) + } + code, err := store.ReadEntry(ctx, sessionId, common.DATA_SELECTED_LANGUAGE_CODE) + + assert.Equal(t, test.expectedLanguageCode, string(code)) + } + +} diff --git a/internal/handlers/base.go b/internal/handlers/base.go index 755cca4..6c77f49 100644 --- a/internal/handlers/base.go +++ b/internal/handlers/base.go @@ -6,46 +6,46 @@ import ( "git.defalsify.org/vise.git/persist" "git.defalsify.org/vise.git/resource" - "git.grassecon.net/urdt/ussd/internal/handlers/ussd" + "git.grassecon.net/urdt/ussd/internal/handlers/application" "git.grassecon.net/urdt/ussd/internal/storage" ) type BaseSessionHandler struct { cfgTemplate engine.Config - rp RequestParser - rs resource.Resource - hn *ussd.Handlers - provider storage.StorageProvider + rp RequestParser + rs resource.Resource + hn *application.Handlers + provider storage.StorageProvider } -func NewBaseSessionHandler(cfg engine.Config, rs resource.Resource, stateDb db.Db, userdataDb db.Db, rp RequestParser, hn *ussd.Handlers) *BaseSessionHandler { +func NewBaseSessionHandler(cfg engine.Config, rs resource.Resource, stateDb db.Db, userdataDb db.Db, rp RequestParser, hn *application.Handlers) *BaseSessionHandler { return &BaseSessionHandler{ cfgTemplate: cfg, - rs: rs, - hn: hn, - rp: rp, - provider: storage.NewSimpleStorageProvider(stateDb, userdataDb), + rs: rs, + hn: hn, + rp: rp, + provider: storage.NewSimpleStorageProvider(stateDb, userdataDb), } } -func(f* BaseSessionHandler) Shutdown() { +func (f *BaseSessionHandler) Shutdown() { err := f.provider.Close() if err != nil { logg.Errorf("handler shutdown error", "err", err) } } -func(f *BaseSessionHandler) GetEngine(cfg engine.Config, rs resource.Resource, pr *persist.Persister) engine.Engine { +func (f *BaseSessionHandler) GetEngine(cfg engine.Config, rs resource.Resource, pr *persist.Persister) engine.Engine { en := engine.NewEngine(cfg, rs) en = en.WithPersister(pr) return en } -func(f *BaseSessionHandler) Process(rqs RequestSession) (RequestSession, error) { +func (f *BaseSessionHandler) Process(rqs RequestSession) (RequestSession, error) { var r bool var err error var ok bool - + logg.InfoCtxf(rqs.Ctx, "new request", "data", rqs) rqs.Storage, err = f.provider.Get(rqs.Config.SessionId) @@ -84,25 +84,25 @@ func(f *BaseSessionHandler) Process(rqs RequestSession) (RequestSession, error) return rqs, err } - rqs.Continue = r + rqs.Continue = r return rqs, nil } -func(f *BaseSessionHandler) Output(rqs RequestSession) (RequestSession, error) { +func (f *BaseSessionHandler) Output(rqs RequestSession) (RequestSession, error) { var err error _, err = rqs.Engine.Flush(rqs.Ctx, rqs.Writer) return rqs, err } -func(f *BaseSessionHandler) Reset(rqs RequestSession) (RequestSession, error) { +func (f *BaseSessionHandler) Reset(rqs RequestSession) (RequestSession, error) { defer f.provider.Put(rqs.Config.SessionId, rqs.Storage) return rqs, rqs.Engine.Finish() } -func(f *BaseSessionHandler) GetConfig() engine.Config { +func (f *BaseSessionHandler) GetConfig() engine.Config { return f.cfgTemplate } -func(f *BaseSessionHandler) GetRequestParser() RequestParser { +func (f *BaseSessionHandler) GetRequestParser() RequestParser { return f.rp } diff --git a/internal/handlers/handler_service.go b/internal/handlers/handler_service.go new file mode 100644 index 0000000..6fb355b --- /dev/null +++ b/internal/handlers/handler_service.go @@ -0,0 +1,141 @@ +package handlers + +import ( + "context" + "strings" + + "git.defalsify.org/vise.git/asm" + "git.defalsify.org/vise.git/db" + "git.defalsify.org/vise.git/engine" + "git.defalsify.org/vise.git/persist" + "git.defalsify.org/vise.git/resource" + + "git.grassecon.net/urdt/ussd/internal/handlers/application" + "git.grassecon.net/urdt/ussd/internal/utils" + "git.grassecon.net/urdt/ussd/remote" +) + +type HandlerService interface { + GetHandler() (*application.Handlers, error) +} + +func getParser(fp string, debug bool) (*asm.FlagParser, error) { + flagParser := asm.NewFlagParser().WithDebug() + _, err := flagParser.Load(fp) + if err != nil { + return nil, err + } + return flagParser, nil +} + +type LocalHandlerService struct { + Parser *asm.FlagParser + DbRs *resource.DbResource + Pe *persist.Persister + UserdataStore *db.Db + AdminStore *utils.AdminStore + Cfg engine.Config + Rs resource.Resource +} + +func NewLocalHandlerService(ctx context.Context, fp string, debug bool, dbResource *resource.DbResource, cfg engine.Config, rs resource.Resource) (*LocalHandlerService, error) { + parser, err := getParser(fp, debug) + if err != nil { + return nil, err + } + adminstore, err := utils.NewAdminStore(ctx, "admin_numbers") + if err != nil { + return nil, err + } + return &LocalHandlerService{ + Parser: parser, + DbRs: dbResource, + AdminStore: adminstore, + Cfg: cfg, + Rs: rs, + }, nil +} + +func (ls *LocalHandlerService) SetPersister(Pe *persist.Persister) { + ls.Pe = Pe +} + +func (ls *LocalHandlerService) SetDataStore(db *db.Db) { + ls.UserdataStore = db +} + +func (ls *LocalHandlerService) GetHandler(accountService remote.AccountServiceInterface) (*application.Handlers, error) { + replaceSeparatorFunc := func(input string) string { + return strings.ReplaceAll(input, ":", ls.Cfg.MenuSeparator) + } + + appHandlers, err := application.NewHandlers(ls.Parser, *ls.UserdataStore, ls.AdminStore, accountService, replaceSeparatorFunc) + if err != nil { + return nil, err + } + appHandlers = appHandlers.WithPersister(ls.Pe) + ls.DbRs.AddLocalFunc("set_language", appHandlers.SetLanguage) + ls.DbRs.AddLocalFunc("create_account", appHandlers.CreateAccount) + ls.DbRs.AddLocalFunc("save_temporary_pin", appHandlers.SaveTemporaryPin) + ls.DbRs.AddLocalFunc("verify_create_pin", appHandlers.VerifyCreatePin) + ls.DbRs.AddLocalFunc("check_identifier", appHandlers.CheckIdentifier) + ls.DbRs.AddLocalFunc("check_account_status", appHandlers.CheckAccountStatus) + ls.DbRs.AddLocalFunc("authorize_account", appHandlers.Authorize) + ls.DbRs.AddLocalFunc("quit", appHandlers.Quit) + ls.DbRs.AddLocalFunc("check_balance", appHandlers.CheckBalance) + ls.DbRs.AddLocalFunc("validate_recipient", appHandlers.ValidateRecipient) + ls.DbRs.AddLocalFunc("transaction_reset", appHandlers.TransactionReset) + ls.DbRs.AddLocalFunc("invite_valid_recipient", appHandlers.InviteValidRecipient) + ls.DbRs.AddLocalFunc("max_amount", appHandlers.MaxAmount) + ls.DbRs.AddLocalFunc("validate_amount", appHandlers.ValidateAmount) + ls.DbRs.AddLocalFunc("reset_transaction_amount", appHandlers.ResetTransactionAmount) + ls.DbRs.AddLocalFunc("get_recipient", appHandlers.GetRecipient) + ls.DbRs.AddLocalFunc("get_sender", appHandlers.GetSender) + ls.DbRs.AddLocalFunc("get_amount", appHandlers.GetAmount) + ls.DbRs.AddLocalFunc("reset_incorrect", appHandlers.ResetIncorrectPin) + ls.DbRs.AddLocalFunc("save_firstname", appHandlers.SaveFirstname) + ls.DbRs.AddLocalFunc("save_familyname", appHandlers.SaveFamilyname) + ls.DbRs.AddLocalFunc("save_gender", appHandlers.SaveGender) + ls.DbRs.AddLocalFunc("save_location", appHandlers.SaveLocation) + ls.DbRs.AddLocalFunc("save_yob", appHandlers.SaveYob) + ls.DbRs.AddLocalFunc("save_offerings", appHandlers.SaveOfferings) + ls.DbRs.AddLocalFunc("reset_account_authorized", appHandlers.ResetAccountAuthorized) + ls.DbRs.AddLocalFunc("reset_allow_update", appHandlers.ResetAllowUpdate) + ls.DbRs.AddLocalFunc("get_profile_info", appHandlers.GetProfileInfo) + ls.DbRs.AddLocalFunc("verify_yob", appHandlers.VerifyYob) + ls.DbRs.AddLocalFunc("reset_incorrect_date_format", appHandlers.ResetIncorrectYob) + ls.DbRs.AddLocalFunc("initiate_transaction", appHandlers.InitiateTransaction) + ls.DbRs.AddLocalFunc("verify_new_pin", appHandlers.VerifyNewPin) + ls.DbRs.AddLocalFunc("confirm_pin_change", appHandlers.ConfirmPinChange) + ls.DbRs.AddLocalFunc("quit_with_help", appHandlers.QuitWithHelp) + ls.DbRs.AddLocalFunc("fetch_community_balance", appHandlers.FetchCommunityBalance) + ls.DbRs.AddLocalFunc("set_default_voucher", appHandlers.SetDefaultVoucher) + ls.DbRs.AddLocalFunc("check_vouchers", appHandlers.CheckVouchers) + ls.DbRs.AddLocalFunc("get_vouchers", appHandlers.GetVoucherList) + ls.DbRs.AddLocalFunc("view_voucher", appHandlers.ViewVoucher) + ls.DbRs.AddLocalFunc("set_voucher", appHandlers.SetVoucher) + ls.DbRs.AddLocalFunc("get_voucher_details", appHandlers.GetVoucherDetails) + ls.DbRs.AddLocalFunc("reset_valid_pin", appHandlers.ResetValidPin) + ls.DbRs.AddLocalFunc("check_pin_mismatch", appHandlers.CheckBlockedNumPinMisMatch) + ls.DbRs.AddLocalFunc("validate_blocked_number", appHandlers.ValidateBlockedNumber) + ls.DbRs.AddLocalFunc("retrieve_blocked_number", appHandlers.RetrieveBlockedNumber) + ls.DbRs.AddLocalFunc("reset_unregistered_number", appHandlers.ResetUnregisteredNumber) + ls.DbRs.AddLocalFunc("reset_others_pin", appHandlers.ResetOthersPin) + ls.DbRs.AddLocalFunc("save_others_temporary_pin", appHandlers.SaveOthersTemporaryPin) + ls.DbRs.AddLocalFunc("get_current_profile_info", appHandlers.GetCurrentProfileInfo) + ls.DbRs.AddLocalFunc("check_transactions", appHandlers.CheckTransactions) + ls.DbRs.AddLocalFunc("get_transactions", appHandlers.GetTransactionsList) + ls.DbRs.AddLocalFunc("view_statement", appHandlers.ViewTransactionStatement) + ls.DbRs.AddLocalFunc("update_all_profile_items", appHandlers.UpdateAllProfileItems) + ls.DbRs.AddLocalFunc("set_back", appHandlers.SetBack) + ls.DbRs.AddLocalFunc("show_blocked_account", appHandlers.ShowBlockedAccount) + + return appHandlers, nil +} + +// TODO: enable setting of sessionId on engine init time +func (ls *LocalHandlerService) GetEngine() *engine.DefaultEngine { + en := engine.NewEngine(ls.Cfg, ls.Rs) + en = en.WithPersister(ls.Pe) + return en +} diff --git a/internal/handlers/handlerservice.go b/internal/handlers/handlerservice.go deleted file mode 100644 index 1da28c3..0000000 --- a/internal/handlers/handlerservice.go +++ /dev/null @@ -1,140 +0,0 @@ -package handlers - -import ( - "context" - "strings" - - "git.defalsify.org/vise.git/asm" - "git.defalsify.org/vise.git/db" - "git.defalsify.org/vise.git/engine" - "git.defalsify.org/vise.git/persist" - "git.defalsify.org/vise.git/resource" - - "git.grassecon.net/urdt/ussd/internal/handlers/ussd" - "git.grassecon.net/urdt/ussd/internal/utils" - "git.grassecon.net/urdt/ussd/remote" -) - -type HandlerService interface { - GetHandler() (*ussd.Handlers, error) -} - -func getParser(fp string, debug bool) (*asm.FlagParser, error) { - flagParser := asm.NewFlagParser().WithDebug() - _, err := flagParser.Load(fp) - if err != nil { - return nil, err - } - return flagParser, nil -} - -type LocalHandlerService struct { - Parser *asm.FlagParser - DbRs *resource.DbResource - Pe *persist.Persister - UserdataStore *db.Db - AdminStore *utils.AdminStore - Cfg engine.Config - Rs resource.Resource -} - -func NewLocalHandlerService(ctx context.Context, fp string, debug bool, dbResource *resource.DbResource, cfg engine.Config, rs resource.Resource) (*LocalHandlerService, error) { - parser, err := getParser(fp, debug) - if err != nil { - return nil, err - } - adminstore, err := utils.NewAdminStore(ctx, "admin_numbers") - if err != nil { - return nil, err - } - return &LocalHandlerService{ - Parser: parser, - DbRs: dbResource, - AdminStore: adminstore, - Cfg: cfg, - Rs: rs, - }, nil -} - -func (ls *LocalHandlerService) SetPersister(Pe *persist.Persister) { - ls.Pe = Pe -} - -func (ls *LocalHandlerService) SetDataStore(db *db.Db) { - ls.UserdataStore = db -} - -func (ls *LocalHandlerService) GetHandler(accountService remote.AccountServiceInterface) (*ussd.Handlers, error) { - replaceSeparatorFunc := func(input string) string { - return strings.ReplaceAll(input, ":", ls.Cfg.MenuSeparator) - } - - ussdHandlers, err := ussd.NewHandlers(ls.Parser, *ls.UserdataStore, ls.AdminStore, accountService, replaceSeparatorFunc) - if err != nil { - return nil, err - } - ussdHandlers = ussdHandlers.WithPersister(ls.Pe) - ls.DbRs.AddLocalFunc("set_language", ussdHandlers.SetLanguage) - ls.DbRs.AddLocalFunc("create_account", ussdHandlers.CreateAccount) - ls.DbRs.AddLocalFunc("save_temporary_pin", ussdHandlers.SaveTemporaryPin) - ls.DbRs.AddLocalFunc("verify_create_pin", ussdHandlers.VerifyCreatePin) - ls.DbRs.AddLocalFunc("check_identifier", ussdHandlers.CheckIdentifier) - ls.DbRs.AddLocalFunc("check_account_status", ussdHandlers.CheckAccountStatus) - ls.DbRs.AddLocalFunc("authorize_account", ussdHandlers.Authorize) - ls.DbRs.AddLocalFunc("quit", ussdHandlers.Quit) - ls.DbRs.AddLocalFunc("check_balance", ussdHandlers.CheckBalance) - ls.DbRs.AddLocalFunc("validate_recipient", ussdHandlers.ValidateRecipient) - ls.DbRs.AddLocalFunc("transaction_reset", ussdHandlers.TransactionReset) - ls.DbRs.AddLocalFunc("invite_valid_recipient", ussdHandlers.InviteValidRecipient) - ls.DbRs.AddLocalFunc("max_amount", ussdHandlers.MaxAmount) - ls.DbRs.AddLocalFunc("validate_amount", ussdHandlers.ValidateAmount) - ls.DbRs.AddLocalFunc("reset_transaction_amount", ussdHandlers.ResetTransactionAmount) - ls.DbRs.AddLocalFunc("get_recipient", ussdHandlers.GetRecipient) - ls.DbRs.AddLocalFunc("get_sender", ussdHandlers.GetSender) - ls.DbRs.AddLocalFunc("get_amount", ussdHandlers.GetAmount) - ls.DbRs.AddLocalFunc("reset_incorrect", ussdHandlers.ResetIncorrectPin) - ls.DbRs.AddLocalFunc("save_firstname", ussdHandlers.SaveFirstname) - ls.DbRs.AddLocalFunc("save_familyname", ussdHandlers.SaveFamilyname) - ls.DbRs.AddLocalFunc("save_gender", ussdHandlers.SaveGender) - ls.DbRs.AddLocalFunc("save_location", ussdHandlers.SaveLocation) - ls.DbRs.AddLocalFunc("save_yob", ussdHandlers.SaveYob) - ls.DbRs.AddLocalFunc("save_offerings", ussdHandlers.SaveOfferings) - ls.DbRs.AddLocalFunc("reset_account_authorized", ussdHandlers.ResetAccountAuthorized) - ls.DbRs.AddLocalFunc("reset_allow_update", ussdHandlers.ResetAllowUpdate) - ls.DbRs.AddLocalFunc("get_profile_info", ussdHandlers.GetProfileInfo) - ls.DbRs.AddLocalFunc("verify_yob", ussdHandlers.VerifyYob) - ls.DbRs.AddLocalFunc("reset_incorrect_date_format", ussdHandlers.ResetIncorrectYob) - ls.DbRs.AddLocalFunc("initiate_transaction", ussdHandlers.InitiateTransaction) - ls.DbRs.AddLocalFunc("verify_new_pin", ussdHandlers.VerifyNewPin) - ls.DbRs.AddLocalFunc("confirm_pin_change", ussdHandlers.ConfirmPinChange) - ls.DbRs.AddLocalFunc("quit_with_help", ussdHandlers.QuitWithHelp) - ls.DbRs.AddLocalFunc("fetch_community_balance", ussdHandlers.FetchCommunityBalance) - ls.DbRs.AddLocalFunc("set_default_voucher", ussdHandlers.SetDefaultVoucher) - ls.DbRs.AddLocalFunc("check_vouchers", ussdHandlers.CheckVouchers) - ls.DbRs.AddLocalFunc("get_vouchers", ussdHandlers.GetVoucherList) - ls.DbRs.AddLocalFunc("view_voucher", ussdHandlers.ViewVoucher) - ls.DbRs.AddLocalFunc("set_voucher", ussdHandlers.SetVoucher) - ls.DbRs.AddLocalFunc("get_voucher_details", ussdHandlers.GetVoucherDetails) - ls.DbRs.AddLocalFunc("reset_valid_pin", ussdHandlers.ResetValidPin) - ls.DbRs.AddLocalFunc("check_pin_mismatch", ussdHandlers.CheckBlockedNumPinMisMatch) - ls.DbRs.AddLocalFunc("validate_blocked_number", ussdHandlers.ValidateBlockedNumber) - ls.DbRs.AddLocalFunc("retrieve_blocked_number", ussdHandlers.RetrieveBlockedNumber) - ls.DbRs.AddLocalFunc("reset_unregistered_number", ussdHandlers.ResetUnregisteredNumber) - ls.DbRs.AddLocalFunc("reset_others_pin", ussdHandlers.ResetOthersPin) - ls.DbRs.AddLocalFunc("save_others_temporary_pin", ussdHandlers.SaveOthersTemporaryPin) - ls.DbRs.AddLocalFunc("get_current_profile_info", ussdHandlers.GetCurrentProfileInfo) - ls.DbRs.AddLocalFunc("check_transactions", ussdHandlers.CheckTransactions) - ls.DbRs.AddLocalFunc("get_transactions", ussdHandlers.GetTransactionsList) - ls.DbRs.AddLocalFunc("view_statement", ussdHandlers.ViewTransactionStatement) - ls.DbRs.AddLocalFunc("update_all_profile_items", ussdHandlers.UpdateAllProfileItems) - ls.DbRs.AddLocalFunc("set_back", ussdHandlers.SetBack) - - return ussdHandlers, nil -} - -// TODO: enable setting of sessionId on engine init time -func (ls *LocalHandlerService) GetEngine() *engine.DefaultEngine { - en := engine.NewEngine(ls.Cfg, ls.Rs) - en = en.WithPersister(ls.Pe) - return en -} diff --git a/internal/handlers/single.go b/internal/handlers/single.go index 6929617..1b11a64 100644 --- a/internal/handlers/single.go +++ b/internal/handlers/single.go @@ -6,9 +6,9 @@ import ( "io" "git.defalsify.org/vise.git/engine" - "git.defalsify.org/vise.git/resource" - "git.defalsify.org/vise.git/persist" "git.defalsify.org/vise.git/logging" + "git.defalsify.org/vise.git/persist" + "git.defalsify.org/vise.git/resource" "git.grassecon.net/urdt/ussd/internal/storage" ) @@ -20,33 +20,33 @@ var ( var ( ErrInvalidRequest = errors.New("invalid request for context") ErrSessionMissing = errors.New("missing session") - ErrInvalidInput = errors.New("invalid input") - ErrStorage = errors.New("storage retrieval fail") - ErrEngineType = errors.New("incompatible engine") - ErrEngineInit = errors.New("engine init fail") - ErrEngineExec = errors.New("engine exec fail") + ErrInvalidInput = errors.New("invalid input") + ErrStorage = errors.New("storage retrieval fail") + ErrEngineType = errors.New("incompatible engine") + ErrEngineInit = errors.New("engine init fail") + ErrEngineExec = errors.New("engine exec fail") ) type RequestSession struct { - Ctx context.Context - Config engine.Config - Engine engine.Engine - Input []byte - Storage *storage.Storage - Writer io.Writer + Ctx context.Context + Config engine.Config + Engine engine.Engine + Input []byte + Storage *storage.Storage + Writer io.Writer Continue bool } // TODO: seems like can remove this. type RequestParser interface { - GetSessionId(rq any) (string, error) + GetSessionId(context context.Context, rq any) (string, error) GetInput(rq any) ([]byte, error) } type RequestHandler interface { GetConfig() engine.Config GetRequestParser() RequestParser - GetEngine(cfg engine.Config, rs resource.Resource, pe *persist.Persister) engine.Engine + GetEngine(cfg engine.Config, rs resource.Resource, pe *persist.Persister) engine.Engine Process(rs RequestSession) (RequestSession, error) Output(rs RequestSession) (RequestSession, error) Reset(rs RequestSession) (RequestSession, error) diff --git a/internal/http/at/parse.go b/internal/http/at/parse.go index d2696ed..76e84e7 100644 --- a/internal/http/at/parse.go +++ b/internal/http/at/parse.go @@ -15,16 +15,14 @@ import ( ) type ATRequestParser struct { - Context context.Context } -func (arp *ATRequestParser) GetSessionId(rq any) (string, error) { +func (arp *ATRequestParser) GetSessionId(ctx context.Context, rq any) (string, error) { rqv, ok := rq.(*http.Request) if !ok { logg.Warnf("got an invalid request", "req", rq) return "", handlers.ErrInvalidRequest } - // Capture body (if any) for logging body, err := io.ReadAll(rqv.Body) if err != nil { @@ -43,9 +41,9 @@ func (arp *ATRequestParser) GetSessionId(rq any) (string, error) { decodedStr := string(logBytes) sessionId, err := extractATSessionId(decodedStr) if err != nil { - context.WithValue(arp.Context, "at-session-id", sessionId) + ctx = context.WithValue(ctx, "AT-SessionId", sessionId) } - logg.Debugf("Received request:", decodedStr) + logg.DebugCtxf(ctx, "Received request:", decodedStr) } if err := rqv.ParseForm(); err != nil { @@ -83,7 +81,8 @@ func (arp *ATRequestParser) GetInput(rq any) ([]byte, error) { return nil, fmt.Errorf("no input found") } - return []byte(parts[len(parts)-1]), nil + trimmedInput := strings.TrimSpace(parts[len(parts)-1]) + return []byte(trimmedInput), nil } func parseQueryParams(query string) map[string]string { diff --git a/internal/http/at/server.go b/internal/http/at/server.go index 705ff76..3399dd5 100644 --- a/internal/http/at/server.go +++ b/internal/http/at/server.go @@ -10,7 +10,7 @@ import ( ) var ( - logg = logging.NewVanilla().WithDomain("atserver") + logg = logging.NewVanilla().WithDomain("atserver").WithContextKey("SessionId").WithContextKey("AT-SessionId") ) type ATSessionHandler struct { @@ -34,7 +34,7 @@ func (ash *ATSessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) rp := ash.GetRequestParser() cfg := ash.GetConfig() - cfg.SessionId, err = rp.GetSessionId(req) + cfg.SessionId, err = rp.GetSessionId(req.Context(), req) if err != nil { logg.ErrorCtxf(rqs.Ctx, "", "header processing error", err) ash.WriteError(w, 400, err) @@ -48,7 +48,7 @@ func (ash *ATSessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) return } - rqs, err = ash.Process(rqs) + rqs, err = ash.Process(rqs) switch err { case nil: // set code to 200 if no err code = 200 diff --git a/internal/http/parse.go b/internal/http/parse.go index ec8e00b..b4e784d 100644 --- a/internal/http/parse.go +++ b/internal/http/parse.go @@ -1,6 +1,7 @@ package http import ( + "context" "io/ioutil" "net/http" @@ -10,7 +11,7 @@ import ( type DefaultRequestParser struct { } -func (rp *DefaultRequestParser) GetSessionId(rq any) (string, error) { +func (rp *DefaultRequestParser) GetSessionId(ctx context.Context, rq any) (string, error) { rqv, ok := rq.(*http.Request) if !ok { return "", handlers.ErrInvalidRequest @@ -34,5 +35,3 @@ func (rp *DefaultRequestParser) GetInput(rq any) ([]byte, error) { } return v, nil } - - diff --git a/internal/http/server.go b/internal/http/server.go index 9cadfa3..0a2533e 100644 --- a/internal/http/server.go +++ b/internal/http/server.go @@ -46,7 +46,7 @@ func (f *SessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { rp := f.GetRequestParser() cfg := f.GetConfig() - cfg.SessionId, err = rp.GetSessionId(req) + cfg.SessionId, err = rp.GetSessionId(req.Context(), req) if err != nil { logg.ErrorCtxf(rqs.Ctx, "", "header processing error", err) f.WriteError(w, 400, err) diff --git a/internal/http/server_test.go b/internal/http/server_test.go index a46f98e..23afd5d 100644 --- a/internal/http/server_test.go +++ b/internal/http/server_test.go @@ -2,6 +2,7 @@ package http import ( "bytes" + "context" "errors" "net/http" "net/http/httptest" @@ -161,7 +162,7 @@ func TestDefaultRequestParser_GetSessionId(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - id, err := parser.GetSessionId(tt.request) + id, err := parser.GetSessionId(context.Background(),tt.request) if id != tt.expectedID { t.Errorf("Expected session ID %s, got %s", tt.expectedID, id) diff --git a/internal/ssh/keystore.go b/internal/ssh/keystore.go new file mode 100644 index 0000000..206d684 --- /dev/null +++ b/internal/ssh/keystore.go @@ -0,0 +1,65 @@ +package ssh + +import ( + "context" + "fmt" + "os" + "path" + + "golang.org/x/crypto/ssh" + + "git.defalsify.org/vise.git/db" + + "git.grassecon.net/urdt/ussd/internal/storage" + dbstorage "git.grassecon.net/urdt/ussd/internal/storage/db/gdbm" +) + +type SshKeyStore struct { + store db.Db +} + +func NewSshKeyStore(ctx context.Context, dbDir string) (*SshKeyStore, error) { + keyStore := &SshKeyStore{} + keyStoreFile := path.Join(dbDir, "ssh_authorized_keys.gdbm") + keyStore.store = dbstorage.NewThreadGdbmDb() + err := keyStore.store.Connect(ctx, keyStoreFile) + if err != nil { + return nil, err + } + return keyStore, nil +} + +func(s *SshKeyStore) AddFromFile(ctx context.Context, fp string, sessionId string) error { + _, err := os.Stat(fp) + if err != nil { + return fmt.Errorf("cannot open ssh server public key file: %v\n", err) + } + + publicBytes, err := os.ReadFile(fp) + if err != nil { + return fmt.Errorf("Failed to load public key: %v", err) + } + pubKey, _, _, _, err := ssh.ParseAuthorizedKey(publicBytes) + if err != nil { + return fmt.Errorf("Failed to parse public key: %v", err) + } + k := append([]byte{0x01}, pubKey.Marshal()...) + s.store.SetPrefix(storage.DATATYPE_EXTEND) + logg.Infof("Added key", "sessionId", sessionId, "public key", string(publicBytes)) + return s.store.Put(ctx, k, []byte(sessionId)) +} + +func(s *SshKeyStore) Get(ctx context.Context, pubKey ssh.PublicKey) (string, error) { + s.store.SetLanguage(nil) + s.store.SetPrefix(storage.DATATYPE_EXTEND) + k := append([]byte{0x01}, pubKey.Marshal()...) + v, err := s.store.Get(ctx, k) + if err != nil { + return "", err + } + return string(v), nil +} + +func(s *SshKeyStore) Close() error { + return s.store.Close() +} diff --git a/internal/ssh/ssh.go b/internal/ssh/ssh.go new file mode 100644 index 0000000..8209187 --- /dev/null +++ b/internal/ssh/ssh.go @@ -0,0 +1,284 @@ +package ssh + +import ( + "context" + "encoding/hex" + "encoding/base64" + "errors" + "fmt" + "net" + "os" + "sync" + + "golang.org/x/crypto/ssh" + + "git.defalsify.org/vise.git/engine" + "git.defalsify.org/vise.git/logging" + "git.defalsify.org/vise.git/resource" + "git.defalsify.org/vise.git/state" + + "git.grassecon.net/urdt/ussd/internal/handlers" + "git.grassecon.net/urdt/ussd/internal/storage" + "git.grassecon.net/urdt/ussd/remote" +) + +var ( + logg = logging.NewVanilla().WithDomain("ssh") +) + +type auther struct { + Ctx context.Context + keyStore *SshKeyStore + auth map[string]string +} + +func NewAuther(ctx context.Context, keyStore *SshKeyStore) *auther { + return &auther{ + Ctx: ctx, + keyStore: keyStore, + auth: make(map[string]string), + } +} + +func(a *auther) Check(conn ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) { + logg.TraceCtxf(a.Ctx, "looking for publickey", "pubkey", fmt.Sprintf("%x", pubKey)) + va, err := a.keyStore.Get(a.Ctx, pubKey) + if err != nil { + return nil, err + } + ka := hex.EncodeToString(conn.SessionID()) + a.auth[ka] = va + fmt.Fprintf(os.Stderr, "connect: %s -> %s\n", ka, va) + return nil, nil +} + +func(a *auther) FromConn(c *ssh.ServerConn) (string, error) { + if c == nil { + return "", errors.New("nil server conn") + } + if c.Conn == nil { + return "", errors.New("nil underlying conn") + } + return a.Get(c.Conn.SessionID()) +} + + +func(a *auther) Get(k []byte) (string, error) { + ka := hex.EncodeToString(k) + v, ok := a.auth[ka] + if !ok { + return "", errors.New("not found") + } + return v, nil +} + +type SshRunner struct { + Ctx context.Context + Cfg engine.Config + FlagFile string + Conn storage.ConnData + ResourceDir string + Debug bool + SrvKeyFile string + Host string + Port uint + wg sync.WaitGroup + lst net.Listener +} + +func(s *SshRunner) serve(ctx context.Context, sessionId string, ch ssh.NewChannel, en engine.Engine) error { + if ch == nil { + return errors.New("nil channel") + } + if ch.ChannelType() != "session" { + ch.Reject(ssh.UnknownChannelType, "that is not the channel you are looking for") + return errors.New("not a session") + } + channel, requests, err := ch.Accept() + if err != nil { + panic(err) + } + defer channel.Close() + s.wg.Add(1) + go func(reqIn <-chan *ssh.Request) { + defer s.wg.Done() + for req := range reqIn { + req.Reply(req.Type == "shell", nil) + } + _ = requests + }(requests) + + cont, err := en.Exec(ctx, []byte{}) + if err != nil { + return fmt.Errorf("initial engine exec err: %v", err) + } + + var input [state.INPUT_LIMIT]byte + for cont { + c, err := en.Flush(ctx, channel) + if err != nil { + return fmt.Errorf("flush err: %v", err) + } + _, err = channel.Write([]byte{0x0a}) + if err != nil { + return fmt.Errorf("newline err: %v", err) + } + c, err = channel.Read(input[:]) + if err != nil { + return fmt.Errorf("read input fail: %v", err) + } + logg.TraceCtxf(ctx, "input read", "c", c, "input", input[:c-1]) + cont, err = en.Exec(ctx, input[:c-1]) + if err != nil { + return fmt.Errorf("engine exec err: %v", err) + } + logg.TraceCtxf(ctx, "exec cont", "cont", cont, "en", en) + _ = c + } + c, err := en.Flush(ctx, channel) + if err != nil { + return fmt.Errorf("last flush err: %v", err) + } + _ = c + return nil +} + +func(s *SshRunner) Stop() error { + return s.lst.Close() +} + +func(s *SshRunner) GetEngine(sessionId string) (engine.Engine, func(), error) { + ctx := s.Ctx + menuStorageService := storage.NewMenuStorageService(s.Conn, s.ResourceDir) + + rs, err := menuStorageService.GetResource(ctx) + if err != nil { + return nil, nil, err + } + + pe, err := menuStorageService.GetPersister(ctx) + if err != nil { + return nil, nil, err + } + + userdatastore, err := menuStorageService.GetUserdataDb(ctx) + if err != nil { + return nil, nil, err + } + + dbResource, ok := rs.(*resource.DbResource) + if !ok { + return nil, nil, err + } + + lhs, err := handlers.NewLocalHandlerService(ctx, s.FlagFile, true, dbResource, s.Cfg, rs) + lhs.SetDataStore(&userdatastore) + lhs.SetPersister(pe) + lhs.Cfg.SessionId = sessionId + + if err != nil { + return nil, nil, err + } + + // TODO: clear up why pointer here and by-value other cmds + accountService := &remote.AccountService{} + hl, err := lhs.GetHandler(accountService) + if err != nil { + return nil, nil, err + } + + en := lhs.GetEngine() + en = en.WithFirst(hl.Init) + if s.Debug { + en = en.WithDebug(nil) + } + // TODO: this is getting very hacky! + closer := func() { + err := menuStorageService.Close() + if err != nil { + logg.ErrorCtxf(ctx, "menu storage service cleanup fail", "err", err) + } + } + return en, closer, nil +} + +// adapted example from crypto/ssh package, NewServerConn doc +func(s *SshRunner) Run(ctx context.Context, keyStore *SshKeyStore) { + s.Ctx = ctx + running := true + + // TODO: waitgroup should probably not be global + defer s.wg.Wait() + + auth := NewAuther(ctx, keyStore) + cfg := ssh.ServerConfig{ + PublicKeyCallback: auth.Check, + } + + privateBytes, err := os.ReadFile(s.SrvKeyFile) + if err != nil { + logg.ErrorCtxf(ctx, "Failed to load private key", "err", err) + } + private, err := ssh.ParsePrivateKey(privateBytes) + if err != nil { + logg.ErrorCtxf(ctx, "Failed to parse private key", "err", err) + } + srvPub := private.PublicKey() + srvPubStr := base64.StdEncoding.EncodeToString(srvPub.Marshal()) + logg.InfoCtxf(ctx, "have server key", "type", srvPub.Type(), "public", srvPubStr) + cfg.AddHostKey(private) + + s.lst, err = net.Listen("tcp", fmt.Sprintf("%s:%d", s.Host, s.Port)) + if err != nil { + panic(err) + } + + for running { + conn, err := s.lst.Accept() + if err != nil { + logg.ErrorCtxf(ctx, "ssh accept error", "err", err) + running = false + continue + } + + go func(conn net.Conn) { + defer conn.Close() + for true { + srvConn, nC, rC, err := ssh.NewServerConn(conn, &cfg) + if err != nil { + logg.InfoCtxf(ctx, "rejected client", "err", err) + return + } + logg.DebugCtxf(ctx, "ssh client connected", "conn", srvConn) + + s.wg.Add(1) + go func() { + ssh.DiscardRequests(rC) + s.wg.Done() + }() + + sessionId, err := auth.FromConn(srvConn) + if err != nil { + logg.ErrorCtxf(ctx, "Cannot find authentication") + return + } + en, closer, err := s.GetEngine(sessionId) + if err != nil { + logg.ErrorCtxf(ctx, "engine won't start", "err", err) + return + } + defer func() { + err := en.Finish() + if err != nil { + logg.ErrorCtxf(ctx, "engine won't stop", "err", err) + } + closer() + }() + for ch := range nC { + err = s.serve(ctx, sessionId, ch, en) + logg.ErrorCtxf(ctx, "ssh server finish", "err", err) + } + } + }(conn) + } +} diff --git a/internal/storage/parse.go b/internal/storage/parse.go new file mode 100644 index 0000000..41dac6b --- /dev/null +++ b/internal/storage/parse.go @@ -0,0 +1,86 @@ +package storage + +import ( + "fmt" + "net/url" + "path" +) + +const ( + DBTYPE_MEM = iota + DBTYPE_GDBM + DBTYPE_POSTGRES +) + +type ConnData struct { + typ int + str string + domain string +} + +func (cd *ConnData) DbType() int { + return cd.typ +} + +func (cd *ConnData) String() string { + return cd.str +} + +func (cd *ConnData) Domain() string { + return cd.domain +} + +func (cd *ConnData) Path() string { + v, _ := url.Parse(cd.str) + v.RawQuery = "" + return v.String() +} + +func probePostgres(s string) (string, string, bool) { + domain := "public" + v, err := url.Parse(s) + if err != nil { + return "", "", false + } + if v.Scheme != "postgres" { + return "", "", false + } + vv := v.Query() + if vv.Has("search_path") { + domain = vv.Get("search_path") + } + return s, domain, true +} + +func probeGdbm(s string) (string, string, bool) { + if !path.IsAbs(s) { + return "", "", false + } + s = path.Clean(s) + return s, "", true +} + +func ToConnData(connStr string) (ConnData, error) { + var o ConnData + + if connStr == "" { + return o, nil + } + + v, domain, ok := probePostgres(connStr) + if ok { + o.typ = DBTYPE_POSTGRES + o.str = v + o.domain = domain + return o, nil + } + + v, _, ok = probeGdbm(connStr) + if ok { + o.typ = DBTYPE_GDBM + o.str = v + return o, nil + } + + return o, fmt.Errorf("invalid connection string: %s", connStr) +} diff --git a/internal/storage/parse_test.go b/internal/storage/parse_test.go new file mode 100644 index 0000000..e18e57c --- /dev/null +++ b/internal/storage/parse_test.go @@ -0,0 +1,28 @@ +package storage + +import ( + "testing" +) + +func TestParseConnStr(t *testing.T) { + _, err := ToConnData("postgres://foo:bar@localhost:5432/baz") + if err != nil { + t.Fatal(err) + } + _, err = ToConnData("/foo/bar") + if err != nil { + t.Fatal(err) + } + _, err = ToConnData("/foo/bar/") + if err != nil { + t.Fatal(err) + } + _, err = ToConnData("foo/bar") + if err == nil { + t.Fatalf("expected error") + } + _, err = ToConnData("http://foo/bar") + if err == nil { + t.Fatalf("expected error") + } +} diff --git a/internal/storage/storage.go b/internal/storage/storage.go index 53f4392..231a1db 100644 --- a/internal/storage/storage.go +++ b/internal/storage/storage.go @@ -5,6 +5,10 @@ import ( "git.defalsify.org/vise.git/persist" ) +const ( + DATATYPE_EXTEND = 128 +) + type Storage struct { Persister *persist.Persister UserdataDb db.Db diff --git a/internal/storage/storageservice.go b/internal/storage/storage_service.go similarity index 55% rename from internal/storage/storageservice.go rename to internal/storage/storage_service.go index 04e75ce..374af74 100644 --- a/internal/storage/storageservice.go +++ b/internal/storage/storage_service.go @@ -9,11 +9,12 @@ import ( "git.defalsify.org/vise.git/db" fsdb "git.defalsify.org/vise.git/db/fs" "git.defalsify.org/vise.git/db/postgres" + "git.defalsify.org/vise.git/lang" "git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/persist" "git.defalsify.org/vise.git/resource" - "git.grassecon.net/urdt/ussd/initializers" gdbmstorage "git.grassecon.net/urdt/ussd/internal/storage/db/gdbm" + "github.com/jackc/pgx/v5/pgxpool" ) var ( @@ -24,63 +25,54 @@ type StorageService interface { GetPersister(ctx context.Context) (*persist.Persister, error) GetUserdataDb(ctx context.Context) db.Db GetResource(ctx context.Context) (resource.Resource, error) - EnsureDbDir() error } type MenuStorageService struct { - dbDir string + conn ConnData resourceDir string + poResource resource.Resource resourceStore db.Db stateStore db.Db userDataStore db.Db } -func buildConnStr() string { - host := initializers.GetEnv("DB_HOST", "localhost") - user := initializers.GetEnv("DB_USER", "postgres") - password := initializers.GetEnv("DB_PASSWORD", "") - dbName := initializers.GetEnv("DB_NAME", "") - port := initializers.GetEnv("DB_PORT", "5432") - - connString := fmt.Sprintf( - "postgres://%s:%s@%s:%s/%s", - user, password, host, port, dbName, - ) - logg.Debugf("pg conn string", "conn", connString) - - return connString -} - -func NewMenuStorageService(dbDir string, resourceDir string) *MenuStorageService { +func NewMenuStorageService(conn ConnData, resourceDir string) *MenuStorageService { return &MenuStorageService{ - dbDir: dbDir, + conn: conn, resourceDir: resourceDir, } } -func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.Db, fileName string) (db.Db, error) { - database, ok := ctx.Value("Database").(string) - if !ok { - return nil, fmt.Errorf("failed to select the database") - } +func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.Db, section string) (db.Db, error) { + var newDb db.Db + var err error if existingDb != nil { return existingDb, nil } - var newDb db.Db - var err error - if database == "postgres" { - newDb = postgres.NewPgDb() - connStr := buildConnStr() - err = newDb.Connect(ctx, connStr) - } else { + connStr := ms.conn.String() + dbTyp := ms.conn.DbType() + if dbTyp == DBTYPE_POSTGRES { + // TODO: move to vise + err = ensureSchemaExists(ctx, ms.conn) + if err != nil { + return nil, err + } + newDb = postgres.NewPgDb().WithSchema(ms.conn.Domain()) + } else if dbTyp == DBTYPE_GDBM { + err = ms.ensureDbDir() + if err != nil { + return nil, err + } + connStr = path.Join(connStr, section) newDb = gdbmstorage.NewThreadGdbmDb() - storeFile := path.Join(ms.dbDir, fileName) - err = newDb.Connect(ctx, storeFile) + } else { + return nil, fmt.Errorf("unsupported connection string: '%s'\n", ms.conn.String()) } - + logg.DebugCtxf(ctx, "connecting to db", "conn", connStr, "conndata", ms.conn) + err = newDb.Connect(ctx, connStr) if err != nil { return nil, err } @@ -88,6 +80,45 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D return newDb, nil } +// WithGettext triggers use of gettext for translation of templates and menus. +// +// The first language in `lns` will be used as default language, to resolve node keys to +// language strings. +// +// If `lns` is an empty array, gettext will not be used. +func (ms *MenuStorageService) WithGettext(path string, lns []lang.Language) *MenuStorageService { + if len(lns) == 0 { + logg.Warnf("Gettext requested but no languages supplied") + return ms + } + rs := resource.NewPoResource(lns[0], path) + + for _, ln := range(lns) { + rs = rs.WithLanguage(ln) + } + + ms.poResource = rs + + return ms +} + +// ensureSchemaExists creates a new schema if it does not exist +func ensureSchemaExists(ctx context.Context, conn ConnData) error { + h, err := pgxpool.New(ctx, conn.Path()) + if err != nil { + return fmt.Errorf("failed to connect to the database: %w", err) + } + defer h.Close() + + query := fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", conn.Domain()) + _, err = h.Exec(ctx, query) + if err != nil { + return fmt.Errorf("failed to create schema: %w", err) + } + + return nil +} + func (ms *MenuStorageService) GetPersister(ctx context.Context) (*persist.Persister, error) { stateStore, err := ms.GetStateStore(ctx) if err != nil { @@ -120,6 +151,11 @@ func (ms *MenuStorageService) GetResource(ctx context.Context) (resource.Resourc return nil, err } rfs := resource.NewDbResource(ms.resourceStore) + if ms.poResource != nil { + logg.InfoCtxf(ctx, "using poresource for menu and template") + rfs.WithMenuGetter(ms.poResource.GetMenu) + rfs.WithTemplateGetter(ms.poResource.GetTemplate) + } return rfs, nil } @@ -137,8 +173,8 @@ func (ms *MenuStorageService) GetStateStore(ctx context.Context) (db.Db, error) return ms.stateStore, nil } -func (ms *MenuStorageService) EnsureDbDir() error { - err := os.MkdirAll(ms.dbDir, 0700) +func (ms *MenuStorageService) ensureDbDir() error { + err := os.MkdirAll(ms.conn.String(), 0700) if err != nil { return fmt.Errorf("state dir create exited with error: %v\n", err) } diff --git a/internal/testutil/TestEngine.go b/internal/testutil/TestEngine.go deleted file mode 100644 index 3fcb307..0000000 --- a/internal/testutil/TestEngine.go +++ /dev/null @@ -1,124 +0,0 @@ -package testutil - -import ( - "context" - "fmt" - "os" - "path" - "time" - - "git.defalsify.org/vise.git/engine" - "git.defalsify.org/vise.git/logging" - "git.defalsify.org/vise.git/resource" - "git.grassecon.net/urdt/ussd/internal/handlers" - "git.grassecon.net/urdt/ussd/internal/storage" - "git.grassecon.net/urdt/ussd/internal/testutil/testservice" - "git.grassecon.net/urdt/ussd/internal/testutil/testtag" - testdataloader "github.com/peteole/testdata-loader" - "git.grassecon.net/urdt/ussd/remote" -) - -var ( - baseDir = testdataloader.GetBasePath() - logg = logging.NewVanilla() - scriptDir = path.Join(baseDir, "services", "registration") -) - -func TestEngine(sessionId string) (engine.Engine, func(), chan bool) { - ctx := context.Background() - ctx = context.WithValue(ctx, "SessionId", sessionId) - ctx = context.WithValue(ctx, "Database", "gdbm") - pfp := path.Join(scriptDir, "pp.csv") - - var eventChannel = make(chan bool) - - cfg := engine.Config{ - Root: "root", - SessionId: sessionId, - OutputSize: uint32(160), - FlagCount: uint32(128), - } - - dbDir := ".test_state" - resourceDir := scriptDir - menuStorageService := storage.NewMenuStorageService(dbDir, resourceDir) - - err := menuStorageService.EnsureDbDir() - if err != nil { - fmt.Fprintf(os.Stderr, err.Error()) - os.Exit(1) - } - - rs, err := menuStorageService.GetResource(ctx) - if err != nil { - fmt.Fprintf(os.Stderr, err.Error()) - os.Exit(1) - } - - pe, err := menuStorageService.GetPersister(ctx) - if err != nil { - fmt.Fprintf(os.Stderr, err.Error()) - os.Exit(1) - } - - userDataStore, err := menuStorageService.GetUserdataDb(ctx) - if err != nil { - fmt.Fprintf(os.Stderr, err.Error()) - os.Exit(1) - } - - dbResource, ok := rs.(*resource.DbResource) - if !ok { - fmt.Fprintf(os.Stderr, err.Error()) - os.Exit(1) - } - - lhs, err := handlers.NewLocalHandlerService(ctx, pfp, true, dbResource, cfg, rs) - lhs.SetDataStore(&userDataStore) - lhs.SetPersister(pe) - - if err != nil { - fmt.Fprintf(os.Stderr, err.Error()) - os.Exit(1) - } - - if testtag.AccountService == nil { - testtag.AccountService = &remote.AccountService{} - } - - switch testtag.AccountService.(type) { - case *testservice.TestAccountService: - go func() { - eventChannel <- false - }() - case *remote.AccountService: - go func() { - time.Sleep(5 * time.Second) // Wait for 5 seconds - eventChannel <- true - }() - default: - panic("Unknown account service type") - } - - hl, err := lhs.GetHandler(testtag.AccountService) - if err != nil { - fmt.Fprintf(os.Stderr, err.Error()) - os.Exit(1) - } - - en := lhs.GetEngine() - en = en.WithFirst(hl.Init) - cleanFn := func() { - err := en.Finish() - if err != nil { - logg.Errorf(err.Error()) - } - - err = menuStorageService.Close() - if err != nil { - logg.Errorf(err.Error()) - } - logg.Infof("testengine storage closed") - } - return en, cleanFn, eventChannel -} diff --git a/internal/testutil/engine.go b/internal/testutil/engine.go new file mode 100644 index 0000000..5d581ba --- /dev/null +++ b/internal/testutil/engine.go @@ -0,0 +1,209 @@ +package testutil + +import ( + "context" + "fmt" + "log" + "net/url" + "os" + "path" + "path/filepath" + "time" + + "git.defalsify.org/vise.git/engine" + "git.defalsify.org/vise.git/logging" + "git.defalsify.org/vise.git/resource" + "git.grassecon.net/urdt/ussd/config" + "git.grassecon.net/urdt/ussd/initializers" + "git.grassecon.net/urdt/ussd/internal/handlers" + "git.grassecon.net/urdt/ussd/internal/storage" + "git.grassecon.net/urdt/ussd/internal/testutil/testservice" + "git.grassecon.net/urdt/ussd/internal/testutil/testtag" + "git.grassecon.net/urdt/ussd/remote" + "github.com/jackc/pgx/v5/pgxpool" + testdataloader "github.com/peteole/testdata-loader" +) + +var ( + logg = logging.NewVanilla() + baseDir = testdataloader.GetBasePath() + scriptDir = path.Join(baseDir, "services", "registration") + setDbType string + setConnStr string + setDbSchema string +) + +func init() { + initializers.LoadEnvVariablesPath(baseDir) + config.LoadConfig() +} + +// SetDatabase updates the database used by TestEngine +func SetDatabase(database, connStr, dbSchema string) { + setDbType = database + setConnStr = connStr + setDbSchema = dbSchema +} + +// CleanDatabase removes all test data from the database +func CleanDatabase() { + if setDbType == "postgres" { + ctx := context.Background() + // Update the connection string with the new search path + updatedConnStr, err := updateSearchPath(setConnStr, setDbSchema) + if err != nil { + log.Fatalf("Failed to update search path: %v", err) + } + + dbConn, err := pgxpool.New(ctx, updatedConnStr) + if err != nil { + log.Fatalf("Failed to connect to database for cleanup: %v", err) + } + defer dbConn.Close() + + query := fmt.Sprintf("DELETE FROM %s.kv_vise;", setDbSchema) + _, execErr := dbConn.Exec(ctx, query) + if execErr != nil { + log.Printf("Failed to cleanup table %s.kv_vise: %v", setDbSchema, execErr) + } else { + log.Printf("Successfully cleaned up table %s.kv_vise", setDbSchema) + } + } else { + setConnStr, _ := filepath.Abs(setConnStr) + if err := os.RemoveAll(setConnStr); err != nil { + log.Fatalf("Failed to delete state store %s: %v", setConnStr, err) + } + } +} + +// updateSearchPath updates the search_path (schema) to be used in the connection +func updateSearchPath(connStr string, newSearchPath string) (string, error) { + u, err := url.Parse(connStr) + if err != nil { + return "", fmt.Errorf("invalid connection string: %w", err) + } + + // Parse the query parameters + q := u.Query() + + // Update or add the search_path parameter + q.Set("search_path", newSearchPath) + + // Rebuild the connection string with updated parameters + u.RawQuery = q.Encode() + + return u.String(), nil +} + +func TestEngine(sessionId string) (engine.Engine, func(), chan bool) { + var err error + ctx := context.Background() + ctx = context.WithValue(ctx, "SessionId", sessionId) + pfp := path.Join(scriptDir, "pp.csv") + + var eventChannel = make(chan bool) + + cfg := engine.Config{ + Root: "root", + SessionId: sessionId, + OutputSize: uint32(160), + FlagCount: uint32(128), + } + + if setDbType == "postgres" { + setConnStr = config.DbConn + setConnStr, err = updateSearchPath(setConnStr, setDbSchema) + if err != nil { + fmt.Println("Error:", err) + os.Exit(1) + } + } else { + setConnStr, err = filepath.Abs(setConnStr) + if err != nil { + fmt.Fprintf(os.Stderr, "connstr err: %v", err) + os.Exit(1) + } + } + + conn, err := storage.ToConnData(setConnStr) + if err != nil { + fmt.Fprintf(os.Stderr, "connstr parse err: %v", err) + os.Exit(1) + } + + resourceDir := scriptDir + menuStorageService := storage.NewMenuStorageService(conn, resourceDir) + + rs, err := menuStorageService.GetResource(ctx) + if err != nil { + fmt.Fprintf(os.Stderr, "resource error: %v", err) + os.Exit(1) + } + + pe, err := menuStorageService.GetPersister(ctx) + if err != nil { + fmt.Fprintf(os.Stderr, "persister error: %v", err) + os.Exit(1) + } + + userDataStore, err := menuStorageService.GetUserdataDb(ctx) + if err != nil { + fmt.Fprintf(os.Stderr, "userdb error: %v", err) + os.Exit(1) + } + + dbResource, ok := rs.(*resource.DbResource) + if !ok { + fmt.Fprintf(os.Stderr, "dbresource cast error") + os.Exit(1) + } + + lhs, err := handlers.NewLocalHandlerService(ctx, pfp, true, dbResource, cfg, rs) + lhs.SetDataStore(&userDataStore) + lhs.SetPersister(pe) + + if err != nil { + fmt.Fprintf(os.Stderr, err.Error()) + os.Exit(1) + } + + if testtag.AccountService == nil { + testtag.AccountService = &remote.AccountService{} + } + + switch testtag.AccountService.(type) { + case *testservice.TestAccountService: + go func() { + eventChannel <- false + }() + case *remote.AccountService: + go func() { + time.Sleep(5 * time.Second) // Wait for 5 seconds + eventChannel <- true + }() + default: + panic("Unknown account service type") + } + + hl, err := lhs.GetHandler(testtag.AccountService) + if err != nil { + fmt.Fprintf(os.Stderr, err.Error()) + os.Exit(1) + } + + en := lhs.GetEngine() + en = en.WithFirst(hl.Init) + cleanFn := func() { + err := en.Finish() + if err != nil { + logg.Errorf(err.Error()) + } + + err = menuStorageService.Close() + if err != nil { + logg.Errorf(err.Error()) + } + logg.Infof("testengine storage closed") + } + return en, cleanFn, eventChannel +} diff --git a/internal/testutil/engine_test.go b/internal/testutil/engine_test.go new file mode 100644 index 0000000..f747468 --- /dev/null +++ b/internal/testutil/engine_test.go @@ -0,0 +1,15 @@ +package testutil + +import ( + "testing" +) + +func TestCreateEngine(t *testing.T) { + o, clean, eventC := TestEngine("foo") + defer clean() + defer func() { + <-eventC + close(eventC) + }() + _ = o +} diff --git a/internal/testutil/mocks/httpmocks/requesthandlermock.go b/internal/testutil/mocks/httpmocks/request_handler_mock.go similarity index 100% rename from internal/testutil/mocks/httpmocks/requesthandlermock.go rename to internal/testutil/mocks/httpmocks/request_handler_mock.go diff --git a/internal/testutil/mocks/httpmocks/requestparsermock.go b/internal/testutil/mocks/httpmocks/request_parser_mock.go similarity index 76% rename from internal/testutil/mocks/httpmocks/requestparsermock.go rename to internal/testutil/mocks/httpmocks/request_parser_mock.go index 54b16bf..3c19e12 100644 --- a/internal/testutil/mocks/httpmocks/requestparsermock.go +++ b/internal/testutil/mocks/httpmocks/request_parser_mock.go @@ -1,12 +1,14 @@ package httpmocks +import "context" + // MockRequestParser implements the handlers.RequestParser interface for testing type MockRequestParser struct { GetSessionIdFunc func(any) (string, error) GetInputFunc func(any) ([]byte, error) } -func (m *MockRequestParser) GetSessionId(rq any) (string, error) { +func (m *MockRequestParser) GetSessionId(ctx context.Context, rq any) (string, error) { return m.GetSessionIdFunc(rq) } diff --git a/internal/testutil/mocks/httpmocks/writermock.go b/internal/testutil/mocks/httpmocks/writer_mock.go similarity index 100% rename from internal/testutil/mocks/httpmocks/writermock.go rename to internal/testutil/mocks/httpmocks/writer_mock.go diff --git a/internal/testutil/mocks/servicemock.go b/internal/testutil/mocks/service_mock.go similarity index 100% rename from internal/testutil/mocks/servicemock.go rename to internal/testutil/mocks/service_mock.go diff --git a/internal/testutil/testservice/TestAccountService.go b/internal/testutil/testservice/account_service.go similarity index 100% rename from internal/testutil/testservice/TestAccountService.go rename to internal/testutil/testservice/account_service.go diff --git a/menutraversal_test/group_test.json b/menutraversal_test/group_test.json index f35beb9..0ffb49f 100644 --- a/menutraversal_test/group_test.json +++ b/menutraversal_test/group_test.json @@ -54,7 +54,7 @@ }, { "input": "1235", - "expectedContent": "Incorrect PIN\n1:Retry\n9:Quit" + "expectedContent": "Incorrect PIN. You have: 2 remaining attempt(s).\n1:Retry\n9:Quit" }, { "input": "1", @@ -95,7 +95,7 @@ }, { "input": "1235", - "expectedContent": "Incorrect PIN\n1:Retry\n9:Quit" + "expectedContent": "Incorrect PIN. You have: 2 remaining attempt(s).\n1:Retry\n9:Quit" }, { "input": "1", @@ -107,8 +107,7 @@ }, { "input": "0", - "expectedContent": "Balances:\n1:My balance\n2:Community balance\n0:Back" - + "expectedContent": "Balances:\n1:My balance\n2:Community balance\n0:Back" }, { "input": "0", @@ -141,7 +140,7 @@ }, { "input": "1235", - "expectedContent": "Incorrect PIN\n1:Retry\n9:Quit" + "expectedContent": "Incorrect PIN. You have: 2 remaining attempt(s).\n1:Retry\n9:Quit" }, { "input": "1", @@ -153,8 +152,7 @@ }, { "input": "0", - "expectedContent": "Balances:\n1:My balance\n2:Community balance\n0:Back" - + "expectedContent": "Balances:\n1:My balance\n2:Community balance\n0:Back" }, { "input": "0", @@ -195,7 +193,7 @@ }, { "input": "1", - "expectedContent": "Enter your year of birth\n0:Back" + "expectedContent": "Enter your year of birth\n0:Back" }, { "input": "1940", @@ -258,7 +256,6 @@ "input": "0", "expectedContent": "{balance}\n\n1:Send\n2:My Vouchers\n3:My Account\n4:Help\n9:Quit" } - ] }, { @@ -443,10 +440,4 @@ ] } ] -} - - - - - - +} \ No newline at end of file diff --git a/menutraversal_test/menu_traversal_test.go b/menutraversal_test/menu_traversal_test.go index 6b6b3da..4aee26e 100644 --- a/menutraversal_test/menu_traversal_test.go +++ b/menutraversal_test/menu_traversal_test.go @@ -6,7 +6,6 @@ import ( "flag" "log" "math/rand" - "os" "regexp" "testing" @@ -17,13 +16,15 @@ import ( var ( testData = driver.ReadData() - testStore = ".test_state" sessionID string src = rand.NewSource(42) g = rand.New(src) ) var groupTestFile = flag.String("test-file", "group_test.json", "The test file to use for running the group tests") +var database = flag.String("db", "gdbm", "Specify the database (gdbm or postgres)") +var connStr = flag.String("conn", ".test_state", "connection string") +var dbSchema = flag.String("schema", "test", "Specify the database schema (default test)") func GenerateSessionId() string { uu := uuid.NewGenWithOptions(uuid.WithRandomReader(g)) @@ -79,12 +80,15 @@ func extractSendAmount(response []byte) string { } func TestMain(m *testing.M) { + // Parse the flags + flag.Parse() sessionID = GenerateSessionId() - defer func() { - if err := os.RemoveAll(testStore); err != nil { - log.Fatalf("Failed to delete state store %s: %v", testStore, err) - } - }() + // set the db + testutil.SetDatabase(*database, *connStr, *dbSchema) + + // Cleanup the db after tests + defer testutil.CleanDatabase() + m.Run() } @@ -121,7 +125,6 @@ func TestAccountCreationSuccessful(t *testing.T) { } } <-eventChannel - } func TestAccountRegistrationRejectTerms(t *testing.T) { diff --git a/models/accountresponse.go b/models/account_response.go similarity index 100% rename from models/accountresponse.go rename to models/account_response.go diff --git a/models/balanceresponse.go b/models/balance_response.go similarity index 100% rename from models/balanceresponse.go rename to models/balance_response.go diff --git a/models/trackstatusresponse.go b/models/track_status_response.go similarity index 100% rename from models/trackstatusresponse.go rename to models/track_status_response.go diff --git a/remote/accountservice.go b/remote/account_service.go similarity index 100% rename from remote/accountservice.go rename to remote/account_service.go diff --git a/services/registration/blocked_account.vis b/services/registration/blocked_account.vis new file mode 100644 index 0000000..d8adab2 --- /dev/null +++ b/services/registration/blocked_account.vis @@ -0,0 +1,2 @@ +LOAD show_blocked_account 0 +HALT diff --git a/services/registration/incorrect_pin b/services/registration/incorrect_pin index 7fcf610..13a9562 100644 --- a/services/registration/incorrect_pin +++ b/services/registration/incorrect_pin @@ -1 +1 @@ -Incorrect PIN \ No newline at end of file +Incorrect PIN. You have: {{.reset_incorrect}} remaining attempt(s). \ No newline at end of file diff --git a/services/registration/incorrect_pin.vis b/services/registration/incorrect_pin.vis index 844f3d6..167364a 100644 --- a/services/registration/incorrect_pin.vis +++ b/services/registration/incorrect_pin.vis @@ -1,5 +1,7 @@ LOAD reset_incorrect 0 RELOAD reset_incorrect +MAP reset_incorrect +CATCH blocked_account flag_account_blocked 1 MOUT retry 1 MOUT quit 9 HALT diff --git a/services/registration/incorrect_pin_swa b/services/registration/incorrect_pin_swa index 34a0b28..ed22beb 100644 --- a/services/registration/incorrect_pin_swa +++ b/services/registration/incorrect_pin_swa @@ -1 +1 @@ -PIN ulioeka sio sahihi \ No newline at end of file +PIN ulioeka sio sahihi, una majaribio: {{.reset_incorrect}} yaliyobaki \ No newline at end of file diff --git a/services/registration/locale/swa/default.po b/services/registration/locale/swa/default.po index 4bf876b..6155063 100644 --- a/services/registration/locale/swa/default.po +++ b/services/registration/locale/swa/default.po @@ -7,8 +7,11 @@ msgstr "Ombi lako limetumwa. %s atapokea %s %s kutoka kwa %s." msgid "Thank you for using Sarafu. Goodbye!" msgstr "Asante kwa kutumia huduma ya Sarafu. Kwaheri!" -msgid "For more help,please call: 0757628885" -msgstr "Kwa usaidizi zaidi,piga: 0757628885" +msgid "For more help, please call: 0757628885" +msgstr "Kwa usaidizi zaidi, piga: 0757628885" + +msgid "Your account has been locked. For help on how to unblock your account, contact support at: 0757628885" +msgstr "Akaunti yako imefungwa. Kwa usaidizi wa jinsi ya kufungua akaunti yako, wasiliana na usaidizi kwa: 0757628885" msgid "Balance: %s\n" msgstr "Salio: %s\n" diff --git a/services/registration/pp.csv b/services/registration/pp.csv index 26a8833..aa1eb05 100644 --- a/services/registration/pp.csv +++ b/services/registration/pp.csv @@ -28,3 +28,5 @@ flag,flag_gender_set,34,this is set when the gender of the profile is set flag,flag_location_set,35,this is set when the location of the profile is set flag,flag_offerings_set,36,this is set when the offerings of the profile is set flag,flag_back_set,37,this is set when it is a back navigation +flag,flag_account_blocked,38,this is set when an account has been blocked after the allowed incorrect PIN attempts have been exceeded + diff --git a/services/registration/root.vis b/services/registration/root.vis index 02ef9e9..102e6e5 100644 --- a/services/registration/root.vis +++ b/services/registration/root.vis @@ -1,3 +1,4 @@ +CATCH blocked_account flag_account_blocked 1 CATCH select_language flag_language_set 0 CATCH terms flag_account_created 0 LOAD check_account_status 0