From d49f866ca4770d18a75f9d02556b6d15d65ed91d Mon Sep 17 00:00:00 2001 From: lash Date: Thu, 12 Sep 2024 04:07:55 +0100 Subject: [PATCH] Factor out methods common to http and async cli --- cmd/async/main.go | 236 ++++++++++++++++++++++++++++++++++++ internal/handlers/base.go | 102 ++++++++++++++++ internal/handlers/single.go | 7 ++ internal/http/server.go | 96 ++------------- 4 files changed, 354 insertions(+), 87 deletions(-) create mode 100644 cmd/async/main.go create mode 100644 internal/handlers/base.go diff --git a/cmd/async/main.go b/cmd/async/main.go new file mode 100644 index 0000000..cd3a926 --- /dev/null +++ b/cmd/async/main.go @@ -0,0 +1,236 @@ +package main + +import ( + "context" + "flag" + "fmt" + "os" + "path" + + "git.defalsify.org/vise.git/asm" + "git.defalsify.org/vise.git/db" + fsdb "git.defalsify.org/vise.git/db/fs" + gdbmdb "git.defalsify.org/vise.git/db/gdbm" + "git.defalsify.org/vise.git/engine" + "git.defalsify.org/vise.git/resource" + "git.defalsify.org/vise.git/logging" + + "git.grassecon.net/urdt/ussd/internal/handlers/ussd" + "git.grassecon.net/urdt/ussd/internal/handlers" +) + +var ( + logg = logging.NewVanilla() + scriptDir = path.Join("services", "registration") +) + +type asyncRequestParser struct { + sessionId string + input []byte +} + +func(p *asyncRequestParser) GetSessionId(r any) (string, error) { + return p.sessionId, nil +} + +func(p *asyncRequestParser) GetInput(r any) ([]byte, error) { + return p.input, nil +} + +func getFlags(fp string, debug bool) (*asm.FlagParser, error) { + flagParser := asm.NewFlagParser().WithDebug() + _, err := flagParser.Load(fp) + if err != nil { + return nil, err + } + return flagParser, nil +} + +func getHandler(appFlags *asm.FlagParser, rs *resource.DbResource, userdataStore db.Db) (*ussd.Handlers, error) { + + ussdHandlers, err := ussd.NewHandlers(appFlags, userdataStore) + if err != nil { + return nil, err + } + rs.AddLocalFunc("select_language", ussdHandlers.SetLanguage) + rs.AddLocalFunc("create_account", ussdHandlers.CreateAccount) + rs.AddLocalFunc("save_pin", ussdHandlers.SavePin) + rs.AddLocalFunc("verify_pin", ussdHandlers.VerifyPin) + rs.AddLocalFunc("check_identifier", ussdHandlers.CheckIdentifier) + rs.AddLocalFunc("check_account_status", ussdHandlers.CheckAccountStatus) + rs.AddLocalFunc("authorize_account", ussdHandlers.Authorize) + rs.AddLocalFunc("quit", ussdHandlers.Quit) + rs.AddLocalFunc("check_balance", ussdHandlers.CheckBalance) + rs.AddLocalFunc("validate_recipient", ussdHandlers.ValidateRecipient) + rs.AddLocalFunc("transaction_reset", ussdHandlers.TransactionReset) + rs.AddLocalFunc("max_amount", ussdHandlers.MaxAmount) + rs.AddLocalFunc("validate_amount", ussdHandlers.ValidateAmount) + rs.AddLocalFunc("reset_transaction_amount", ussdHandlers.ResetTransactionAmount) + rs.AddLocalFunc("get_recipient", ussdHandlers.GetRecipient) + rs.AddLocalFunc("get_sender", ussdHandlers.GetSender) + rs.AddLocalFunc("get_amount", ussdHandlers.GetAmount) + rs.AddLocalFunc("reset_incorrect", ussdHandlers.ResetIncorrectPin) + rs.AddLocalFunc("save_firstname", ussdHandlers.SaveFirstname) + rs.AddLocalFunc("save_familyname", ussdHandlers.SaveFamilyname) + rs.AddLocalFunc("save_gender", ussdHandlers.SaveGender) + rs.AddLocalFunc("save_location", ussdHandlers.SaveLocation) + rs.AddLocalFunc("save_yob", ussdHandlers.SaveYob) + rs.AddLocalFunc("save_offerings", ussdHandlers.SaveOfferings) + rs.AddLocalFunc("quit_with_balance", ussdHandlers.QuitWithBalance) + rs.AddLocalFunc("reset_account_authorized", ussdHandlers.ResetAccountAuthorized) + rs.AddLocalFunc("reset_allow_update", ussdHandlers.ResetAllowUpdate) + rs.AddLocalFunc("get_profile_info", ussdHandlers.GetProfileInfo) + rs.AddLocalFunc("verify_yob", ussdHandlers.VerifyYob) + rs.AddLocalFunc("reset_incorrect_date_format", ussdHandlers.ResetIncorrectYob) + rs.AddLocalFunc("set_reset_single_edit", ussdHandlers.SetResetSingleEdit) + rs.AddLocalFunc("initiate_transaction", ussdHandlers.InitiateTransaction) + + return ussdHandlers, nil +} + +func ensureDbDir(dbDir string) error { + err := os.MkdirAll(dbDir, 0700) + if err != nil { + return fmt.Errorf("state dir create exited with error: %v\n", err) + } + return nil +} + +func getStateStore(dbDir string, ctx context.Context) (db.Db, error) { + store := gdbmdb.NewGdbmDb() + storeFile := path.Join(dbDir, "state.gdbm") + store.Connect(ctx, storeFile) + return store, nil +} + +func getUserdataDb(dbDir string, ctx context.Context) db.Db { + store := gdbmdb.NewGdbmDb() + storeFile := path.Join(dbDir, "userdata.gdbm") + store.Connect(ctx, storeFile) + + return store +} + +func getResource(resourceDir string, ctx context.Context) (resource.Resource, error) { + store := fsdb.NewFsDb() + err := store.Connect(ctx, resourceDir) + if err != nil { + return nil, err + } + rfs := resource.NewDbResource(store) + return rfs, nil +} + + +func main() { + var sessionId string + var dbDir string + var resourceDir string + var size uint + var engineDebug bool + var stateDebug bool + var host string + var port uint + flag.StringVar(&sessionId, "session-id", "075xx2123", "session id") + flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from") + flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir") + flag.BoolVar(&engineDebug, "engine-debug", false, "use engine debug output") + flag.BoolVar(&stateDebug, "state-debug", false, "use engine debug output") + flag.UintVar(&size, "s", 160, "max size of output") + flag.StringVar(&host, "h", "127.0.0.1", "http host") + flag.UintVar(&port, "p", 7123, "http port") + flag.Parse() + + logg.Infof("start command", "dbdir", dbDir, "resourcedir", resourceDir, "outputsize", size, "sessionId", sessionId) + + ctx := context.Background() + pfp := path.Join(scriptDir, "pp.csv") + flagParser, err := getFlags(pfp, true) + + if err != nil { + os.Exit(1) + } + + cfg := engine.Config{ + Root: "root", + OutputSize: uint32(size), + FlagCount: uint32(16), + } + if stateDebug { + cfg.StateDebug = true + } + if engineDebug { + cfg.EngineDebug = true + } + + rs, err := getResource(resourceDir, ctx) + if err != nil { + fmt.Fprintf(os.Stderr, err.Error()) + os.Exit(1) + } + + err = ensureDbDir(dbDir) + if err != nil { + fmt.Fprintf(os.Stderr, err.Error()) + os.Exit(1) + } + + userdataStore := getUserdataDb(dbDir, ctx) + if err != nil { + fmt.Fprintf(os.Stderr, err.Error()) + os.Exit(1) + } + defer userdataStore.Close() + + dbResource, ok := rs.(*resource.DbResource) + if !ok { + os.Exit(1) + } + + hl, err := getHandler(flagParser, dbResource, userdataStore) + if err != nil { + fmt.Fprintf(os.Stderr, err.Error()) + os.Exit(1) + } + + stateStore, err := getStateStore(dbDir, ctx) + if err != nil { + fmt.Fprintf(os.Stderr, err.Error()) + os.Exit(1) + } + defer stateStore.Close() + + rp := &asyncRequestParser{ + sessionId: sessionId, + } + sh := handlers.NewBaseSessionHandler(cfg, rs, stateStore, userdataStore, rp, hl) + cfg.SessionId = sessionId + rqs := handlers.RequestSession{ + Ctx: ctx, + Writer: os.Stdout, + Config: cfg, + } + for true { + rqs, err = sh.Process(rqs) + if err != nil { + fmt.Errorf("error in process: %v", err) + os.Exit(1) + } + rqs, err = sh.Output(rqs) + if err != nil { + fmt.Errorf("error in output: %v", err) + os.Exit(1) + } + rqs, err = sh.Reset(rqs) + if err != nil { + fmt.Errorf("error in reset: %v", err) + os.Exit(1) + } + fmt.Println("") + _, err = fmt.Scanln(&rqs.Input) + if err != nil { + fmt.Errorf("error in input: %v", err) + os.Exit(1) + } + } +} diff --git a/internal/handlers/base.go b/internal/handlers/base.go new file mode 100644 index 0000000..fba62c9 --- /dev/null +++ b/internal/handlers/base.go @@ -0,0 +1,102 @@ +package handlers + +import ( + "git.defalsify.org/vise.git/engine" + "git.defalsify.org/vise.git/resource" + "git.defalsify.org/vise.git/persist" + "git.defalsify.org/vise.git/db" + + "git.grassecon.net/urdt/ussd/internal/storage" + "git.grassecon.net/urdt/ussd/internal/handlers/ussd" +) + +type BaseSessionHandler struct { + cfgTemplate engine.Config + rp RequestParser + rs resource.Resource + hn *ussd.Handlers + provider storage.StorageProvider +} + +func NewBaseSessionHandler(cfg engine.Config, rs resource.Resource, stateDb db.Db, userdataDb db.Db, rp RequestParser, hn *ussd.Handlers) *BaseSessionHandler { + return &BaseSessionHandler{ + cfgTemplate: cfg, + rs: rs, + hn: hn, + rp: rp, + provider: storage.NewSimpleStorageProvider(stateDb, userdataDb), + } +} + +func(f* BaseSessionHandler) Shutdown() { + err := f.provider.Close() + if err != nil { + logg.Errorf("handler shutdown error", "err", err) + } +} + +func(f *BaseSessionHandler) GetEngine(cfg engine.Config, rs resource.Resource, pr *persist.Persister) engine.Engine { + en := engine.NewEngine(cfg, rs) + en = en.WithPersister(pr) + return en +} + +func(f *BaseSessionHandler) Process(rqs RequestSession) (RequestSession, error) { + var r bool + var err error + var ok bool + + logg.InfoCtxf(rqs.Ctx, "new request", rqs) + + rqs.Storage, err = f.provider.Get(rqs.Config.SessionId) + if err != nil { + logg.ErrorCtxf(rqs.Ctx, "", "storage error", "err", err) + return rqs, ErrStorage + } + + f.hn = f.hn.WithPersister(rqs.Storage.Persister) + eni := f.GetEngine(rqs.Config, f.rs, rqs.Storage.Persister) + en, ok := eni.(*engine.DefaultEngine) + if !ok { + return rqs, ErrEngineType + } + en = en.WithFirst(f.hn.Init) + if rqs.Config.EngineDebug { + en = en.WithDebug(nil) + } + rqs.Engine = en + + r, err = rqs.Engine.Init(rqs.Ctx) + if err != nil { + return rqs, err + } + + if r && len(rqs.Input) > 0 { + r, err = rqs.Engine.Exec(rqs.Ctx, rqs.Input) + } + if err != nil { + return rqs, err + } + + _ = r + return rqs, nil +} + +func(f *BaseSessionHandler) Output(rqs RequestSession) (RequestSession, error) { + var err error + _, err = rqs.Engine.WriteResult(rqs.Ctx, rqs.Writer) + return rqs, err +} + +func(f *BaseSessionHandler) Reset(rqs RequestSession) (RequestSession, error) { + defer f.provider.Put(rqs.Config.SessionId, rqs.Storage) + return rqs, rqs.Engine.Finish() +} + +func(f *BaseSessionHandler) GetConfig() engine.Config { + return f.cfgTemplate +} + +func(f *BaseSessionHandler) GetRequestParser() RequestParser { + return f.rp +} diff --git a/internal/handlers/single.go b/internal/handlers/single.go index 7b6c9db..40b0594 100644 --- a/internal/handlers/single.go +++ b/internal/handlers/single.go @@ -8,10 +8,15 @@ import ( "git.defalsify.org/vise.git/engine" "git.defalsify.org/vise.git/resource" "git.defalsify.org/vise.git/persist" + "git.defalsify.org/vise.git/logging" "git.grassecon.net/urdt/ussd/internal/storage" ) +var ( + logg = logging.NewVanilla().WithDomain("handlers") +) + var ( ErrInvalidRequest = errors.New("invalid request for context") ErrSessionMissing = errors.New("missing session") @@ -39,6 +44,8 @@ type RequestParser interface { } type RequestHandler interface { + GetConfig() engine.Config + GetRequestParser() RequestParser GetEngine(cfg engine.Config, rs resource.Resource, pe *persist.Persister) engine.Engine Process(rs RequestSession) (RequestSession, error) Output(rs RequestSession) (RequestSession, error) diff --git a/internal/http/server.go b/internal/http/server.go index 8425302..af5413a 100644 --- a/internal/http/server.go +++ b/internal/http/server.go @@ -5,15 +5,9 @@ import ( "net/http" "strconv" - "git.defalsify.org/vise.git/db" - "git.defalsify.org/vise.git/engine" "git.defalsify.org/vise.git/logging" - "git.defalsify.org/vise.git/persist" - "git.defalsify.org/vise.git/resource" "git.grassecon.net/urdt/ussd/internal/handlers" - "git.grassecon.net/urdt/ussd/internal/handlers/ussd" - "git.grassecon.net/urdt/ussd/internal/storage" ) var ( @@ -50,20 +44,12 @@ func(rp *DefaultRequestParser) GetInput(rq any) ([]byte, error) { } type SessionHandler struct { - cfgTemplate engine.Config - rp handlers.RequestParser - rs resource.Resource - hn *ussd.Handlers - provider storage.StorageProvider + handlers.RequestHandler } -func NewSessionHandler(cfg engine.Config, rs resource.Resource, stateDb db.Db, userdataDb db.Db, rp handlers.RequestParser, hn *ussd.Handlers) *SessionHandler { +func ToSessionHandler(h handlers.RequestHandler) *SessionHandler { return &SessionHandler{ - cfgTemplate: cfg, - rs: rs, - hn: hn, - rp: rp, - provider: storage.NewSimpleStorageProvider(stateDb, userdataDb), + RequestHandler: h, } } @@ -79,71 +65,6 @@ func(f *SessionHandler) writeError(w http.ResponseWriter, code int, err error) { return } -func(f* SessionHandler) Shutdown() { - err := f.provider.Close() - if err != nil { - logg.Errorf("handler shutdown error", "err", err) - } -} - -func(f *SessionHandler) GetEngine(cfg engine.Config, rs resource.Resource, pr *persist.Persister) engine.Engine { - en := engine.NewEngine(cfg, rs) - en = en.WithPersister(pr) - return en -} - -func(f *SessionHandler) Process(rqs handlers.RequestSession) (handlers.RequestSession, error) { - var r bool - var err error - var ok bool - - logg.InfoCtxf(rqs.Ctx, "new request", rqs) - - rqs.Storage, err = f.provider.Get(rqs.Config.SessionId) - if err != nil { - logg.ErrorCtxf(rqs.Ctx, "", "storage error", "err", err) - return rqs, handlers.ErrStorage - } - - f.hn = f.hn.WithPersister(rqs.Storage.Persister) - eni := f.GetEngine(rqs.Config, f.rs, rqs.Storage.Persister) - en, ok := eni.(*engine.DefaultEngine) - if !ok { - return rqs, handlers.ErrEngineType - } - en = en.WithFirst(f.hn.Init) - if rqs.Config.EngineDebug { - en = en.WithDebug(nil) - } - rqs.Engine = en - - r, err = rqs.Engine.Init(rqs.Ctx) - if err != nil { - return rqs, err - } - - if r && len(rqs.Input) > 0 { - r, err = rqs.Engine.Exec(rqs.Ctx, rqs.Input) - } - if err != nil { - return rqs, err - } - - _ = r - return rqs, nil -} - -func(f *SessionHandler) Output(rqs handlers.RequestSession) error { - var err error - _, err = rqs.Engine.WriteResult(rqs.Ctx, rqs.Writer) - return err -} - -func(f *SessionHandler) Reset(rqs handlers.RequestSession) error { - defer f.provider.Put(rqs.Config.SessionId, rqs.Storage) - return rqs.Engine.Finish() -} - func(f *SessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { var code int var err error @@ -153,14 +74,15 @@ func(f *SessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { Writer: w, } - cfg := f.cfgTemplate - cfg.SessionId, err = f.rp.GetSessionId(req) + rp := f.GetRequestParser() + cfg := f.GetConfig() + cfg.SessionId, err = rp.GetSessionId(req) if err != nil { logg.ErrorCtxf(rqs.Ctx, "", "header processing error", err) f.writeError(w, 400, err) } rqs.Config = cfg - rqs.Input, err = f.rp.GetInput(req) + rqs.Input, err = rp.GetInput(req) if err != nil { logg.ErrorCtxf(rqs.Ctx, "", "header processing error", err) f.writeError(w, 400, err) @@ -186,13 +108,13 @@ func(f *SessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { w.WriteHeader(200) w.Header().Set("Content-Type", "text/plain") - err = f.Output(rqs) + rqs, err = f.Output(rqs) if err != nil { f.writeError(w, 500, err) return } - err = f.Reset(rqs) + rqs, err = f.Reset(rqs) if err != nil { f.writeError(w, 500, err) return