Compare commits
2 Commits
lash/ssh
...
wip-unit-t
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2cd0b4434f
|
||
|
|
c735bf688d
|
@@ -12,23 +12,27 @@ import (
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"git.defalsify.org/vise.git/asm"
|
||||
"git.defalsify.org/vise.git/db"
|
||||
fsdb "git.defalsify.org/vise.git/db/fs"
|
||||
gdbmdb "git.defalsify.org/vise.git/db/gdbm"
|
||||
"git.defalsify.org/vise.git/engine"
|
||||
"git.defalsify.org/vise.git/logging"
|
||||
"git.defalsify.org/vise.git/resource"
|
||||
|
||||
"git.grassecon.net/urdt/ussd/internal/handlers"
|
||||
"git.grassecon.net/urdt/ussd/internal/handlers/ussd"
|
||||
httpserver "git.grassecon.net/urdt/ussd/internal/http"
|
||||
"git.grassecon.net/urdt/ussd/internal/storage"
|
||||
)
|
||||
|
||||
var (
|
||||
logg = logging.NewVanilla()
|
||||
logg = logging.NewVanilla()
|
||||
scriptDir = path.Join("services", "registration")
|
||||
)
|
||||
|
||||
type atRequestParser struct{}
|
||||
type atRequestParser struct {}
|
||||
|
||||
func (arp *atRequestParser) GetSessionId(rq any) (string, error) {
|
||||
func(arp *atRequestParser) GetSessionId(rq any) (string, error) {
|
||||
rqv, ok := rq.(*http.Request)
|
||||
if !ok {
|
||||
return "", handlers.ErrInvalidRequest
|
||||
@@ -45,7 +49,7 @@ func (arp *atRequestParser) GetSessionId(rq any) (string, error) {
|
||||
return phoneNumber, nil
|
||||
}
|
||||
|
||||
func (arp *atRequestParser) GetInput(rq any) ([]byte, error) {
|
||||
func(arp *atRequestParser) GetInput(rq any) ([]byte, error) {
|
||||
rqv, ok := rq.(*http.Request)
|
||||
if !ok {
|
||||
return nil, handlers.ErrInvalidRequest
|
||||
@@ -64,6 +68,96 @@ func (arp *atRequestParser) GetInput(rq any) ([]byte, error) {
|
||||
return []byte(parts[len(parts)-1]), nil
|
||||
}
|
||||
|
||||
|
||||
func getFlags(fp string, debug bool) (*asm.FlagParser, error) {
|
||||
flagParser := asm.NewFlagParser().WithDebug()
|
||||
_, err := flagParser.Load(fp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return flagParser, nil
|
||||
}
|
||||
|
||||
func getHandler(appFlags *asm.FlagParser, rs *resource.DbResource, userdataStore db.Db) (*ussd.Handlers, error) {
|
||||
|
||||
ussdHandlers, err := ussd.NewHandlers(appFlags, userdataStore)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rs.AddLocalFunc("set_language", ussdHandlers.SetLanguage)
|
||||
rs.AddLocalFunc("create_account", ussdHandlers.CreateAccount)
|
||||
rs.AddLocalFunc("save_pin", ussdHandlers.SavePin)
|
||||
rs.AddLocalFunc("verify_pin", ussdHandlers.VerifyPin)
|
||||
rs.AddLocalFunc("check_identifier", ussdHandlers.CheckIdentifier)
|
||||
rs.AddLocalFunc("check_account_status", ussdHandlers.CheckAccountStatus)
|
||||
rs.AddLocalFunc("authorize_account", ussdHandlers.Authorize)
|
||||
rs.AddLocalFunc("quit", ussdHandlers.Quit)
|
||||
rs.AddLocalFunc("check_balance", ussdHandlers.CheckBalance)
|
||||
rs.AddLocalFunc("validate_recipient", ussdHandlers.ValidateRecipient)
|
||||
rs.AddLocalFunc("transaction_reset", ussdHandlers.TransactionReset)
|
||||
rs.AddLocalFunc("max_amount", ussdHandlers.MaxAmount)
|
||||
rs.AddLocalFunc("validate_amount", ussdHandlers.ValidateAmount)
|
||||
rs.AddLocalFunc("reset_transaction_amount", ussdHandlers.ResetTransactionAmount)
|
||||
rs.AddLocalFunc("get_recipient", ussdHandlers.GetRecipient)
|
||||
rs.AddLocalFunc("get_sender", ussdHandlers.GetSender)
|
||||
rs.AddLocalFunc("get_amount", ussdHandlers.GetAmount)
|
||||
rs.AddLocalFunc("reset_incorrect", ussdHandlers.ResetIncorrectPin)
|
||||
rs.AddLocalFunc("save_firstname", ussdHandlers.SaveFirstname)
|
||||
rs.AddLocalFunc("save_familyname", ussdHandlers.SaveFamilyname)
|
||||
rs.AddLocalFunc("save_gender", ussdHandlers.SaveGender)
|
||||
rs.AddLocalFunc("save_location", ussdHandlers.SaveLocation)
|
||||
rs.AddLocalFunc("save_yob", ussdHandlers.SaveYob)
|
||||
rs.AddLocalFunc("save_offerings", ussdHandlers.SaveOfferings)
|
||||
rs.AddLocalFunc("quit_with_balance", ussdHandlers.QuitWithBalance)
|
||||
rs.AddLocalFunc("reset_account_authorized", ussdHandlers.ResetAccountAuthorized)
|
||||
rs.AddLocalFunc("reset_allow_update", ussdHandlers.ResetAllowUpdate)
|
||||
rs.AddLocalFunc("get_profile_info", ussdHandlers.GetProfileInfo)
|
||||
rs.AddLocalFunc("verify_yob", ussdHandlers.VerifyYob)
|
||||
rs.AddLocalFunc("reset_incorrect_date_format", ussdHandlers.ResetIncorrectYob)
|
||||
rs.AddLocalFunc("set_reset_single_edit", ussdHandlers.SetResetSingleEdit)
|
||||
rs.AddLocalFunc("initiate_transaction", ussdHandlers.InitiateTransaction)
|
||||
rs.AddLocalFunc("save_temporary_pin", ussdHandlers.SaveTemporaryPin)
|
||||
rs.AddLocalFunc("verify_new_pin", ussdHandlers.VerifyNewPin)
|
||||
rs.AddLocalFunc("confirm_pin_change", ussdHandlers.ConfirmPinChange)
|
||||
rs.AddLocalFunc("quit_with_help",ussdHandlers.QuitWithHelp)
|
||||
|
||||
return ussdHandlers, nil
|
||||
}
|
||||
|
||||
func ensureDbDir(dbDir string) error {
|
||||
err := os.MkdirAll(dbDir, 0700)
|
||||
if err != nil {
|
||||
return fmt.Errorf("state dir create exited with error: %v\n", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getStateStore(dbDir string, ctx context.Context) (db.Db, error) {
|
||||
store := gdbmdb.NewGdbmDb()
|
||||
storeFile := path.Join(dbDir, "state.gdbm")
|
||||
store.Connect(ctx, storeFile)
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func getUserdataDb(dbDir string, ctx context.Context) db.Db {
|
||||
store := gdbmdb.NewGdbmDb()
|
||||
storeFile := path.Join(dbDir, "userdata.gdbm")
|
||||
store.Connect(ctx, storeFile)
|
||||
|
||||
return store
|
||||
}
|
||||
|
||||
func getResource(resourceDir string, ctx context.Context) (resource.Resource, error) {
|
||||
store := fsdb.NewFsDb()
|
||||
err := store.Connect(ctx, resourceDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rfs := resource.NewDbResource(store)
|
||||
return rfs, nil
|
||||
}
|
||||
|
||||
|
||||
func main() {
|
||||
var dbDir string
|
||||
var resourceDir string
|
||||
@@ -81,10 +175,15 @@ func main() {
|
||||
flag.UintVar(&port, "p", 7123, "http port")
|
||||
flag.Parse()
|
||||
|
||||
logg.Infof("start command", "dbdir", dbDir, "resourcedir", resourceDir, "outputsize", size)
|
||||
logg.Infof("start command", "dbdir", dbDir, "resourcedir", resourceDir, "outputsize", size)
|
||||
|
||||
ctx := context.Background()
|
||||
pfp := path.Join(scriptDir, "pp.csv")
|
||||
flagParser, err := getFlags(pfp, true)
|
||||
|
||||
if err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
cfg := engine.Config{
|
||||
Root: "root",
|
||||
@@ -98,20 +197,19 @@ func main() {
|
||||
cfg.EngineDebug = true
|
||||
}
|
||||
|
||||
menuStorageService := storage.NewMenuStorageService(dbDir, resourceDir)
|
||||
rs, err := menuStorageService.GetResource(ctx)
|
||||
rs, err := getResource(resourceDir, ctx)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
err = menuStorageService.EnsureDbDir()
|
||||
err = ensureDbDir(dbDir)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
userdataStore, err := menuStorageService.GetUserdataDb(ctx)
|
||||
userdataStore := getUserdataDb(dbDir, ctx)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, err.Error())
|
||||
os.Exit(1)
|
||||
@@ -123,21 +221,13 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
lhs, err := handlers.NewLocalHandlerService(pfp, true, dbResource, cfg, rs)
|
||||
lhs.SetDataStore(&userdataStore)
|
||||
|
||||
hl, err := getHandler(flagParser, dbResource, userdataStore)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
hl, err := lhs.GetHandler()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
stateStore, err := menuStorageService.GetStateStore(ctx)
|
||||
stateStore, err := getStateStore(dbDir, ctx)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, err.Error())
|
||||
os.Exit(1)
|
||||
@@ -148,7 +238,7 @@ func main() {
|
||||
bsh := handlers.NewBaseSessionHandler(cfg, rs, stateStore, userdataStore, rp, hl)
|
||||
sh := httpserver.NewATSessionHandler(bsh)
|
||||
s := &http.Server{
|
||||
Addr: fmt.Sprintf("%s:%s", host, strconv.Itoa(int(port))),
|
||||
Addr: fmt.Sprintf("%s:%s", host, strconv.Itoa(int(port))),
|
||||
Handler: sh,
|
||||
}
|
||||
s.RegisterOnShutdown(sh.Shutdown)
|
||||
|
||||
@@ -9,32 +9,125 @@ import (
|
||||
"path"
|
||||
"syscall"
|
||||
|
||||
"git.defalsify.org/vise.git/asm"
|
||||
"git.defalsify.org/vise.git/db"
|
||||
fsdb "git.defalsify.org/vise.git/db/fs"
|
||||
gdbmdb "git.defalsify.org/vise.git/db/gdbm"
|
||||
"git.defalsify.org/vise.git/engine"
|
||||
"git.defalsify.org/vise.git/logging"
|
||||
"git.defalsify.org/vise.git/resource"
|
||||
"git.defalsify.org/vise.git/logging"
|
||||
|
||||
"git.grassecon.net/urdt/ussd/internal/handlers/ussd"
|
||||
"git.grassecon.net/urdt/ussd/internal/handlers"
|
||||
"git.grassecon.net/urdt/ussd/internal/storage"
|
||||
)
|
||||
|
||||
var (
|
||||
logg = logging.NewVanilla()
|
||||
logg = logging.NewVanilla()
|
||||
scriptDir = path.Join("services", "registration")
|
||||
)
|
||||
|
||||
type asyncRequestParser struct {
|
||||
sessionId string
|
||||
input []byte
|
||||
input []byte
|
||||
}
|
||||
|
||||
func (p *asyncRequestParser) GetSessionId(r any) (string, error) {
|
||||
func(p *asyncRequestParser) GetSessionId(r any) (string, error) {
|
||||
return p.sessionId, nil
|
||||
}
|
||||
|
||||
func (p *asyncRequestParser) GetInput(r any) ([]byte, error) {
|
||||
func(p *asyncRequestParser) GetInput(r any) ([]byte, error) {
|
||||
return p.input, nil
|
||||
}
|
||||
|
||||
func getFlags(fp string, debug bool) (*asm.FlagParser, error) {
|
||||
flagParser := asm.NewFlagParser().WithDebug()
|
||||
_, err := flagParser.Load(fp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return flagParser, nil
|
||||
}
|
||||
|
||||
func getHandler(appFlags *asm.FlagParser, rs *resource.DbResource, userdataStore db.Db) (*ussd.Handlers, error) {
|
||||
|
||||
ussdHandlers, err := ussd.NewHandlers(appFlags, userdataStore)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rs.AddLocalFunc("set_language", ussdHandlers.SetLanguage)
|
||||
rs.AddLocalFunc("create_account", ussdHandlers.CreateAccount)
|
||||
rs.AddLocalFunc("save_pin", ussdHandlers.SavePin)
|
||||
rs.AddLocalFunc("verify_pin", ussdHandlers.VerifyPin)
|
||||
rs.AddLocalFunc("check_identifier", ussdHandlers.CheckIdentifier)
|
||||
rs.AddLocalFunc("check_account_status", ussdHandlers.CheckAccountStatus)
|
||||
rs.AddLocalFunc("authorize_account", ussdHandlers.Authorize)
|
||||
rs.AddLocalFunc("quit", ussdHandlers.Quit)
|
||||
rs.AddLocalFunc("check_balance", ussdHandlers.CheckBalance)
|
||||
rs.AddLocalFunc("validate_recipient", ussdHandlers.ValidateRecipient)
|
||||
rs.AddLocalFunc("transaction_reset", ussdHandlers.TransactionReset)
|
||||
rs.AddLocalFunc("max_amount", ussdHandlers.MaxAmount)
|
||||
rs.AddLocalFunc("validate_amount", ussdHandlers.ValidateAmount)
|
||||
rs.AddLocalFunc("reset_transaction_amount", ussdHandlers.ResetTransactionAmount)
|
||||
rs.AddLocalFunc("get_recipient", ussdHandlers.GetRecipient)
|
||||
rs.AddLocalFunc("get_sender", ussdHandlers.GetSender)
|
||||
rs.AddLocalFunc("get_amount", ussdHandlers.GetAmount)
|
||||
rs.AddLocalFunc("reset_incorrect", ussdHandlers.ResetIncorrectPin)
|
||||
rs.AddLocalFunc("save_firstname", ussdHandlers.SaveFirstname)
|
||||
rs.AddLocalFunc("save_familyname", ussdHandlers.SaveFamilyname)
|
||||
rs.AddLocalFunc("save_gender", ussdHandlers.SaveGender)
|
||||
rs.AddLocalFunc("save_location", ussdHandlers.SaveLocation)
|
||||
rs.AddLocalFunc("save_yob", ussdHandlers.SaveYob)
|
||||
rs.AddLocalFunc("save_offerings", ussdHandlers.SaveOfferings)
|
||||
rs.AddLocalFunc("quit_with_balance", ussdHandlers.QuitWithBalance)
|
||||
rs.AddLocalFunc("reset_account_authorized", ussdHandlers.ResetAccountAuthorized)
|
||||
rs.AddLocalFunc("reset_allow_update", ussdHandlers.ResetAllowUpdate)
|
||||
rs.AddLocalFunc("get_profile_info", ussdHandlers.GetProfileInfo)
|
||||
rs.AddLocalFunc("verify_yob", ussdHandlers.VerifyYob)
|
||||
rs.AddLocalFunc("reset_incorrect_date_format", ussdHandlers.ResetIncorrectYob)
|
||||
rs.AddLocalFunc("set_reset_single_edit", ussdHandlers.SetResetSingleEdit)
|
||||
rs.AddLocalFunc("initiate_transaction", ussdHandlers.InitiateTransaction)
|
||||
rs.AddLocalFunc("save_temporary_pin", ussdHandlers.SaveTemporaryPin)
|
||||
rs.AddLocalFunc("verify_new_pin", ussdHandlers.VerifyNewPin)
|
||||
rs.AddLocalFunc("confirm_pin_change", ussdHandlers.ConfirmPinChange)
|
||||
rs.AddLocalFunc("quit_with_help",ussdHandlers.QuitWithHelp)
|
||||
|
||||
return ussdHandlers, nil
|
||||
}
|
||||
|
||||
func ensureDbDir(dbDir string) error {
|
||||
err := os.MkdirAll(dbDir, 0700)
|
||||
if err != nil {
|
||||
return fmt.Errorf("state dir create exited with error: %v\n", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getStateStore(dbDir string, ctx context.Context) (db.Db, error) {
|
||||
store := gdbmdb.NewGdbmDb()
|
||||
storeFile := path.Join(dbDir, "state.gdbm")
|
||||
store.Connect(ctx, storeFile)
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func getUserdataDb(dbDir string, ctx context.Context) db.Db {
|
||||
store := gdbmdb.NewGdbmDb()
|
||||
storeFile := path.Join(dbDir, "userdata.gdbm")
|
||||
store.Connect(ctx, storeFile)
|
||||
|
||||
return store
|
||||
}
|
||||
|
||||
func getResource(resourceDir string, ctx context.Context) (resource.Resource, error) {
|
||||
store := fsdb.NewFsDb()
|
||||
err := store.Connect(ctx, resourceDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rfs := resource.NewDbResource(store)
|
||||
return rfs, nil
|
||||
}
|
||||
|
||||
|
||||
func main() {
|
||||
var sessionId string
|
||||
var dbDir string
|
||||
@@ -54,10 +147,15 @@ func main() {
|
||||
flag.UintVar(&port, "p", 7123, "http port")
|
||||
flag.Parse()
|
||||
|
||||
logg.Infof("start command", "dbdir", dbDir, "resourcedir", resourceDir, "outputsize", size, "sessionId", sessionId)
|
||||
logg.Infof("start command", "dbdir", dbDir, "resourcedir", resourceDir, "outputsize", size, "sessionId", sessionId)
|
||||
|
||||
ctx := context.Background()
|
||||
pfp := path.Join(scriptDir, "pp.csv")
|
||||
flagParser, err := getFlags(pfp, true)
|
||||
|
||||
if err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
cfg := engine.Config{
|
||||
Root: "root",
|
||||
@@ -71,20 +169,19 @@ func main() {
|
||||
cfg.EngineDebug = true
|
||||
}
|
||||
|
||||
menuStorageService := storage.NewMenuStorageService(dbDir, resourceDir)
|
||||
rs, err := menuStorageService.GetResource(ctx)
|
||||
rs, err := getResource(resourceDir, ctx)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
err = menuStorageService.EnsureDbDir()
|
||||
err = ensureDbDir(dbDir)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
userdataStore, err := menuStorageService.GetUserdataDb(ctx)
|
||||
userdataStore := getUserdataDb(dbDir, ctx)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, err.Error())
|
||||
os.Exit(1)
|
||||
@@ -96,16 +193,13 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
lhs, err := handlers.NewLocalHandlerService(pfp, true, dbResource, cfg, rs)
|
||||
lhs.SetDataStore(&userdataStore)
|
||||
|
||||
hl, err := lhs.GetHandler()
|
||||
hl, err := getHandler(flagParser, dbResource, userdataStore)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
stateStore, err := menuStorageService.GetStateStore(ctx)
|
||||
stateStore, err := getStateStore(dbDir, ctx)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, err.Error())
|
||||
os.Exit(1)
|
||||
@@ -118,7 +212,7 @@ func main() {
|
||||
sh := handlers.NewBaseSessionHandler(cfg, rs, stateStore, userdataStore, rp, hl)
|
||||
cfg.SessionId = sessionId
|
||||
rqs := handlers.RequestSession{
|
||||
Ctx: ctx,
|
||||
Ctx: ctx,
|
||||
Writer: os.Stdout,
|
||||
Config: cfg,
|
||||
}
|
||||
@@ -154,7 +248,7 @@ func main() {
|
||||
fmt.Println("")
|
||||
_, err = fmt.Scanln(&rqs.Input)
|
||||
if err != nil {
|
||||
fmt.Errorf("error in input: %v", err)
|
||||
fmt.Errorf("error in input: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
127
cmd/http/main.go
127
cmd/http/main.go
@@ -11,20 +11,113 @@ import (
|
||||
"strconv"
|
||||
"syscall"
|
||||
|
||||
"git.defalsify.org/vise.git/asm"
|
||||
"git.defalsify.org/vise.git/db"
|
||||
fsdb "git.defalsify.org/vise.git/db/fs"
|
||||
gdbmdb "git.defalsify.org/vise.git/db/gdbm"
|
||||
"git.defalsify.org/vise.git/engine"
|
||||
"git.defalsify.org/vise.git/logging"
|
||||
"git.defalsify.org/vise.git/resource"
|
||||
"git.defalsify.org/vise.git/logging"
|
||||
|
||||
"git.grassecon.net/urdt/ussd/internal/handlers/ussd"
|
||||
"git.grassecon.net/urdt/ussd/internal/handlers"
|
||||
httpserver "git.grassecon.net/urdt/ussd/internal/http"
|
||||
"git.grassecon.net/urdt/ussd/internal/storage"
|
||||
)
|
||||
|
||||
var (
|
||||
logg = logging.NewVanilla()
|
||||
logg = logging.NewVanilla()
|
||||
scriptDir = path.Join("services", "registration")
|
||||
)
|
||||
|
||||
func getFlags(fp string, debug bool) (*asm.FlagParser, error) {
|
||||
flagParser := asm.NewFlagParser().WithDebug()
|
||||
_, err := flagParser.Load(fp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return flagParser, nil
|
||||
}
|
||||
|
||||
func getHandler(appFlags *asm.FlagParser, rs *resource.DbResource, userdataStore db.Db) (*ussd.Handlers, error) {
|
||||
|
||||
ussdHandlers, err := ussd.NewHandlers(appFlags, userdataStore)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rs.AddLocalFunc("set_language", ussdHandlers.SetLanguage)
|
||||
rs.AddLocalFunc("create_account", ussdHandlers.CreateAccount)
|
||||
rs.AddLocalFunc("save_pin", ussdHandlers.SavePin)
|
||||
rs.AddLocalFunc("verify_pin", ussdHandlers.VerifyPin)
|
||||
rs.AddLocalFunc("check_identifier", ussdHandlers.CheckIdentifier)
|
||||
rs.AddLocalFunc("check_account_status", ussdHandlers.CheckAccountStatus)
|
||||
rs.AddLocalFunc("authorize_account", ussdHandlers.Authorize)
|
||||
rs.AddLocalFunc("quit", ussdHandlers.Quit)
|
||||
rs.AddLocalFunc("check_balance", ussdHandlers.CheckBalance)
|
||||
rs.AddLocalFunc("validate_recipient", ussdHandlers.ValidateRecipient)
|
||||
rs.AddLocalFunc("transaction_reset", ussdHandlers.TransactionReset)
|
||||
rs.AddLocalFunc("max_amount", ussdHandlers.MaxAmount)
|
||||
rs.AddLocalFunc("validate_amount", ussdHandlers.ValidateAmount)
|
||||
rs.AddLocalFunc("reset_transaction_amount", ussdHandlers.ResetTransactionAmount)
|
||||
rs.AddLocalFunc("get_recipient", ussdHandlers.GetRecipient)
|
||||
rs.AddLocalFunc("get_sender", ussdHandlers.GetSender)
|
||||
rs.AddLocalFunc("get_amount", ussdHandlers.GetAmount)
|
||||
rs.AddLocalFunc("reset_incorrect", ussdHandlers.ResetIncorrectPin)
|
||||
rs.AddLocalFunc("save_firstname", ussdHandlers.SaveFirstname)
|
||||
rs.AddLocalFunc("save_familyname", ussdHandlers.SaveFamilyname)
|
||||
rs.AddLocalFunc("save_gender", ussdHandlers.SaveGender)
|
||||
rs.AddLocalFunc("save_location", ussdHandlers.SaveLocation)
|
||||
rs.AddLocalFunc("save_yob", ussdHandlers.SaveYob)
|
||||
rs.AddLocalFunc("save_offerings", ussdHandlers.SaveOfferings)
|
||||
rs.AddLocalFunc("quit_with_balance", ussdHandlers.QuitWithBalance)
|
||||
rs.AddLocalFunc("reset_account_authorized", ussdHandlers.ResetAccountAuthorized)
|
||||
rs.AddLocalFunc("reset_allow_update", ussdHandlers.ResetAllowUpdate)
|
||||
rs.AddLocalFunc("get_profile_info", ussdHandlers.GetProfileInfo)
|
||||
rs.AddLocalFunc("verify_yob", ussdHandlers.VerifyYob)
|
||||
rs.AddLocalFunc("reset_incorrect_date_format", ussdHandlers.ResetIncorrectYob)
|
||||
rs.AddLocalFunc("set_reset_single_edit", ussdHandlers.SetResetSingleEdit)
|
||||
rs.AddLocalFunc("initiate_transaction", ussdHandlers.InitiateTransaction)
|
||||
rs.AddLocalFunc("save_temporary_pin", ussdHandlers.SaveTemporaryPin)
|
||||
rs.AddLocalFunc("verify_new_pin", ussdHandlers.VerifyNewPin)
|
||||
rs.AddLocalFunc("confirm_pin_change", ussdHandlers.ConfirmPinChange)
|
||||
rs.AddLocalFunc("quit_with_help",ussdHandlers.QuitWithHelp)
|
||||
|
||||
return ussdHandlers, nil
|
||||
}
|
||||
|
||||
func ensureDbDir(dbDir string) error {
|
||||
err := os.MkdirAll(dbDir, 0700)
|
||||
if err != nil {
|
||||
return fmt.Errorf("state dir create exited with error: %v\n", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getStateStore(dbDir string, ctx context.Context) (db.Db, error) {
|
||||
store := gdbmdb.NewGdbmDb()
|
||||
storeFile := path.Join(dbDir, "state.gdbm")
|
||||
store.Connect(ctx, storeFile)
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func getUserdataDb(dbDir string, ctx context.Context) db.Db {
|
||||
store := gdbmdb.NewGdbmDb()
|
||||
storeFile := path.Join(dbDir, "userdata.gdbm")
|
||||
store.Connect(ctx, storeFile)
|
||||
|
||||
return store
|
||||
}
|
||||
|
||||
func getResource(resourceDir string, ctx context.Context) (resource.Resource, error) {
|
||||
store := fsdb.NewFsDb()
|
||||
err := store.Connect(ctx, resourceDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rfs := resource.NewDbResource(store)
|
||||
return rfs, nil
|
||||
}
|
||||
|
||||
|
||||
func main() {
|
||||
var dbDir string
|
||||
var resourceDir string
|
||||
@@ -42,10 +135,15 @@ func main() {
|
||||
flag.UintVar(&port, "p", 7123, "http port")
|
||||
flag.Parse()
|
||||
|
||||
logg.Infof("start command", "dbdir", dbDir, "resourcedir", resourceDir, "outputsize", size)
|
||||
logg.Infof("start command", "dbdir", dbDir, "resourcedir", resourceDir, "outputsize", size)
|
||||
|
||||
ctx := context.Background()
|
||||
pfp := path.Join(scriptDir, "pp.csv")
|
||||
flagParser, err := getFlags(pfp, true)
|
||||
|
||||
if err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
cfg := engine.Config{
|
||||
Root: "root",
|
||||
@@ -59,20 +157,19 @@ func main() {
|
||||
cfg.EngineDebug = true
|
||||
}
|
||||
|
||||
menuStorageService := storage.NewMenuStorageService(dbDir, resourceDir)
|
||||
rs, err := menuStorageService.GetResource(ctx)
|
||||
rs, err := getResource(resourceDir, ctx)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
err = menuStorageService.EnsureDbDir()
|
||||
err = ensureDbDir(dbDir)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
userdataStore, err := menuStorageService.GetUserdataDb(ctx)
|
||||
userdataStore := getUserdataDb(dbDir, ctx)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, err.Error())
|
||||
os.Exit(1)
|
||||
@@ -84,21 +181,13 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
lhs, err := handlers.NewLocalHandlerService(pfp, true, dbResource, cfg, rs)
|
||||
lhs.SetDataStore(&userdataStore)
|
||||
|
||||
hl, err := getHandler(flagParser, dbResource, userdataStore)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
hl, err := lhs.GetHandler()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
stateStore, err := menuStorageService.GetStateStore(ctx)
|
||||
stateStore, err := getStateStore(dbDir, ctx)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, err.Error())
|
||||
os.Exit(1)
|
||||
@@ -109,7 +198,7 @@ func main() {
|
||||
bsh := handlers.NewBaseSessionHandler(cfg, rs, stateStore, userdataStore, rp, hl)
|
||||
sh := httpserver.ToSessionHandler(bsh)
|
||||
s := &http.Server{
|
||||
Addr: fmt.Sprintf("%s:%s", host, strconv.Itoa(int(port))),
|
||||
Addr: fmt.Sprintf("%s:%s", host, strconv.Itoa(int(port))),
|
||||
Handler: sh,
|
||||
}
|
||||
s.RegisterOnShutdown(sh.Shutdown)
|
||||
|
||||
149
cmd/main.go
149
cmd/main.go
@@ -7,11 +7,15 @@ import (
|
||||
"os"
|
||||
"path"
|
||||
|
||||
"git.defalsify.org/vise.git/asm"
|
||||
"git.defalsify.org/vise.git/db"
|
||||
fsdb "git.defalsify.org/vise.git/db/fs"
|
||||
gdbmdb "git.defalsify.org/vise.git/db/gdbm"
|
||||
"git.defalsify.org/vise.git/engine"
|
||||
"git.defalsify.org/vise.git/logging"
|
||||
"git.defalsify.org/vise.git/persist"
|
||||
"git.defalsify.org/vise.git/resource"
|
||||
"git.grassecon.net/urdt/ussd/internal/handlers"
|
||||
"git.grassecon.net/urdt/ussd/internal/storage"
|
||||
"git.grassecon.net/urdt/ussd/internal/handlers/ussd"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -19,6 +23,106 @@ var (
|
||||
scriptDir = path.Join("services", "registration")
|
||||
)
|
||||
|
||||
func getParser(fp string, debug bool) (*asm.FlagParser, error) {
|
||||
flagParser := asm.NewFlagParser().WithDebug()
|
||||
_, err := flagParser.Load(fp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return flagParser, nil
|
||||
}
|
||||
|
||||
func getHandler(appFlags *asm.FlagParser, rs *resource.DbResource, pe *persist.Persister, userdataStore db.Db) (*ussd.Handlers, error) {
|
||||
|
||||
ussdHandlers, err := ussd.NewHandlers(appFlags, userdataStore)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ussdHandlers = ussdHandlers.WithPersister(pe)
|
||||
rs.AddLocalFunc("set_language", ussdHandlers.SetLanguage)
|
||||
rs.AddLocalFunc("create_account", ussdHandlers.CreateAccount)
|
||||
rs.AddLocalFunc("save_pin", ussdHandlers.SavePin)
|
||||
rs.AddLocalFunc("verify_pin", ussdHandlers.VerifyPin)
|
||||
rs.AddLocalFunc("check_identifier", ussdHandlers.CheckIdentifier)
|
||||
rs.AddLocalFunc("check_account_status", ussdHandlers.CheckAccountStatus)
|
||||
rs.AddLocalFunc("authorize_account", ussdHandlers.Authorize)
|
||||
rs.AddLocalFunc("quit", ussdHandlers.Quit)
|
||||
rs.AddLocalFunc("check_balance", ussdHandlers.CheckBalance)
|
||||
rs.AddLocalFunc("validate_recipient", ussdHandlers.ValidateRecipient)
|
||||
rs.AddLocalFunc("transaction_reset", ussdHandlers.TransactionReset)
|
||||
rs.AddLocalFunc("max_amount", ussdHandlers.MaxAmount)
|
||||
rs.AddLocalFunc("validate_amount", ussdHandlers.ValidateAmount)
|
||||
rs.AddLocalFunc("reset_transaction_amount", ussdHandlers.ResetTransactionAmount)
|
||||
rs.AddLocalFunc("get_recipient", ussdHandlers.GetRecipient)
|
||||
rs.AddLocalFunc("get_sender", ussdHandlers.GetSender)
|
||||
rs.AddLocalFunc("get_amount", ussdHandlers.GetAmount)
|
||||
rs.AddLocalFunc("reset_incorrect", ussdHandlers.ResetIncorrectPin)
|
||||
rs.AddLocalFunc("save_firstname", ussdHandlers.SaveFirstname)
|
||||
rs.AddLocalFunc("save_familyname", ussdHandlers.SaveFamilyname)
|
||||
rs.AddLocalFunc("save_gender", ussdHandlers.SaveGender)
|
||||
rs.AddLocalFunc("save_location", ussdHandlers.SaveLocation)
|
||||
rs.AddLocalFunc("save_yob", ussdHandlers.SaveYob)
|
||||
rs.AddLocalFunc("save_offerings", ussdHandlers.SaveOfferings)
|
||||
rs.AddLocalFunc("quit_with_balance", ussdHandlers.QuitWithBalance)
|
||||
rs.AddLocalFunc("reset_account_authorized", ussdHandlers.ResetAccountAuthorized)
|
||||
rs.AddLocalFunc("reset_allow_update", ussdHandlers.ResetAllowUpdate)
|
||||
rs.AddLocalFunc("get_profile_info", ussdHandlers.GetProfileInfo)
|
||||
rs.AddLocalFunc("verify_yob", ussdHandlers.VerifyYob)
|
||||
rs.AddLocalFunc("reset_incorrect_date_format", ussdHandlers.ResetIncorrectYob)
|
||||
rs.AddLocalFunc("set_reset_single_edit", ussdHandlers.SetResetSingleEdit)
|
||||
rs.AddLocalFunc("initiate_transaction", ussdHandlers.InitiateTransaction)
|
||||
rs.AddLocalFunc("save_temporary_pin", ussdHandlers.SaveTemporaryPin)
|
||||
rs.AddLocalFunc("verify_new_pin", ussdHandlers.VerifyNewPin)
|
||||
rs.AddLocalFunc("confirm_pin_change", ussdHandlers.ConfirmPinChange)
|
||||
rs.AddLocalFunc("quit_with_help",ussdHandlers.QuitWithHelp)
|
||||
|
||||
return ussdHandlers, nil
|
||||
}
|
||||
|
||||
func ensureDbDir(dbDir string) error {
|
||||
err := os.MkdirAll(dbDir, 0700)
|
||||
if err != nil {
|
||||
return fmt.Errorf("state dir create exited with error: %v\n", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getPersister(dbDir string, ctx context.Context) (*persist.Persister, error) {
|
||||
err := ensureDbDir(dbDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
store := gdbmdb.NewGdbmDb()
|
||||
storeFile := path.Join(dbDir, "state.gdbm")
|
||||
store.Connect(ctx, storeFile)
|
||||
pr := persist.NewPersister(store)
|
||||
return pr, nil
|
||||
}
|
||||
|
||||
func getUserdataDb(dbDir string, ctx context.Context) db.Db {
|
||||
store := gdbmdb.NewGdbmDb()
|
||||
storeFile := path.Join(dbDir, "userdata.gdbm")
|
||||
store.Connect(ctx, storeFile)
|
||||
|
||||
return store
|
||||
}
|
||||
|
||||
func getResource(resourceDir string, ctx context.Context) (resource.Resource, error) {
|
||||
store := fsdb.NewFsDb()
|
||||
err := store.Connect(ctx, resourceDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rfs := resource.NewDbResource(store)
|
||||
return rfs, nil
|
||||
}
|
||||
|
||||
func getEngine(cfg engine.Config, rs resource.Resource, pr *persist.Persister) *engine.DefaultEngine {
|
||||
en := engine.NewEngine(cfg, rs)
|
||||
en = en.WithPersister(pr)
|
||||
return en
|
||||
}
|
||||
|
||||
func main() {
|
||||
var dbDir string
|
||||
var size uint
|
||||
@@ -35,6 +139,11 @@ func main() {
|
||||
ctx := context.Background()
|
||||
ctx = context.WithValue(ctx, "SessionId", sessionId)
|
||||
pfp := path.Join(scriptDir, "pp.csv")
|
||||
flagParser, err := getParser(pfp, true)
|
||||
|
||||
if err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
cfg := engine.Config{
|
||||
Root: "root",
|
||||
@@ -43,28 +152,19 @@ func main() {
|
||||
FlagCount: uint32(16),
|
||||
}
|
||||
|
||||
resourceDir := scriptDir
|
||||
menuStorageService := storage.NewMenuStorageService(dbDir, resourceDir)
|
||||
|
||||
err := menuStorageService.EnsureDbDir()
|
||||
rs, err := getResource(scriptDir, ctx)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
rs, err := menuStorageService.GetResource(ctx)
|
||||
pe, err := getPersister(dbDir, ctx)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
pe, err := menuStorageService.GetPersister(ctx)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
userdatastore, err := menuStorageService.GetUserdataDb(ctx)
|
||||
store := getUserdataDb(dbDir, ctx)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, err.Error())
|
||||
os.Exit(1)
|
||||
@@ -76,28 +176,25 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
lhs, err := handlers.NewLocalHandlerService(pfp, true, dbResource, cfg, rs)
|
||||
lhs.SetDataStore(&userdatastore)
|
||||
lhs.SetPersister(pe)
|
||||
|
||||
hl, err := getHandler(flagParser, dbResource, pe, store)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
hl, err := lhs.GetHandler()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
en := lhs.GetEngine()
|
||||
en := getEngine(cfg, rs, pe)
|
||||
en = en.WithFirst(hl.Init)
|
||||
if debug {
|
||||
en = en.WithDebug(nil)
|
||||
}
|
||||
|
||||
err = engine.Loop(ctx, en, os.Stdin, os.Stdout, nil)
|
||||
_, err = en.Init(ctx)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "engine init exited with error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
err = engine.Loop(ctx, en, os.Stdin, os.Stdout)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "loop exited with error: %v\n", err)
|
||||
os.Exit(1)
|
||||
|
||||
359
cmd/ssh/main.go
359
cmd/ssh/main.go
@@ -1,359 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"net"
|
||||
"path"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"git.defalsify.org/vise.git/db"
|
||||
"git.defalsify.org/vise.git/engine"
|
||||
"git.defalsify.org/vise.git/logging"
|
||||
"git.defalsify.org/vise.git/resource"
|
||||
"git.defalsify.org/vise.git/state"
|
||||
|
||||
"git.grassecon.net/urdt/ussd/internal/handlers"
|
||||
"git.grassecon.net/urdt/ussd/internal/storage"
|
||||
)
|
||||
|
||||
var (
|
||||
wg sync.WaitGroup
|
||||
keyStore db.Db
|
||||
logg = logging.NewVanilla()
|
||||
scriptDir = path.Join("services", "registration")
|
||||
)
|
||||
|
||||
type auther struct {
|
||||
Ctx context.Context
|
||||
auth map[string]string
|
||||
}
|
||||
|
||||
func NewAuther(ctx context.Context) *auther {
|
||||
return &auther{
|
||||
Ctx: ctx,
|
||||
auth: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
func(a *auther) Check(conn ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
|
||||
keyStore.SetLanguage(nil)
|
||||
keyStore.SetPrefix(storage.DATATYPE_CUSTOM)
|
||||
k := append([]byte{0x01}, pubKey.Marshal()...)
|
||||
v, err := keyStore.Get(a.Ctx, k)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ka := hex.EncodeToString(conn.SessionID())
|
||||
va := string(v)
|
||||
a.auth[ka] = va
|
||||
fmt.Fprintf(os.Stderr, "connect: %s -> %s\n", ka, v)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func(a *auther) FromConn(c *ssh.ServerConn) (string, error) {
|
||||
if c == nil {
|
||||
return "", errors.New("nil server conn")
|
||||
}
|
||||
if c.Conn == nil {
|
||||
return "", errors.New("nil underlying conn")
|
||||
}
|
||||
return a.Get(c.Conn.SessionID())
|
||||
}
|
||||
|
||||
|
||||
func(a *auther) Get(k []byte) (string, error) {
|
||||
ka := hex.EncodeToString(k)
|
||||
v, ok := a.auth[ka]
|
||||
if !ok {
|
||||
return "", errors.New("not found")
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
//func serve(ctx context.Context, sessionId string, ch ssh.NewChannel, mss *storage.MenuStorageService, lhs *handlers.LocalHandlerService) error {
|
||||
func serve(ctx context.Context, sessionId string, ch ssh.NewChannel, en engine.Engine) error {
|
||||
if ch == nil {
|
||||
return errors.New("nil channel")
|
||||
}
|
||||
if ch.ChannelType() != "session" {
|
||||
ch.Reject(ssh.UnknownChannelType, "that is not the channel you are looking for")
|
||||
return errors.New("not a session")
|
||||
}
|
||||
channel, requests, err := ch.Accept()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer channel.Close()
|
||||
wg.Add(1)
|
||||
go func(reqIn <-chan *ssh.Request) {
|
||||
defer wg.Done()
|
||||
for req := range reqIn {
|
||||
req.Reply(req.Type == "shell", nil)
|
||||
}
|
||||
_ = requests
|
||||
}(requests)
|
||||
|
||||
cont, err := en.Exec(ctx, []byte{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("initial engine exec err: %v", err)
|
||||
}
|
||||
|
||||
var input [state.INPUT_LIMIT]byte
|
||||
for cont {
|
||||
c, err := en.Flush(ctx, channel)
|
||||
if err != nil {
|
||||
return fmt.Errorf("flush err: %v", err)
|
||||
}
|
||||
_, err = channel.Write([]byte{0x0a})
|
||||
if err != nil {
|
||||
return fmt.Errorf("newline err: %v", err)
|
||||
}
|
||||
c, err = channel.Read(input[:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("read input fail: %v", err)
|
||||
}
|
||||
logg.TraceCtxf(ctx, "input read", "c", c, "input", input[:c-1])
|
||||
cont, err = en.Exec(ctx, input[:c-1])
|
||||
if err != nil {
|
||||
return fmt.Errorf("engine exec err: %v", err)
|
||||
}
|
||||
logg.TraceCtxf(ctx, "exec cont", "cont", cont, "en", en)
|
||||
_ = c
|
||||
}
|
||||
c, err := en.Flush(ctx, channel)
|
||||
if err != nil {
|
||||
return fmt.Errorf("last flush err: %v", err)
|
||||
}
|
||||
_ = c
|
||||
return nil
|
||||
}
|
||||
|
||||
type sshRunner struct {
|
||||
Ctx context.Context
|
||||
Cfg engine.Config
|
||||
FlagFile string
|
||||
DbDir string
|
||||
ResourceDir string
|
||||
Debug bool
|
||||
KeyFile string
|
||||
Host string
|
||||
Port uint
|
||||
}
|
||||
|
||||
func(s *sshRunner) GetEngine(sessionId string) (engine.Engine, func(), error) {
|
||||
ctx := s.Ctx
|
||||
menuStorageService := storage.NewMenuStorageService(s.DbDir, s.ResourceDir)
|
||||
|
||||
err := menuStorageService.EnsureDbDir()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
rs, err := menuStorageService.GetResource(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
pe, err := menuStorageService.GetPersister(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
userdatastore, err := menuStorageService.GetUserdataDb(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
dbResource, ok := rs.(*resource.DbResource)
|
||||
if !ok {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
lhs, err := handlers.NewLocalHandlerService(s.FlagFile, true, dbResource, s.Cfg, rs)
|
||||
lhs.SetDataStore(&userdatastore)
|
||||
lhs.SetPersister(pe)
|
||||
lhs.Cfg.SessionId = sessionId
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
hl, err := lhs.GetHandler()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
en := lhs.GetEngine()
|
||||
en = en.WithFirst(hl.Init)
|
||||
if s.Debug {
|
||||
en = en.WithDebug(nil)
|
||||
}
|
||||
// TODO: this is getting very hacky!
|
||||
closer := func() {
|
||||
err := menuStorageService.Close()
|
||||
if err != nil {
|
||||
logg.ErrorCtxf(ctx, "menu storage service cleanup fail", "err", err)
|
||||
}
|
||||
}
|
||||
return en, closer, nil
|
||||
}
|
||||
|
||||
// adapted example from crypto/ssh package, NewServerConn doc
|
||||
func(s *sshRunner) Run(ctx context.Context) {
|
||||
running := true
|
||||
|
||||
// TODO: waitgroup should probably not be global
|
||||
defer wg.Wait()
|
||||
|
||||
auth := NewAuther(ctx)
|
||||
cfg := ssh.ServerConfig{
|
||||
PublicKeyCallback: auth.Check,
|
||||
}
|
||||
|
||||
privateBytes, err := os.ReadFile(s.KeyFile)
|
||||
if err != nil {
|
||||
logg.ErrorCtxf(ctx, "Failed to load private key", "err", err)
|
||||
}
|
||||
private, err := ssh.ParsePrivateKey(privateBytes)
|
||||
if err != nil {
|
||||
logg.ErrorCtxf(ctx, "Failed to parse private key", "err", err)
|
||||
}
|
||||
cfg.AddHostKey(private)
|
||||
|
||||
lst, err := net.Listen("tcp", fmt.Sprintf("%s:%d", s.Host, s.Port))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
for running {
|
||||
conn, err := lst.Accept()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
|
||||
go func(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
for true {
|
||||
srvConn, nC, rC, err := ssh.NewServerConn(conn, &cfg)
|
||||
if err != nil {
|
||||
logg.InfoCtxf(ctx, "rejected client", "err", err)
|
||||
return
|
||||
}
|
||||
logg.DebugCtxf(ctx, "ssh client connected", "conn", srvConn)
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
ssh.DiscardRequests(rC)
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
sessionId, err := auth.FromConn(srvConn)
|
||||
if err != nil {
|
||||
logg.ErrorCtxf(ctx, "Cannot find authentication")
|
||||
return
|
||||
}
|
||||
en, closer, err := s.GetEngine(sessionId)
|
||||
if err != nil {
|
||||
logg.ErrorCtxf(ctx, "engine won't start", "err", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
err := en.Finish()
|
||||
if err != nil {
|
||||
logg.ErrorCtxf(ctx, "engine won't stop", "err", err)
|
||||
}
|
||||
closer()
|
||||
}()
|
||||
for ch := range nC {
|
||||
err = serve(ctx, sessionId, ch, en)
|
||||
logg.ErrorCtxf(ctx, "ssh server finish", "err", err)
|
||||
}
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: This is test code, move to external tool for adding and removing keys
|
||||
func sshLoadKeys(ctx context.Context, dbDir string) error {
|
||||
keyStoreFile := path.Join(dbDir, "ssh_authorized_keys.gdbm")
|
||||
keyStore = storage.NewThreadGdbmDb()
|
||||
err := keyStore.Connect(ctx, keyStoreFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte("ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQCu5rYCxMBsVAL1TEkMQgmElAYEZj5zYDdyHjUxZ6qzHBOZD9GAzdxx9GyQDx2vdYm3329tLH/69ky1YA3nUz8SnJGBD6hC5XrqwN6zo9R9oOHAKTwiPGhey2NTVmheP+9XNHukBnOlkkWOQlpDDvMbWOztaZOWDaA8OIeP0t6qzFqLyelyg65lxzM3BKd7bCmmfzl/64BcP1MotAmB9DUxmY0Wb4Q2hYZfNYBx50Z4xthTgKV+Xoo8CbTduKotIz6hluQGvWdtxlCJQEiZ2f4RYY87JSA6/BAH2fhxuLHMXRpzocJNqARqCWpdcTGSg7bzxbKvTFH9OU4wZtr9ie40OR4zsc1lOBZL0rnp8GLkG8ZmeBQrgEDlmR9TTlz4okgtL+c5TCS37rjZYVjmtGwihws0EL9+wyv2dSQibirklC4wK5eWHKXl5vab19qzw/qRLdoRBK40DxbRKggxA7gqSsKrmrf+z7CuLIz/kxF+169FBLbh1MfBOGdx1awm6aU= lash@furioso"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
k := append([]byte{0x01}, pubKey.Marshal()...)
|
||||
keyStore.SetPrefix(storage.DATATYPE_CUSTOM)
|
||||
return keyStore.Put(ctx, k, []byte("+25113243546"))
|
||||
}
|
||||
|
||||
func main() {
|
||||
var dbDir string
|
||||
var resourceDir string
|
||||
var size uint
|
||||
var engineDebug bool
|
||||
var stateDebug bool
|
||||
var host string
|
||||
var port uint
|
||||
flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from")
|
||||
flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir")
|
||||
flag.BoolVar(&engineDebug, "engine-debug", false, "use engine debug output")
|
||||
flag.BoolVar(&stateDebug, "state-debug", false, "use engine debug output")
|
||||
flag.UintVar(&size, "s", 160, "max size of output")
|
||||
flag.StringVar(&host, "h", "127.0.0.1", "http host")
|
||||
flag.UintVar(&port, "p", 7122, "http port")
|
||||
flag.Parse()
|
||||
|
||||
sshKeyFile := flag.Arg(0)
|
||||
_, err := os.Stat(sshKeyFile)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "cannot open ssh server private key file: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
logg.Infof("start command", "dbdir", dbDir, "resourcedir", resourceDir, "outputsize", size, "keyfile", sshKeyFile, "host", host, "port", port)
|
||||
|
||||
ctx := context.Background()
|
||||
pfp := path.Join(scriptDir, "pp.csv")
|
||||
|
||||
cfg := engine.Config{
|
||||
Root: "root",
|
||||
OutputSize: uint32(size),
|
||||
FlagCount: uint32(16),
|
||||
}
|
||||
if stateDebug {
|
||||
cfg.StateDebug = true
|
||||
}
|
||||
if engineDebug {
|
||||
cfg.EngineDebug = true
|
||||
}
|
||||
|
||||
err = sshLoadKeys(ctx, dbDir)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
runner := &sshRunner{
|
||||
Cfg: cfg,
|
||||
Debug: engineDebug,
|
||||
FlagFile: pfp,
|
||||
DbDir: dbDir,
|
||||
ResourceDir: resourceDir,
|
||||
KeyFile: sshKeyFile,
|
||||
Host: host,
|
||||
Port: port,
|
||||
}
|
||||
runner.Run(ctx)
|
||||
}
|
||||
2
go.mod
2
go.mod
@@ -3,7 +3,7 @@ module git.grassecon.net/urdt/ussd
|
||||
go 1.22.6
|
||||
|
||||
require (
|
||||
git.defalsify.org/vise.git v0.1.0-rc.3.0.20240920144308-b2d2c5f18f38
|
||||
git.defalsify.org/vise.git v0.1.0-rc.3.0.20240911231817-0d23e0dbb57f
|
||||
github.com/alecthomas/assert/v2 v2.2.2
|
||||
github.com/peteole/testdata-loader v0.3.0
|
||||
gopkg.in/leonelquinteros/gotext.v1 v1.3.1
|
||||
|
||||
4
go.sum
4
go.sum
@@ -1,9 +1,5 @@
|
||||
git.defalsify.org/vise.git v0.1.0-rc.3.0.20240911231817-0d23e0dbb57f h1:CuJvG3NyMoRtHUim4aZdrfjjJBg2AId7z0yp7Q97bRM=
|
||||
git.defalsify.org/vise.git v0.1.0-rc.3.0.20240911231817-0d23e0dbb57f/go.mod h1:JDguWmcoWBdsnpw7PUjVZAEpdC/ubBmjdUBy3tjP63M=
|
||||
git.defalsify.org/vise.git v0.1.0-rc.3.0.20240914163514-577f56f43bea h1:6ZYT+dIjd/f5vn9y5AJDZ7SQQckA6w5ZfUoKygyI11o=
|
||||
git.defalsify.org/vise.git v0.1.0-rc.3.0.20240914163514-577f56f43bea/go.mod h1:JDguWmcoWBdsnpw7PUjVZAEpdC/ubBmjdUBy3tjP63M=
|
||||
git.defalsify.org/vise.git v0.1.0-rc.3.0.20240920144308-b2d2c5f18f38 h1:4aAZijIcq33ixnZ+U48ckDIkwSfZL3St/CqoXZcC5K8=
|
||||
git.defalsify.org/vise.git v0.1.0-rc.3.0.20240920144308-b2d2c5f18f38/go.mod h1:JDguWmcoWBdsnpw7PUjVZAEpdC/ubBmjdUBy3tjP63M=
|
||||
github.com/alecthomas/assert/v2 v2.2.2 h1:Z/iVC0xZfWTaFNE6bA3z07T86hd45Xe2eLt6WVy2bbk=
|
||||
github.com/alecthomas/assert/v2 v2.2.2/go.mod h1:pXcQ2Asjp247dahGEmsZ6ru0UVwnkhktn7S0bBDLxvQ=
|
||||
github.com/alecthomas/participle/v2 v2.0.0 h1:Fgrq+MbuSsJwIkw3fEj9h75vDP0Er5JzepJ0/HNHv0g=
|
||||
|
||||
@@ -71,7 +71,19 @@ func(f *BaseSessionHandler) Process(rqs RequestSession) (RequestSession, error)
|
||||
}
|
||||
rqs.Engine = en
|
||||
|
||||
r, err = rqs.Engine.Exec(rqs.Ctx, rqs.Input)
|
||||
r, err = rqs.Engine.Init(rqs.Ctx)
|
||||
if err != nil {
|
||||
perr := f.provider.Put(rqs.Config.SessionId, rqs.Storage)
|
||||
rqs.Storage = nil
|
||||
if perr != nil {
|
||||
logg.ErrorCtxf(rqs.Ctx, "", "storage put error", perr)
|
||||
}
|
||||
return rqs, err
|
||||
}
|
||||
|
||||
if r && len(rqs.Input) > 0 {
|
||||
r, err = rqs.Engine.Exec(rqs.Ctx, rqs.Input)
|
||||
}
|
||||
if err != nil {
|
||||
perr := f.provider.Put(rqs.Config.SessionId, rqs.Storage)
|
||||
rqs.Storage = nil
|
||||
@@ -87,7 +99,7 @@ func(f *BaseSessionHandler) Process(rqs RequestSession) (RequestSession, error)
|
||||
|
||||
func(f *BaseSessionHandler) Output(rqs RequestSession) (RequestSession, error) {
|
||||
var err error
|
||||
_, err = rqs.Engine.Flush(rqs.Ctx, rqs.Writer)
|
||||
_, err = rqs.Engine.WriteResult(rqs.Ctx, rqs.Writer)
|
||||
return rqs, err
|
||||
}
|
||||
|
||||
|
||||
@@ -1,106 +0,0 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"git.defalsify.org/vise.git/asm"
|
||||
"git.defalsify.org/vise.git/db"
|
||||
"git.defalsify.org/vise.git/engine"
|
||||
"git.defalsify.org/vise.git/persist"
|
||||
"git.defalsify.org/vise.git/resource"
|
||||
"git.grassecon.net/urdt/ussd/internal/handlers/ussd"
|
||||
)
|
||||
|
||||
type HandlerService interface {
|
||||
GetHandler() (*ussd.Handlers, error)
|
||||
}
|
||||
|
||||
func getParser(fp string, debug bool) (*asm.FlagParser, error) {
|
||||
flagParser := asm.NewFlagParser().WithDebug()
|
||||
_, err := flagParser.Load(fp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return flagParser, nil
|
||||
}
|
||||
|
||||
type LocalHandlerService struct {
|
||||
Parser *asm.FlagParser
|
||||
DbRs *resource.DbResource
|
||||
Pe *persist.Persister
|
||||
UserdataStore *db.Db
|
||||
Cfg engine.Config
|
||||
Rs resource.Resource
|
||||
}
|
||||
|
||||
func NewLocalHandlerService(fp string, debug bool, dbResource *resource.DbResource, cfg engine.Config, rs resource.Resource) (*LocalHandlerService, error) {
|
||||
parser, err := getParser(fp, debug)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &LocalHandlerService{
|
||||
Parser: parser,
|
||||
DbRs: dbResource,
|
||||
Cfg: cfg,
|
||||
Rs: rs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (ls *LocalHandlerService) SetPersister(Pe *persist.Persister) {
|
||||
ls.Pe = Pe
|
||||
}
|
||||
|
||||
func (ls *LocalHandlerService) SetDataStore(db *db.Db) {
|
||||
ls.UserdataStore = db
|
||||
}
|
||||
|
||||
func (ls *LocalHandlerService) GetHandler() (*ussd.Handlers, error) {
|
||||
ussdHandlers, err := ussd.NewHandlers(ls.Parser, *ls.UserdataStore)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ussdHandlers = ussdHandlers.WithPersister(ls.Pe)
|
||||
ls.DbRs.AddLocalFunc("set_language", ussdHandlers.SetLanguage)
|
||||
ls.DbRs.AddLocalFunc("create_account", ussdHandlers.CreateAccount)
|
||||
ls.DbRs.AddLocalFunc("save_pin", ussdHandlers.SavePin)
|
||||
ls.DbRs.AddLocalFunc("verify_pin", ussdHandlers.VerifyPin)
|
||||
ls.DbRs.AddLocalFunc("check_identifier", ussdHandlers.CheckIdentifier)
|
||||
ls.DbRs.AddLocalFunc("check_account_status", ussdHandlers.CheckAccountStatus)
|
||||
ls.DbRs.AddLocalFunc("authorize_account", ussdHandlers.Authorize)
|
||||
ls.DbRs.AddLocalFunc("quit", ussdHandlers.Quit)
|
||||
ls.DbRs.AddLocalFunc("check_balance", ussdHandlers.CheckBalance)
|
||||
ls.DbRs.AddLocalFunc("validate_recipient", ussdHandlers.ValidateRecipient)
|
||||
ls.DbRs.AddLocalFunc("transaction_reset", ussdHandlers.TransactionReset)
|
||||
ls.DbRs.AddLocalFunc("max_amount", ussdHandlers.MaxAmount)
|
||||
ls.DbRs.AddLocalFunc("validate_amount", ussdHandlers.ValidateAmount)
|
||||
ls.DbRs.AddLocalFunc("reset_transaction_amount", ussdHandlers.ResetTransactionAmount)
|
||||
ls.DbRs.AddLocalFunc("get_recipient", ussdHandlers.GetRecipient)
|
||||
ls.DbRs.AddLocalFunc("get_sender", ussdHandlers.GetSender)
|
||||
ls.DbRs.AddLocalFunc("get_amount", ussdHandlers.GetAmount)
|
||||
ls.DbRs.AddLocalFunc("reset_incorrect", ussdHandlers.ResetIncorrectPin)
|
||||
ls.DbRs.AddLocalFunc("save_firstname", ussdHandlers.SaveFirstname)
|
||||
ls.DbRs.AddLocalFunc("save_familyname", ussdHandlers.SaveFamilyname)
|
||||
ls.DbRs.AddLocalFunc("save_gender", ussdHandlers.SaveGender)
|
||||
ls.DbRs.AddLocalFunc("save_location", ussdHandlers.SaveLocation)
|
||||
ls.DbRs.AddLocalFunc("save_yob", ussdHandlers.SaveYob)
|
||||
ls.DbRs.AddLocalFunc("save_offerings", ussdHandlers.SaveOfferings)
|
||||
ls.DbRs.AddLocalFunc("quit_with_balance", ussdHandlers.QuitWithBalance)
|
||||
ls.DbRs.AddLocalFunc("reset_account_authorized", ussdHandlers.ResetAccountAuthorized)
|
||||
ls.DbRs.AddLocalFunc("reset_allow_update", ussdHandlers.ResetAllowUpdate)
|
||||
ls.DbRs.AddLocalFunc("get_profile_info", ussdHandlers.GetProfileInfo)
|
||||
ls.DbRs.AddLocalFunc("verify_yob", ussdHandlers.VerifyYob)
|
||||
ls.DbRs.AddLocalFunc("reset_incorrect_date_format", ussdHandlers.ResetIncorrectYob)
|
||||
ls.DbRs.AddLocalFunc("set_reset_single_edit", ussdHandlers.SetResetSingleEdit)
|
||||
ls.DbRs.AddLocalFunc("initiate_transaction", ussdHandlers.InitiateTransaction)
|
||||
ls.DbRs.AddLocalFunc("save_temporary_pin", ussdHandlers.SaveTemporaryPin)
|
||||
ls.DbRs.AddLocalFunc("verify_new_pin", ussdHandlers.VerifyNewPin)
|
||||
ls.DbRs.AddLocalFunc("confirm_pin_change", ussdHandlers.ConfirmPinChange)
|
||||
ls.DbRs.AddLocalFunc("quit_with_help", ussdHandlers.QuitWithHelp)
|
||||
|
||||
return ussdHandlers, nil
|
||||
}
|
||||
|
||||
// TODO: enable setting of sessionId on engine init time
|
||||
func (ls *LocalHandlerService) GetEngine() *engine.DefaultEngine {
|
||||
en := engine.NewEngine(ls.Cfg, ls.Rs)
|
||||
en = en.WithPersister(ls.Pe)
|
||||
return en
|
||||
}
|
||||
@@ -37,6 +37,8 @@ type RequestSession struct {
|
||||
Continue bool
|
||||
}
|
||||
|
||||
type engineMaker func(cfg engine.Config, rs resource.Resource, pr *persist.Persister) engine.Engine
|
||||
|
||||
// TODO: seems like can remove this.
|
||||
type RequestParser interface {
|
||||
GetSessionId(rq any) (string, error)
|
||||
|
||||
@@ -29,6 +29,11 @@ var (
|
||||
translationDir = path.Join(scriptDir, "locale")
|
||||
)
|
||||
|
||||
type FSData struct {
|
||||
Path string
|
||||
St *state.State
|
||||
}
|
||||
|
||||
// FlagManager handles centralized flag management
|
||||
type FlagManager struct {
|
||||
parser *asm.FlagParser
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"git.defalsify.org/vise.git/db"
|
||||
"git.defalsify.org/vise.git/resource"
|
||||
"git.defalsify.org/vise.git/state"
|
||||
"git.grassecon.net/urdt/ussd/internal/mocks"
|
||||
"git.grassecon.net/urdt/ussd/internal/handlers/ussd/mocks"
|
||||
"git.grassecon.net/urdt/ussd/internal/models"
|
||||
"git.grassecon.net/urdt/ussd/internal/utils"
|
||||
"github.com/alecthomas/assert/v2"
|
||||
@@ -349,7 +349,7 @@ func TestSaveGender(t *testing.T) {
|
||||
}
|
||||
|
||||
// Call the method
|
||||
_, err := h.SaveGender(ctx, "someSym", tt.input)
|
||||
_, err := h.SaveGender(ctx, "save_gender", tt.input)
|
||||
|
||||
// Assert no error
|
||||
assert.NoError(t, err)
|
||||
@@ -538,9 +538,9 @@ func TestSetLanguage(t *testing.T) {
|
||||
}
|
||||
// Define test cases
|
||||
tests := []struct {
|
||||
name string
|
||||
execPath []string
|
||||
expectedResult resource.Result
|
||||
name string
|
||||
execPath []string
|
||||
expectedResult resource.Result
|
||||
}{
|
||||
{
|
||||
name: "Set Default Language (English)",
|
||||
@@ -1101,18 +1101,26 @@ func TestCheckAccountStatus(t *testing.T) {
|
||||
FlagReset: []uint32{flag_account_pending},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Test when account status is not a success",
|
||||
input: []byte("TrackingId12"),
|
||||
status: "REVERTED",
|
||||
expectedResult: resource.Result{
|
||||
FlagSet: []uint32{flag_account_success},
|
||||
FlagReset: []uint32{flag_account_pending},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
typ := utils.DATA_TRACKING_ID
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockCreateAccountService.On("CheckAccountStatus", string(tt.input)).Return(tt.status, nil)
|
||||
|
||||
mockDataStore.On("WriteEntry", ctx, sessionId, utils.DATA_ACCOUNT_STATUS, []byte(tt.status)).Return(nil).Maybe()
|
||||
// Define expected interactions with the mock
|
||||
mockDataStore.On("ReadEntry", ctx, sessionId, typ).Return(tt.input, nil)
|
||||
|
||||
mockCreateAccountService.On("CheckAccountStatus", string(tt.input)).Return(tt.status, nil)
|
||||
mockDataStore.On("WriteEntry", ctx, sessionId, utils.DATA_ACCOUNT_STATUS, []byte(tt.status)).Return(nil)
|
||||
|
||||
// Call the method under test
|
||||
res, _ := h.CheckAccountStatus(ctx, "check_status", tt.input)
|
||||
|
||||
@@ -1289,7 +1297,6 @@ func TestInitiateTransaction(t *testing.T) {
|
||||
mockDataStore.On("ReadEntry", ctx, sessionId, utils.DATA_PUBLIC_KEY).Return(tt.PublicKey, nil)
|
||||
mockDataStore.On("ReadEntry", ctx, sessionId, utils.DATA_AMOUNT).Return(tt.Amount, nil)
|
||||
mockDataStore.On("ReadEntry", ctx, sessionId, utils.DATA_RECIPIENT).Return(tt.Recipient, nil)
|
||||
//mockDataStore.On("WriteEntry", ctx, sessionId, utils.DATA_AMOUNT, []byte("")).Return(nil)
|
||||
|
||||
// Call the method under test
|
||||
res, _ := h.InitiateTransaction(ctx, "transaction_reset_amount", tt.input)
|
||||
@@ -1480,7 +1487,7 @@ func TestValidateAmount(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Logf(err.Error())
|
||||
}
|
||||
//flag_invalid_amount, _ := fm.parser.GetFlag("flag_invalid_amount")
|
||||
flag_invalid_amount, _ := fm.parser.GetFlag("flag_invalid_amount")
|
||||
mockDataStore := new(mocks.MockUserDataStore)
|
||||
mockCreateAccountService := new(mocks.MockAccountService)
|
||||
|
||||
@@ -1509,26 +1516,26 @@ func TestValidateAmount(t *testing.T) {
|
||||
Content: "0.001",
|
||||
},
|
||||
},
|
||||
// {
|
||||
// name: "Test with amount larger than balance",
|
||||
// input: []byte("0.02"),
|
||||
// balance: "0.003 CELO",
|
||||
// publicKey: []byte("0xrqeqrequuq"),
|
||||
// expectedResult: resource.Result{
|
||||
// FlagSet: []uint32{flag_invalid_amount},
|
||||
// Content: "0.02",
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// name: "Test with invalid amount",
|
||||
// input: []byte("0.02ms"),
|
||||
// balance: "0.003 CELO",
|
||||
// publicKey: []byte("0xrqeqrequuq"),
|
||||
// expectedResult: resource.Result{
|
||||
// FlagSet: []uint32{flag_invalid_amount},
|
||||
// Content: "0.02ms",
|
||||
// },
|
||||
// },
|
||||
{
|
||||
name: "Test with amount larger than balance",
|
||||
input: []byte("0.02"),
|
||||
balance: "0.003 CELO",
|
||||
publicKey: []byte("0xrqeqrequuq"),
|
||||
expectedResult: resource.Result{
|
||||
FlagSet: []uint32{flag_invalid_amount},
|
||||
Content: "0.02",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Test with invalid amount",
|
||||
input: []byte("0.02ms"),
|
||||
balance: "0.003 CELO",
|
||||
publicKey: []byte("0xrqeqrequuq"),
|
||||
expectedResult: resource.Result{
|
||||
FlagSet: []uint32{flag_invalid_amount},
|
||||
Content: "0.02ms",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -1536,7 +1543,7 @@ func TestValidateAmount(t *testing.T) {
|
||||
|
||||
mockDataStore.On("ReadEntry", ctx, sessionId, utils.DATA_PUBLIC_KEY).Return(tt.publicKey, nil)
|
||||
mockCreateAccountService.On("CheckBalance", string(tt.publicKey)).Return(tt.balance, nil)
|
||||
mockDataStore.On("WriteEntry", ctx, sessionId, utils.DATA_AMOUNT, tt.input).Return(nil)
|
||||
mockDataStore.On("WriteEntry", ctx, sessionId, utils.DATA_AMOUNT, tt.input).Return(nil).Maybe()
|
||||
|
||||
// Call the method under test
|
||||
res, _ := h.ValidateAmount(ctx, "test_validate_amount", tt.input)
|
||||
@@ -1812,11 +1819,19 @@ func TestConfirmPin(t *testing.T) {
|
||||
FlagReset: []uint32{flag_pin_mismatch},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Test with different pin confirmation",
|
||||
input: []byte("1234"),
|
||||
temporarypin: []byte("12345"),
|
||||
expectedResult: resource.Result{
|
||||
FlagSet: []uint32{flag_pin_mismatch},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Set up the expected behavior of the mock
|
||||
mockDataStore.On("WriteEntry", ctx, sessionId, utils.DATA_ACCOUNT_PIN, []byte(tt.temporarypin)).Return(nil)
|
||||
mockDataStore.On("WriteEntry", ctx, sessionId, utils.DATA_ACCOUNT_PIN, []byte(tt.temporarypin)).Return(nil).Maybe()
|
||||
|
||||
mockDataStore.On("ReadEntry", ctx, sessionId, utils.DATA_TEMPORARY_PIN).Return(tt.temporarypin, nil)
|
||||
|
||||
|
||||
@@ -32,7 +32,6 @@ func (ash *ATSessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request)
|
||||
if err != nil {
|
||||
logg.ErrorCtxf(rqs.Ctx, "", "header processing error", err)
|
||||
ash.writeError(w, 400, err)
|
||||
return
|
||||
}
|
||||
rqs.Config = cfg
|
||||
rqs.Input, err = rp.GetInput(req)
|
||||
@@ -42,14 +41,16 @@ func (ash *ATSessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
rqs, err = ash.Process(rqs)
|
||||
rqs, err = ash.Process(rqs)
|
||||
switch err {
|
||||
case nil: // set code to 200 if no err
|
||||
code = 200
|
||||
case handlers.ErrStorage, handlers.ErrEngineInit, handlers.ErrEngineExec, handlers.ErrEngineType:
|
||||
case handlers.ErrStorage:
|
||||
code = 500
|
||||
case handlers.ErrEngineInit:
|
||||
code = 500
|
||||
case handlers.ErrEngineExec:
|
||||
code = 500
|
||||
default:
|
||||
code = 500
|
||||
code = 200
|
||||
}
|
||||
|
||||
if code != 200 {
|
||||
@@ -87,6 +88,6 @@ func (ash *ATSessionHandler) Output(rqs handlers.RequestSession) (handlers.Reque
|
||||
return rqs, err
|
||||
}
|
||||
|
||||
_, err = rqs.Engine.Flush(rqs.Ctx, rqs.Writer)
|
||||
_, err = rqs.Engine.WriteResult(rqs.Ctx, rqs.Writer)
|
||||
return rqs, err
|
||||
}
|
||||
@@ -1,449 +0,0 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"git.defalsify.org/vise.git/engine"
|
||||
"git.grassecon.net/urdt/ussd/internal/handlers"
|
||||
"git.grassecon.net/urdt/ussd/internal/mocks/httpmocks"
|
||||
)
|
||||
|
||||
// invalidRequestType is a custom type to test invalid request scenarios
|
||||
type invalidRequestType struct{}
|
||||
|
||||
// errorReader is a helper type that always returns an error when Read is called
|
||||
type errorReader struct{}
|
||||
|
||||
func (e *errorReader) Read(p []byte) (n int, err error) {
|
||||
return 0, errors.New("read error")
|
||||
}
|
||||
|
||||
func TestNewATSessionHandler(t *testing.T) {
|
||||
mockHandler := &httpmocks.MockRequestHandler{}
|
||||
ash := NewATSessionHandler(mockHandler)
|
||||
|
||||
if ash == nil {
|
||||
t.Fatal("NewATSessionHandler returned nil")
|
||||
}
|
||||
|
||||
if ash.SessionHandler == nil {
|
||||
t.Fatal("SessionHandler is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestATSessionHandler_ServeHTTP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupMocks func(*httpmocks.MockRequestHandler, *httpmocks.MockRequestParser, *httpmocks.MockEngine)
|
||||
formData url.Values
|
||||
expectedStatus int
|
||||
expectedBody string
|
||||
}{
|
||||
{
|
||||
name: "Successful request",
|
||||
setupMocks: func(mh *httpmocks.MockRequestHandler, mrp *httpmocks.MockRequestParser, me *httpmocks.MockEngine) {
|
||||
mrp.GetSessionIdFunc = func(rq any) (string, error) {
|
||||
req := rq.(*http.Request)
|
||||
return req.FormValue("phoneNumber"), nil
|
||||
}
|
||||
mrp.GetInputFunc = func(rq any) ([]byte, error) {
|
||||
req := rq.(*http.Request)
|
||||
text := req.FormValue("text")
|
||||
parts := strings.Split(text, "*")
|
||||
return []byte(parts[len(parts)-1]), nil
|
||||
}
|
||||
mh.ProcessFunc = func(rqs handlers.RequestSession) (handlers.RequestSession, error) {
|
||||
rqs.Continue = true
|
||||
rqs.Engine = me
|
||||
return rqs, nil
|
||||
}
|
||||
mh.GetConfigFunc = func() engine.Config { return engine.Config{} }
|
||||
mh.GetRequestParserFunc = func() handlers.RequestParser { return mrp }
|
||||
mh.OutputFunc = func(rs handlers.RequestSession) (handlers.RequestSession, error) { return rs, nil }
|
||||
mh.ResetFunc = func(rs handlers.RequestSession) (handlers.RequestSession, error) { return rs, nil }
|
||||
me.FlushFunc = func(context.Context, io.Writer) (int, error) { return 0, nil }
|
||||
},
|
||||
formData: url.Values{
|
||||
"phoneNumber": []string{"+1234567890"},
|
||||
"text": []string{"1*2*3"},
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: "CON ",
|
||||
},
|
||||
{
|
||||
name: "GetSessionId error",
|
||||
setupMocks: func(mh *httpmocks.MockRequestHandler, mrp *httpmocks.MockRequestParser, me *httpmocks.MockEngine) {
|
||||
mrp.GetSessionIdFunc = func(rq any) (string, error) {
|
||||
return "", errors.New("no phone number found")
|
||||
}
|
||||
mh.GetConfigFunc = func() engine.Config { return engine.Config{} }
|
||||
mh.GetRequestParserFunc = func() handlers.RequestParser { return mrp }
|
||||
},
|
||||
formData: url.Values{
|
||||
"text": []string{"1*2*3"},
|
||||
},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedBody: "",
|
||||
},
|
||||
{
|
||||
name: "GetInput error",
|
||||
setupMocks: func(mh *httpmocks.MockRequestHandler, mrp *httpmocks.MockRequestParser, me *httpmocks.MockEngine) {
|
||||
mrp.GetSessionIdFunc = func(rq any) (string, error) {
|
||||
req := rq.(*http.Request)
|
||||
return req.FormValue("phoneNumber"), nil
|
||||
}
|
||||
mrp.GetInputFunc = func(rq any) ([]byte, error) {
|
||||
return nil, errors.New("no input found")
|
||||
}
|
||||
mh.GetConfigFunc = func() engine.Config { return engine.Config{} }
|
||||
mh.GetRequestParserFunc = func() handlers.RequestParser { return mrp }
|
||||
},
|
||||
formData: url.Values{
|
||||
"phoneNumber": []string{"+1234567890"},
|
||||
},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedBody: "",
|
||||
},
|
||||
{
|
||||
name: "Process error",
|
||||
setupMocks: func(mh *httpmocks.MockRequestHandler, mrp *httpmocks.MockRequestParser, me *httpmocks.MockEngine) {
|
||||
mrp.GetSessionIdFunc = func(rq any) (string, error) {
|
||||
req := rq.(*http.Request)
|
||||
return req.FormValue("phoneNumber"), nil
|
||||
}
|
||||
mrp.GetInputFunc = func(rq any) ([]byte, error) {
|
||||
req := rq.(*http.Request)
|
||||
text := req.FormValue("text")
|
||||
parts := strings.Split(text, "*")
|
||||
return []byte(parts[len(parts)-1]), nil
|
||||
}
|
||||
mh.ProcessFunc = func(rqs handlers.RequestSession) (handlers.RequestSession, error) {
|
||||
return rqs, handlers.ErrStorage
|
||||
}
|
||||
mh.GetConfigFunc = func() engine.Config { return engine.Config{} }
|
||||
mh.GetRequestParserFunc = func() handlers.RequestParser { return mrp }
|
||||
},
|
||||
formData: url.Values{
|
||||
"phoneNumber": []string{"+1234567890"},
|
||||
"text": []string{"1*2*3"},
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockHandler := &httpmocks.MockRequestHandler{}
|
||||
mockRequestParser := &httpmocks.MockRequestParser{}
|
||||
mockEngine := &httpmocks.MockEngine{}
|
||||
tt.setupMocks(mockHandler, mockRequestParser, mockEngine)
|
||||
|
||||
ash := NewATSessionHandler(mockHandler)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(tt.formData.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
ash.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != tt.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
|
||||
}
|
||||
|
||||
if tt.expectedBody != "" && w.Body.String() != tt.expectedBody {
|
||||
t.Errorf("Expected body %q, got %q", tt.expectedBody, w.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestATSessionHandler_Output(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input handlers.RequestSession
|
||||
expectedPrefix string
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "Continue true",
|
||||
input: handlers.RequestSession{
|
||||
Continue: true,
|
||||
Engine: &httpmocks.MockEngine{
|
||||
FlushFunc: func(context.Context, io.Writer) (int, error) {
|
||||
return 0, nil
|
||||
},
|
||||
},
|
||||
Writer: &httpmocks.MockWriter{},
|
||||
},
|
||||
expectedPrefix: "CON ",
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "Continue false",
|
||||
input: handlers.RequestSession{
|
||||
Continue: false,
|
||||
Engine: &httpmocks.MockEngine{
|
||||
FlushFunc: func(context.Context, io.Writer) (int, error) {
|
||||
return 0, nil
|
||||
},
|
||||
},
|
||||
Writer: &httpmocks.MockWriter{},
|
||||
},
|
||||
expectedPrefix: "END ",
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "Flush error",
|
||||
input: handlers.RequestSession{
|
||||
Continue: true,
|
||||
Engine: &httpmocks.MockEngine{
|
||||
FlushFunc: func(context.Context, io.Writer) (int, error) {
|
||||
return 0, errors.New("write error")
|
||||
},
|
||||
},
|
||||
Writer: &httpmocks.MockWriter{},
|
||||
},
|
||||
expectedPrefix: "CON ",
|
||||
expectedError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ash := &ATSessionHandler{}
|
||||
_, err := ash.Output(tt.input)
|
||||
|
||||
if tt.expectedError && err == nil {
|
||||
t.Error("Expected an error, but got nil")
|
||||
}
|
||||
|
||||
if !tt.expectedError && err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
mw := tt.input.Writer.(*httpmocks.MockWriter)
|
||||
if !mw.WriteStringCalled {
|
||||
t.Error("WriteString was not called")
|
||||
}
|
||||
|
||||
if mw.WrittenString != tt.expectedPrefix {
|
||||
t.Errorf("Expected prefix %q, got %q", tt.expectedPrefix, mw.WrittenString)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionHandler_ServeHTTP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sessionID string
|
||||
input []byte
|
||||
parserErr error
|
||||
processErr error
|
||||
outputErr error
|
||||
resetErr error
|
||||
expectedStatus int
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
sessionID: "123",
|
||||
input: []byte("test input"),
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Missing Session ID",
|
||||
sessionID: "",
|
||||
parserErr: handlers.ErrSessionMissing,
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "Process Error",
|
||||
sessionID: "123",
|
||||
input: []byte("test input"),
|
||||
processErr: handlers.ErrStorage,
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
},
|
||||
{
|
||||
name: "Output Error",
|
||||
sessionID: "123",
|
||||
input: []byte("test input"),
|
||||
outputErr: errors.New("output error"),
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Reset Error",
|
||||
sessionID: "123",
|
||||
input: []byte("test input"),
|
||||
resetErr: errors.New("reset error"),
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockRequestParser := &httpmocks.MockRequestParser{
|
||||
GetSessionIdFunc: func(any) (string, error) {
|
||||
return tt.sessionID, tt.parserErr
|
||||
},
|
||||
GetInputFunc: func(any) ([]byte, error) {
|
||||
return tt.input, nil
|
||||
},
|
||||
}
|
||||
|
||||
mockRequestHandler := &httpmocks.MockRequestHandler{
|
||||
ProcessFunc: func(rs handlers.RequestSession) (handlers.RequestSession, error) {
|
||||
return rs, tt.processErr
|
||||
},
|
||||
OutputFunc: func(rs handlers.RequestSession) (handlers.RequestSession, error) {
|
||||
return rs, tt.outputErr
|
||||
},
|
||||
ResetFunc: func(rs handlers.RequestSession) (handlers.RequestSession, error) {
|
||||
return rs, tt.resetErr
|
||||
},
|
||||
GetRequestParserFunc: func() handlers.RequestParser {
|
||||
return mockRequestParser
|
||||
},
|
||||
GetConfigFunc: func() engine.Config {
|
||||
return engine.Config{}
|
||||
},
|
||||
}
|
||||
|
||||
sessionHandler := ToSessionHandler(mockRequestHandler)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBuffer(tt.input))
|
||||
req.Header.Set("X-Vise-Session", tt.sessionID)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
sessionHandler.ServeHTTP(rr, req)
|
||||
|
||||
if status := rr.Code; status != tt.expectedStatus {
|
||||
t.Errorf("handler returned wrong status code: got %v want %v",
|
||||
status, tt.expectedStatus)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionHandler_writeError(t *testing.T) {
|
||||
handler := &SessionHandler{}
|
||||
mockWriter := &httpmocks.MockWriter{}
|
||||
err := errors.New("test error")
|
||||
|
||||
handler.writeError(mockWriter, http.StatusBadRequest, err)
|
||||
|
||||
if mockWriter.WrittenString != "" {
|
||||
t.Errorf("Expected empty body, got %s", mockWriter.WrittenString)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultRequestParser_GetSessionId(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
request any
|
||||
expectedID string
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
name: "Valid Session ID",
|
||||
request: func() *http.Request {
|
||||
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
req.Header.Set("X-Vise-Session", "123456")
|
||||
return req
|
||||
}(),
|
||||
expectedID: "123456",
|
||||
expectedError: nil,
|
||||
},
|
||||
{
|
||||
name: "Missing Session ID",
|
||||
request: httptest.NewRequest(http.MethodPost, "/", nil),
|
||||
expectedID: "",
|
||||
expectedError: handlers.ErrSessionMissing,
|
||||
},
|
||||
{
|
||||
name: "Invalid Request Type",
|
||||
request: invalidRequestType{},
|
||||
expectedID: "",
|
||||
expectedError: handlers.ErrInvalidRequest,
|
||||
},
|
||||
}
|
||||
|
||||
parser := &DefaultRequestParser{}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
id, err := parser.GetSessionId(tt.request)
|
||||
|
||||
if id != tt.expectedID {
|
||||
t.Errorf("Expected session ID %s, got %s", tt.expectedID, id)
|
||||
}
|
||||
|
||||
if err != tt.expectedError {
|
||||
t.Errorf("Expected error %v, got %v", tt.expectedError, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultRequestParser_GetInput(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
request any
|
||||
expectedInput []byte
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
name: "Valid Input",
|
||||
request: func() *http.Request {
|
||||
return httptest.NewRequest(http.MethodPost, "/", bytes.NewBufferString("test input"))
|
||||
}(),
|
||||
expectedInput: []byte("test input"),
|
||||
expectedError: nil,
|
||||
},
|
||||
{
|
||||
name: "Empty Input",
|
||||
request: httptest.NewRequest(http.MethodPost, "/", nil),
|
||||
expectedInput: []byte{},
|
||||
expectedError: nil,
|
||||
},
|
||||
{
|
||||
name: "Invalid Request Type",
|
||||
request: invalidRequestType{},
|
||||
expectedInput: nil,
|
||||
expectedError: handlers.ErrInvalidRequest,
|
||||
},
|
||||
{
|
||||
name: "Read Error",
|
||||
request: func() *http.Request {
|
||||
return httptest.NewRequest(http.MethodPost, "/", &errorReader{})
|
||||
}(),
|
||||
expectedInput: nil,
|
||||
expectedError: errors.New("read error"),
|
||||
},
|
||||
}
|
||||
|
||||
parser := &DefaultRequestParser{}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
input, err := parser.GetInput(tt.request)
|
||||
|
||||
if !bytes.Equal(input, tt.expectedInput) {
|
||||
t.Errorf("Expected input %s, got %s", tt.expectedInput, input)
|
||||
}
|
||||
|
||||
if err != tt.expectedError && (err == nil || err.Error() != tt.expectedError.Error()) {
|
||||
t.Errorf("Expected error %v, got %v", tt.expectedError, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
package httpmocks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
)
|
||||
|
||||
// MockEngine implements the engine.Engine interface for testing
|
||||
type MockEngine struct {
|
||||
InitFunc func(context.Context) (bool, error)
|
||||
ExecFunc func(context.Context, []byte) (bool, error)
|
||||
FlushFunc func(context.Context, io.Writer) (int, error)
|
||||
FinishFunc func() error
|
||||
}
|
||||
|
||||
func (m *MockEngine) Init(ctx context.Context) (bool, error) {
|
||||
return m.InitFunc(ctx)
|
||||
}
|
||||
|
||||
func (m *MockEngine) Exec(ctx context.Context, input []byte) (bool, error) {
|
||||
return m.ExecFunc(ctx, input)
|
||||
}
|
||||
|
||||
func (m *MockEngine) Flush(ctx context.Context, w io.Writer) (int, error) {
|
||||
return m.FlushFunc(ctx, w)
|
||||
}
|
||||
|
||||
func (m *MockEngine) Finish() error {
|
||||
return m.FinishFunc()
|
||||
}
|
||||
@@ -1,47 +0,0 @@
|
||||
package httpmocks
|
||||
|
||||
import (
|
||||
"git.defalsify.org/vise.git/engine"
|
||||
"git.defalsify.org/vise.git/persist"
|
||||
"git.defalsify.org/vise.git/resource"
|
||||
"git.grassecon.net/urdt/ussd/internal/handlers"
|
||||
)
|
||||
|
||||
// MockRequestHandler implements handlers.RequestHandler interface for testing
|
||||
type MockRequestHandler struct {
|
||||
ProcessFunc func(handlers.RequestSession) (handlers.RequestSession, error)
|
||||
GetConfigFunc func() engine.Config
|
||||
GetEngineFunc func(cfg engine.Config, rs resource.Resource, pe *persist.Persister) engine.Engine
|
||||
OutputFunc func(rs handlers.RequestSession) (handlers.RequestSession, error)
|
||||
ResetFunc func(rs handlers.RequestSession) (handlers.RequestSession, error)
|
||||
ShutdownFunc func()
|
||||
GetRequestParserFunc func() handlers.RequestParser
|
||||
}
|
||||
|
||||
func (m *MockRequestHandler) Process(rqs handlers.RequestSession) (handlers.RequestSession, error) {
|
||||
return m.ProcessFunc(rqs)
|
||||
}
|
||||
|
||||
func (m *MockRequestHandler) GetConfig() engine.Config {
|
||||
return m.GetConfigFunc()
|
||||
}
|
||||
|
||||
func (m *MockRequestHandler) GetEngine(cfg engine.Config, rs resource.Resource, pe *persist.Persister) engine.Engine {
|
||||
return m.GetEngineFunc(cfg, rs, pe)
|
||||
}
|
||||
|
||||
func (m *MockRequestHandler) Output(rs handlers.RequestSession) (handlers.RequestSession, error) {
|
||||
return m.OutputFunc(rs)
|
||||
}
|
||||
|
||||
func (m *MockRequestHandler) Reset(rs handlers.RequestSession) (handlers.RequestSession, error) {
|
||||
return m.ResetFunc(rs)
|
||||
}
|
||||
|
||||
func (m *MockRequestHandler) Shutdown() {
|
||||
m.ShutdownFunc()
|
||||
}
|
||||
|
||||
func (m *MockRequestHandler) GetRequestParser() handlers.RequestParser {
|
||||
return m.GetRequestParserFunc()
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
package httpmocks
|
||||
|
||||
// MockRequestParser implements the handlers.RequestParser interface for testing
|
||||
type MockRequestParser struct {
|
||||
GetSessionIdFunc func(any) (string, error)
|
||||
GetInputFunc func(any) ([]byte, error)
|
||||
}
|
||||
|
||||
func (m *MockRequestParser) GetSessionId(rq any) (string, error) {
|
||||
return m.GetSessionIdFunc(rq)
|
||||
}
|
||||
|
||||
func (m *MockRequestParser) GetInput(rq any) ([]byte, error) {
|
||||
return m.GetInputFunc(rq)
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
package httpmocks
|
||||
|
||||
import "net/http"
|
||||
|
||||
// MockWriter implements a mock io.Writer for testing
|
||||
type MockWriter struct {
|
||||
WriteStringCalled bool
|
||||
WrittenString string
|
||||
}
|
||||
|
||||
func (m *MockWriter) Write(p []byte) (n int, err error) {
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (m *MockWriter) WriteString(s string) (n int, err error) {
|
||||
m.WriteStringCalled = true
|
||||
m.WrittenString = s
|
||||
return len(s), nil
|
||||
}
|
||||
|
||||
func (m *MockWriter) Header() http.Header {
|
||||
return http.Header{}
|
||||
}
|
||||
|
||||
func (m *MockWriter) WriteHeader(statusCode int) {}
|
||||
@@ -1,116 +0,0 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"git.defalsify.org/vise.git/db"
|
||||
"git.defalsify.org/vise.git/lang"
|
||||
gdbmdb "git.defalsify.org/vise.git/db/gdbm"
|
||||
)
|
||||
|
||||
var (
|
||||
dbC map[string]chan db.Db
|
||||
)
|
||||
|
||||
type ThreadGdbmDb struct {
|
||||
db db.Db
|
||||
registered bool
|
||||
connStr string
|
||||
}
|
||||
|
||||
func NewThreadGdbmDb() *ThreadGdbmDb {
|
||||
if dbC == nil {
|
||||
dbC = make(map[string]chan db.Db)
|
||||
}
|
||||
return &ThreadGdbmDb{}
|
||||
}
|
||||
|
||||
func(tdb *ThreadGdbmDb) Connect(ctx context.Context, connStr string) error {
|
||||
var ok bool
|
||||
_, ok = dbC[connStr]
|
||||
if ok {
|
||||
logg.WarnCtxf(ctx, "already registered thread gdbm, skipping", "connStr", connStr)
|
||||
}
|
||||
gdb := gdbmdb.NewGdbmDb()
|
||||
err := gdb.Connect(ctx, connStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dbC[connStr] = make(chan db.Db, 1)
|
||||
dbC[connStr]<- gdb
|
||||
tdb.connStr = connStr
|
||||
tdb.registered = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func(tdb *ThreadGdbmDb) reserve() {
|
||||
if tdb.db == nil {
|
||||
tdb.db = <-dbC[tdb.connStr]
|
||||
}
|
||||
}
|
||||
|
||||
func(tdb *ThreadGdbmDb) release() {
|
||||
if tdb.db == nil {
|
||||
return
|
||||
}
|
||||
dbC[tdb.connStr] <- tdb.db
|
||||
tdb.db = nil
|
||||
}
|
||||
|
||||
func(tdb *ThreadGdbmDb) SetPrefix(pfx uint8) {
|
||||
tdb.reserve()
|
||||
tdb.db.SetPrefix(pfx)
|
||||
}
|
||||
|
||||
func(tdb *ThreadGdbmDb) SetSession(sessionId string) {
|
||||
tdb.reserve()
|
||||
tdb.db.SetSession(sessionId)
|
||||
}
|
||||
|
||||
func(tdb *ThreadGdbmDb) SetLanguage(lng *lang.Language) {
|
||||
tdb.reserve()
|
||||
tdb.db.SetLanguage(lng)
|
||||
}
|
||||
|
||||
func(tdb *ThreadGdbmDb) Safe() bool {
|
||||
tdb.reserve()
|
||||
v := tdb.db.Safe()
|
||||
tdb.release()
|
||||
return v
|
||||
}
|
||||
|
||||
func(tdb *ThreadGdbmDb) Prefix() uint8 {
|
||||
tdb.reserve()
|
||||
v := tdb.db.Prefix()
|
||||
tdb.release()
|
||||
return v
|
||||
}
|
||||
|
||||
func(tdb *ThreadGdbmDb) SetLock(typ uint8, locked bool) error {
|
||||
tdb.reserve()
|
||||
err := tdb.db.SetLock(typ, locked)
|
||||
tdb.release()
|
||||
return err
|
||||
}
|
||||
|
||||
func(tdb *ThreadGdbmDb) Put(ctx context.Context, key []byte, val []byte) error {
|
||||
tdb.reserve()
|
||||
err := tdb.db.Put(ctx, key, val)
|
||||
tdb.release()
|
||||
return err
|
||||
}
|
||||
|
||||
func(tdb *ThreadGdbmDb) Get(ctx context.Context, key []byte) ([]byte, error) {
|
||||
tdb.reserve()
|
||||
v, err := tdb.db.Get(ctx, key)
|
||||
tdb.release()
|
||||
return v, err
|
||||
}
|
||||
|
||||
func(tdb *ThreadGdbmDb) Close() error {
|
||||
tdb.reserve()
|
||||
close(dbC[tdb.connStr])
|
||||
err := tdb.db.Close()
|
||||
tdb.db = nil
|
||||
return err
|
||||
}
|
||||
@@ -5,10 +5,6 @@ import (
|
||||
"git.defalsify.org/vise.git/persist"
|
||||
)
|
||||
|
||||
const (
|
||||
DATATYPE_CUSTOM = 128
|
||||
)
|
||||
|
||||
type Storage struct {
|
||||
Persister *persist.Persister
|
||||
UserdataDb db.Db
|
||||
|
||||
@@ -1,103 +0,0 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
|
||||
"git.defalsify.org/vise.git/db"
|
||||
fsdb "git.defalsify.org/vise.git/db/fs"
|
||||
"git.defalsify.org/vise.git/persist"
|
||||
"git.defalsify.org/vise.git/resource"
|
||||
"git.defalsify.org/vise.git/logging"
|
||||
)
|
||||
|
||||
var (
|
||||
logg = logging.NewVanilla().WithDomain("storage")
|
||||
)
|
||||
|
||||
type StorageService interface {
|
||||
GetPersister(ctx context.Context) (*persist.Persister, error)
|
||||
GetUserdataDb(ctx context.Context) db.Db
|
||||
GetResource(ctx context.Context) (resource.Resource, error)
|
||||
EnsureDbDir() error
|
||||
}
|
||||
|
||||
type MenuStorageService struct{
|
||||
dbDir string
|
||||
resourceDir string
|
||||
resourceStore db.Db
|
||||
stateStore db.Db
|
||||
userDataStore db.Db
|
||||
}
|
||||
|
||||
func NewMenuStorageService(dbDir string, resourceDir string) *MenuStorageService {
|
||||
return &MenuStorageService{
|
||||
dbDir: dbDir,
|
||||
resourceDir: resourceDir,
|
||||
}
|
||||
}
|
||||
|
||||
func (ms *MenuStorageService) GetPersister(ctx context.Context) (*persist.Persister, error) {
|
||||
ms.stateStore = NewThreadGdbmDb()
|
||||
storeFile := path.Join(ms.dbDir, "state.gdbm")
|
||||
err := ms.stateStore.Connect(ctx, storeFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pr := persist.NewPersister(ms.stateStore)
|
||||
logg.TraceCtxf(ctx, "menu storage service", "persist", pr, "store", ms.stateStore)
|
||||
return pr, nil
|
||||
}
|
||||
|
||||
func (ms *MenuStorageService) GetUserdataDb(ctx context.Context) (db.Db, error) {
|
||||
ms.userDataStore = NewThreadGdbmDb()
|
||||
storeFile := path.Join(ms.dbDir, "userdata.gdbm")
|
||||
err := ms.userDataStore.Connect(ctx, storeFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ms.userDataStore, nil
|
||||
}
|
||||
|
||||
func (ms *MenuStorageService) GetResource(ctx context.Context) (resource.Resource, error) {
|
||||
ms.resourceStore = fsdb.NewFsDb()
|
||||
err := ms.resourceStore.Connect(ctx, ms.resourceDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rfs := resource.NewDbResource(ms.resourceStore)
|
||||
return rfs, nil
|
||||
}
|
||||
|
||||
func (ms *MenuStorageService) GetStateStore(ctx context.Context) (db.Db, error) {
|
||||
if ms.stateStore != nil {
|
||||
panic("set up store when already exists")
|
||||
}
|
||||
ms.stateStore = NewThreadGdbmDb()
|
||||
storeFile := path.Join(ms.dbDir, "state.gdbm")
|
||||
err := ms.stateStore.Connect(ctx, storeFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ms.stateStore, nil
|
||||
}
|
||||
|
||||
func (ms *MenuStorageService) EnsureDbDir() error {
|
||||
err := os.MkdirAll(ms.dbDir, 0700)
|
||||
if err != nil {
|
||||
return fmt.Errorf("state dir create exited with error: %v\n", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ms *MenuStorageService) Close() error {
|
||||
errA := ms.stateStore.Close()
|
||||
errB := ms.userDataStore.Close()
|
||||
errC := ms.resourceStore.Close()
|
||||
if errA != nil || errB != nil || errC != nil {
|
||||
return fmt.Errorf("%v %v %v", errA, errB, errC)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user