From bd604219b877c761d898adf689cbc3224a93c815 Mon Sep 17 00:00:00 2001 From: lash Date: Sat, 4 Jan 2025 22:27:46 +0000 Subject: [PATCH] WIP Factor out request, errors --- cmd/africastalking/main.go | 163 ------------ cmd/async/main.go | 3 +- cmd/http/main.go | 3 +- common/storage.go | 10 + errors/errors.go | 15 ++ internal/handlers/base.go | 13 +- internal/handlers/single.go | 38 +-- internal/http/at/parse.go | 121 --------- internal/http/at/server.go | 98 -------- internal/http/at/server_test.go | 234 ------------------ internal/http/server.go | 14 +- internal/http/server_test.go | 11 +- internal/storage/storageservice.go | 5 + .../mocks/httpmocks/requesthandlermock.go | 20 +- request/request.go | 70 ++++++ 15 files changed, 134 insertions(+), 684 deletions(-) delete mode 100644 cmd/africastalking/main.go create mode 100644 errors/errors.go delete mode 100644 internal/http/at/parse.go delete mode 100644 internal/http/at/server.go delete mode 100644 internal/http/at/server_test.go create mode 100644 request/request.go diff --git a/cmd/africastalking/main.go b/cmd/africastalking/main.go deleted file mode 100644 index 40c3609..0000000 --- a/cmd/africastalking/main.go +++ /dev/null @@ -1,163 +0,0 @@ -package main - -import ( - "context" - "flag" - "fmt" - "net/http" - "os" - "os/signal" - "path" - "strconv" - "syscall" - - "git.defalsify.org/vise.git/engine" - "git.defalsify.org/vise.git/logging" - "git.defalsify.org/vise.git/resource" - - "git.grassecon.net/urdt/ussd/config" - "git.grassecon.net/urdt/ussd/initializers" - "git.grassecon.net/urdt/ussd/internal/handlers" - "git.grassecon.net/urdt/ussd/internal/http/at" - "git.grassecon.net/urdt/ussd/internal/storage" - "git.grassecon.net/urdt/ussd/remote" -) - -var ( - logg = logging.NewVanilla().WithDomain("AfricasTalking").WithContextKey("at-session-id") - scriptDir = path.Join("services", "registration") - build = "dev" - menuSeparator = ": " -) - -func init() { - initializers.LoadEnvVariables() -} - -func main() { - config.LoadConfig() - - var connStr string - var resourceDir string - var size uint - var database string - var engineDebug bool - var host string - var port uint - var err error - - flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir") - flag.StringVar(&connStr, "c", "", "connection string") - flag.BoolVar(&engineDebug, "d", false, "use engine debug output") - flag.UintVar(&size, "s", 160, "max size of output") - flag.StringVar(&host, "h", initializers.GetEnv("HOST", "127.0.0.1"), "http host") - flag.UintVar(&port, "p", initializers.GetEnvUint("PORT", 7123), "http port") - flag.Parse() - - if connStr != "" { - connStr = config.DbConn - } - connData, err := storage.ToConnData(config.DbConn) - if err != nil { - fmt.Fprintf(os.Stderr, "connstr err: %v", err) - os.Exit(1) - } - - logg.Infof("start command", "build", build, "conn", connData, "resourcedir", resourceDir, "outputsize", size) - - ctx := context.Background() - ctx = context.WithValue(ctx, "Database", database) - pfp := path.Join(scriptDir, "pp.csv") - - cfg := engine.Config{ - Root: "root", - OutputSize: uint32(size), - FlagCount: uint32(128), - MenuSeparator: menuSeparator, - } - - if engineDebug { - cfg.EngineDebug = true - } - - menuStorageService := storage.NewMenuStorageService(connData, resourceDir) - if err != nil { - fmt.Fprintf(os.Stderr, err.Error()) - os.Exit(1) - } - - rs, err := menuStorageService.GetResource(ctx) - if err != nil { - fmt.Fprintf(os.Stderr, err.Error()) - os.Exit(1) - } - - userdataStore, err := menuStorageService.GetUserdataDb(ctx) - if err != nil { - fmt.Fprintf(os.Stderr, err.Error()) - os.Exit(1) - } - defer userdataStore.Close() - - dbResource, ok := rs.(*resource.DbResource) - if !ok { - os.Exit(1) - } - - lhs, err := handlers.NewLocalHandlerService(ctx, pfp, true, dbResource, cfg, rs) - if err != nil { - fmt.Fprintf(os.Stderr, err.Error()) - os.Exit(1) - } - lhs.SetDataStore(&userdataStore) - - if err != nil { - fmt.Fprintf(os.Stderr, err.Error()) - os.Exit(1) - } - - accountService := remote.AccountService{} - hl, err := lhs.GetHandler(&accountService) - if err != nil { - fmt.Fprintf(os.Stderr, err.Error()) - os.Exit(1) - } - - stateStore, err := menuStorageService.GetStateStore(ctx) - if err != nil { - fmt.Fprintf(os.Stderr, err.Error()) - os.Exit(1) - } - defer stateStore.Close() - - rp := &at.ATRequestParser{ - Context: ctx, - } - bsh := handlers.NewBaseSessionHandler(cfg, rs, stateStore, userdataStore, rp, hl) - sh := at.NewATSessionHandler(bsh) - - mux := http.NewServeMux() - mux.Handle(initializers.GetEnv("AT_ENDPOINT", "/"), sh) - - s := &http.Server{ - Addr: fmt.Sprintf("%s:%s", host, strconv.Itoa(int(port))), - Handler: mux, - } - s.RegisterOnShutdown(sh.Shutdown) - - cint := make(chan os.Signal) - cterm := make(chan os.Signal) - signal.Notify(cint, os.Interrupt, syscall.SIGINT) - signal.Notify(cterm, os.Interrupt, syscall.SIGTERM) - go func() { - select { - case _ = <-cint: - case _ = <-cterm: - } - s.Shutdown(ctx) - }() - err = s.ListenAndServe() - if err != nil { - logg.Infof("Server closed with error", "err", err) - } -} diff --git a/cmd/async/main.go b/cmd/async/main.go index b0c7caa..1e06029 100644 --- a/cmd/async/main.go +++ b/cmd/async/main.go @@ -18,6 +18,7 @@ import ( "git.grassecon.net/urdt/ussd/internal/handlers" "git.grassecon.net/urdt/ussd/internal/storage" "git.grassecon.net/urdt/ussd/remote" + "git.grassecon.net/urdt/ussd/request" ) var ( @@ -138,7 +139,7 @@ func main() { } sh := handlers.NewBaseSessionHandler(cfg, rs, stateStore, userdataStore, rp, hl) cfg.SessionId = sessionId - rqs := handlers.RequestSession{ + rqs := request.RequestSession{ Ctx: ctx, Writer: os.Stdout, Config: cfg, diff --git a/cmd/http/main.go b/cmd/http/main.go index d744afc..1fd9574 100644 --- a/cmd/http/main.go +++ b/cmd/http/main.go @@ -21,6 +21,7 @@ import ( httpserver "git.grassecon.net/urdt/ussd/internal/http" "git.grassecon.net/urdt/ussd/internal/storage" "git.grassecon.net/urdt/ussd/remote" + "git.grassecon.net/urdt/ussd/request" ) var ( @@ -123,7 +124,7 @@ func main() { rp := &httpserver.DefaultRequestParser{} bsh := handlers.NewBaseSessionHandler(cfg, rs, stateStore, userdataStore, rp, hl) - sh := httpserver.ToSessionHandler(bsh) + sh := request.ToSessionHandler(bsh) s := &http.Server{ Addr: fmt.Sprintf("%s:%s", host, strconv.Itoa(int(port))), Handler: sh, diff --git a/common/storage.go b/common/storage.go index 2960578..7fefcc3 100644 --- a/common/storage.go +++ b/common/storage.go @@ -11,6 +11,10 @@ import ( dbstorage "git.grassecon.net/urdt/ussd/internal/storage/db" ) +var ( + ToConnData = storage.ToConnData +) + func StoreToDb(store *UserDataStore) db.Db { return store.Db } @@ -36,6 +40,12 @@ func NewStorageService(conn storage.ConnData) (*StorageService, error) { return svc, nil } +// TODO: simplify enable poresource, conndata instead +func(ss *StorageService) SetResourceDir(resourceDir string) error { + ss.svc = ss.svc.WithResourceDir(resourceDir) + return nil +} + func(ss *StorageService) GetPersister(ctx context.Context) (*persist.Persister, error) { return ss.svc.GetPersister(ctx) } diff --git a/errors/errors.go b/errors/errors.go new file mode 100644 index 0000000..7dba1de --- /dev/null +++ b/errors/errors.go @@ -0,0 +1,15 @@ +package common + +import ( + "git.grassecon.net/urdt/ussd/internal/handlers" +) + +var ( + ErrInvalidRequest = handlers.ErrInvalidRequest + ErrSessionMissing = handlers.ErrSessionMissing + ErrInvalidInput = handlers.ErrInvalidInput + ErrStorage = handlers.ErrStorage + ErrEngineType = handlers.ErrEngineType + ErrEngineInit = handlers.ErrEngineInit + ErrEngineExec = handlers.ErrEngineExec +) diff --git a/internal/handlers/base.go b/internal/handlers/base.go index 755cca4..ed3d63d 100644 --- a/internal/handlers/base.go +++ b/internal/handlers/base.go @@ -6,19 +6,20 @@ import ( "git.defalsify.org/vise.git/persist" "git.defalsify.org/vise.git/resource" + "git.grassecon.net/urdt/ussd/request" "git.grassecon.net/urdt/ussd/internal/handlers/ussd" "git.grassecon.net/urdt/ussd/internal/storage" ) type BaseSessionHandler struct { cfgTemplate engine.Config - rp RequestParser + rp request.RequestParser rs resource.Resource hn *ussd.Handlers provider storage.StorageProvider } -func NewBaseSessionHandler(cfg engine.Config, rs resource.Resource, stateDb db.Db, userdataDb db.Db, rp RequestParser, hn *ussd.Handlers) *BaseSessionHandler { +func NewBaseSessionHandler(cfg engine.Config, rs resource.Resource, stateDb db.Db, userdataDb db.Db, rp request.RequestParser, hn *ussd.Handlers) *BaseSessionHandler { return &BaseSessionHandler{ cfgTemplate: cfg, rs: rs, @@ -41,7 +42,7 @@ func(f *BaseSessionHandler) GetEngine(cfg engine.Config, rs resource.Resource, p return en } -func(f *BaseSessionHandler) Process(rqs RequestSession) (RequestSession, error) { +func(f *BaseSessionHandler) Process(rqs request.RequestSession) (request.RequestSession, error) { var r bool var err error var ok bool @@ -88,13 +89,13 @@ func(f *BaseSessionHandler) Process(rqs RequestSession) (RequestSession, error) return rqs, nil } -func(f *BaseSessionHandler) Output(rqs RequestSession) (RequestSession, error) { +func(f *BaseSessionHandler) Output(rqs request.RequestSession) (request.RequestSession, error) { var err error _, err = rqs.Engine.Flush(rqs.Ctx, rqs.Writer) return rqs, err } -func(f *BaseSessionHandler) Reset(rqs RequestSession) (RequestSession, error) { +func(f *BaseSessionHandler) Reset(rqs request.RequestSession) (request.RequestSession, error) { defer f.provider.Put(rqs.Config.SessionId, rqs.Storage) return rqs, rqs.Engine.Finish() } @@ -103,6 +104,6 @@ func(f *BaseSessionHandler) GetConfig() engine.Config { return f.cfgTemplate } -func(f *BaseSessionHandler) GetRequestParser() RequestParser { +func(f *BaseSessionHandler) GetRequestParser() request.RequestParser { return f.rp } diff --git a/internal/handlers/single.go b/internal/handlers/single.go index 6929617..19079dd 100644 --- a/internal/handlers/single.go +++ b/internal/handlers/single.go @@ -1,20 +1,9 @@ package handlers import ( - "context" "errors" - "io" - "git.defalsify.org/vise.git/engine" - "git.defalsify.org/vise.git/resource" - "git.defalsify.org/vise.git/persist" "git.defalsify.org/vise.git/logging" - - "git.grassecon.net/urdt/ussd/internal/storage" -) - -var ( - logg = logging.NewVanilla().WithDomain("handlers") ) var ( @@ -27,28 +16,7 @@ var ( ErrEngineExec = errors.New("engine exec fail") ) -type RequestSession struct { - Ctx context.Context - Config engine.Config - Engine engine.Engine - Input []byte - Storage *storage.Storage - Writer io.Writer - Continue bool -} +var ( + logg = logging.NewVanilla().WithDomain("handlers") +) -// TODO: seems like can remove this. -type RequestParser interface { - GetSessionId(rq any) (string, error) - GetInput(rq any) ([]byte, error) -} - -type RequestHandler interface { - GetConfig() engine.Config - GetRequestParser() RequestParser - GetEngine(cfg engine.Config, rs resource.Resource, pe *persist.Persister) engine.Engine - Process(rs RequestSession) (RequestSession, error) - Output(rs RequestSession) (RequestSession, error) - Reset(rs RequestSession) (RequestSession, error) - Shutdown() -} diff --git a/internal/http/at/parse.go b/internal/http/at/parse.go deleted file mode 100644 index d2696ed..0000000 --- a/internal/http/at/parse.go +++ /dev/null @@ -1,121 +0,0 @@ -package at - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "net/url" - "strings" - - "git.grassecon.net/urdt/ussd/common" - "git.grassecon.net/urdt/ussd/internal/handlers" -) - -type ATRequestParser struct { - Context context.Context -} - -func (arp *ATRequestParser) GetSessionId(rq any) (string, error) { - rqv, ok := rq.(*http.Request) - if !ok { - logg.Warnf("got an invalid request", "req", rq) - return "", handlers.ErrInvalidRequest - } - - // Capture body (if any) for logging - body, err := io.ReadAll(rqv.Body) - if err != nil { - logg.Warnf("failed to read request body", "err", err) - return "", fmt.Errorf("failed to read request body: %v", err) - } - // Reset the body for further reading - rqv.Body = io.NopCloser(bytes.NewReader(body)) - - // Log the body as JSON - bodyLog := map[string]string{"body": string(body)} - logBytes, err := json.Marshal(bodyLog) - if err != nil { - logg.Warnf("failed to marshal request body", "err", err) - } else { - decodedStr := string(logBytes) - sessionId, err := extractATSessionId(decodedStr) - if err != nil { - context.WithValue(arp.Context, "at-session-id", sessionId) - } - logg.Debugf("Received request:", decodedStr) - } - - if err := rqv.ParseForm(); err != nil { - logg.Warnf("failed to parse form data", "err", err) - return "", fmt.Errorf("failed to parse form data: %v", err) - } - - phoneNumber := rqv.FormValue("phoneNumber") - if phoneNumber == "" { - return "", fmt.Errorf("no phone number found") - } - - formattedNumber, err := common.FormatPhoneNumber(phoneNumber) - if err != nil { - logg.Warnf("failed to format phone number", "err", err) - return "", fmt.Errorf("failed to format number") - } - - return formattedNumber, nil -} - -func (arp *ATRequestParser) GetInput(rq any) ([]byte, error) { - rqv, ok := rq.(*http.Request) - if !ok { - return nil, handlers.ErrInvalidRequest - } - if err := rqv.ParseForm(); err != nil { - return nil, fmt.Errorf("failed to parse form data: %v", err) - } - - text := rqv.FormValue("text") - - parts := strings.Split(text, "*") - if len(parts) == 0 { - return nil, fmt.Errorf("no input found") - } - - return []byte(parts[len(parts)-1]), nil -} - -func parseQueryParams(query string) map[string]string { - params := make(map[string]string) - - queryParams := strings.Split(query, "&") - for _, param := range queryParams { - // Split each key-value pair by '=' - parts := strings.SplitN(param, "=", 2) - if len(parts) == 2 { - params[parts[0]] = parts[1] - } - } - return params -} - -func extractATSessionId(decodedStr string) (string, error) { - var data map[string]string - err := json.Unmarshal([]byte(decodedStr), &data) - - if err != nil { - logg.Errorf("Error unmarshalling JSON: %v", err) - return "", nil - } - decodedBody, err := url.QueryUnescape(data["body"]) - if err != nil { - logg.Errorf("Error URL-decoding body: %v", err) - return "", nil - } - params := parseQueryParams(decodedBody) - - sessionId := params["sessionId"] - return sessionId, nil - -} diff --git a/internal/http/at/server.go b/internal/http/at/server.go deleted file mode 100644 index 705ff76..0000000 --- a/internal/http/at/server.go +++ /dev/null @@ -1,98 +0,0 @@ -package at - -import ( - "io" - "net/http" - - "git.defalsify.org/vise.git/logging" - "git.grassecon.net/urdt/ussd/internal/handlers" - httpserver "git.grassecon.net/urdt/ussd/internal/http" -) - -var ( - logg = logging.NewVanilla().WithDomain("atserver") -) - -type ATSessionHandler struct { - *httpserver.SessionHandler -} - -func NewATSessionHandler(h handlers.RequestHandler) *ATSessionHandler { - return &ATSessionHandler{ - SessionHandler: httpserver.ToSessionHandler(h), - } -} - -func (ash *ATSessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { - var code int - var err error - - rqs := handlers.RequestSession{ - Ctx: req.Context(), - Writer: w, - } - - rp := ash.GetRequestParser() - cfg := ash.GetConfig() - cfg.SessionId, err = rp.GetSessionId(req) - if err != nil { - logg.ErrorCtxf(rqs.Ctx, "", "header processing error", err) - ash.WriteError(w, 400, err) - return - } - rqs.Config = cfg - rqs.Input, err = rp.GetInput(req) - if err != nil { - logg.ErrorCtxf(rqs.Ctx, "", "header processing error", err) - ash.WriteError(w, 400, err) - return - } - - rqs, err = ash.Process(rqs) - switch err { - case nil: // set code to 200 if no err - code = 200 - case handlers.ErrStorage, handlers.ErrEngineInit, handlers.ErrEngineExec, handlers.ErrEngineType: - code = 500 - default: - code = 500 - } - - if code != 200 { - ash.WriteError(w, 500, err) - return - } - - w.WriteHeader(200) - w.Header().Set("Content-Type", "text/plain") - rqs, err = ash.Output(rqs) - if err != nil { - ash.WriteError(w, 500, err) - return - } - - rqs, err = ash.Reset(rqs) - if err != nil { - ash.WriteError(w, 500, err) - return - } -} - -func (ash *ATSessionHandler) Output(rqs handlers.RequestSession) (handlers.RequestSession, error) { - var err error - var prefix string - - if rqs.Continue { - prefix = "CON " - } else { - prefix = "END " - } - - _, err = io.WriteString(rqs.Writer, prefix) - if err != nil { - return rqs, err - } - - _, err = rqs.Engine.Flush(rqs.Ctx, rqs.Writer) - return rqs, err -} diff --git a/internal/http/at/server_test.go b/internal/http/at/server_test.go deleted file mode 100644 index dd45c25..0000000 --- a/internal/http/at/server_test.go +++ /dev/null @@ -1,234 +0,0 @@ -package at - -import ( - "context" - "errors" - "io" - "net/http" - "net/http/httptest" - "net/url" - "strings" - "testing" - - "git.defalsify.org/vise.git/engine" - "git.grassecon.net/urdt/ussd/internal/handlers" - "git.grassecon.net/urdt/ussd/internal/testutil/mocks/httpmocks" -) - -func TestNewATSessionHandler(t *testing.T) { - mockHandler := &httpmocks.MockRequestHandler{} - ash := NewATSessionHandler(mockHandler) - - if ash == nil { - t.Fatal("NewATSessionHandler returned nil") - } - - if ash.SessionHandler == nil { - t.Fatal("SessionHandler is nil") - } -} - -func TestATSessionHandler_ServeHTTP(t *testing.T) { - tests := []struct { - name string - setupMocks func(*httpmocks.MockRequestHandler, *httpmocks.MockRequestParser, *httpmocks.MockEngine) - formData url.Values - expectedStatus int - expectedBody string - }{ - { - name: "Successful request", - setupMocks: func(mh *httpmocks.MockRequestHandler, mrp *httpmocks.MockRequestParser, me *httpmocks.MockEngine) { - mrp.GetSessionIdFunc = func(rq any) (string, error) { - req := rq.(*http.Request) - return req.FormValue("phoneNumber"), nil - } - mrp.GetInputFunc = func(rq any) ([]byte, error) { - req := rq.(*http.Request) - text := req.FormValue("text") - parts := strings.Split(text, "*") - return []byte(parts[len(parts)-1]), nil - } - mh.ProcessFunc = func(rqs handlers.RequestSession) (handlers.RequestSession, error) { - rqs.Continue = true - rqs.Engine = me - return rqs, nil - } - mh.GetConfigFunc = func() engine.Config { return engine.Config{} } - mh.GetRequestParserFunc = func() handlers.RequestParser { return mrp } - mh.OutputFunc = func(rs handlers.RequestSession) (handlers.RequestSession, error) { return rs, nil } - mh.ResetFunc = func(rs handlers.RequestSession) (handlers.RequestSession, error) { return rs, nil } - me.FlushFunc = func(context.Context, io.Writer) (int, error) { return 0, nil } - }, - formData: url.Values{ - "phoneNumber": []string{"+1234567890"}, - "text": []string{"1*2*3"}, - }, - expectedStatus: http.StatusOK, - expectedBody: "CON ", - }, - { - name: "GetSessionId error", - setupMocks: func(mh *httpmocks.MockRequestHandler, mrp *httpmocks.MockRequestParser, me *httpmocks.MockEngine) { - mrp.GetSessionIdFunc = func(rq any) (string, error) { - return "", errors.New("no phone number found") - } - mh.GetConfigFunc = func() engine.Config { return engine.Config{} } - mh.GetRequestParserFunc = func() handlers.RequestParser { return mrp } - }, - formData: url.Values{ - "text": []string{"1*2*3"}, - }, - expectedStatus: http.StatusBadRequest, - expectedBody: "", - }, - { - name: "GetInput error", - setupMocks: func(mh *httpmocks.MockRequestHandler, mrp *httpmocks.MockRequestParser, me *httpmocks.MockEngine) { - mrp.GetSessionIdFunc = func(rq any) (string, error) { - req := rq.(*http.Request) - return req.FormValue("phoneNumber"), nil - } - mrp.GetInputFunc = func(rq any) ([]byte, error) { - return nil, errors.New("no input found") - } - mh.GetConfigFunc = func() engine.Config { return engine.Config{} } - mh.GetRequestParserFunc = func() handlers.RequestParser { return mrp } - }, - formData: url.Values{ - "phoneNumber": []string{"+1234567890"}, - }, - expectedStatus: http.StatusBadRequest, - expectedBody: "", - }, - { - name: "Process error", - setupMocks: func(mh *httpmocks.MockRequestHandler, mrp *httpmocks.MockRequestParser, me *httpmocks.MockEngine) { - mrp.GetSessionIdFunc = func(rq any) (string, error) { - req := rq.(*http.Request) - return req.FormValue("phoneNumber"), nil - } - mrp.GetInputFunc = func(rq any) ([]byte, error) { - req := rq.(*http.Request) - text := req.FormValue("text") - parts := strings.Split(text, "*") - return []byte(parts[len(parts)-1]), nil - } - mh.ProcessFunc = func(rqs handlers.RequestSession) (handlers.RequestSession, error) { - return rqs, handlers.ErrStorage - } - mh.GetConfigFunc = func() engine.Config { return engine.Config{} } - mh.GetRequestParserFunc = func() handlers.RequestParser { return mrp } - }, - formData: url.Values{ - "phoneNumber": []string{"+1234567890"}, - "text": []string{"1*2*3"}, - }, - expectedStatus: http.StatusInternalServerError, - expectedBody: "", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mockHandler := &httpmocks.MockRequestHandler{} - mockRequestParser := &httpmocks.MockRequestParser{} - mockEngine := &httpmocks.MockEngine{} - tt.setupMocks(mockHandler, mockRequestParser, mockEngine) - - ash := NewATSessionHandler(mockHandler) - - req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(tt.formData.Encode())) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - w := httptest.NewRecorder() - - ash.ServeHTTP(w, req) - - if w.Code != tt.expectedStatus { - t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code) - } - - if tt.expectedBody != "" && w.Body.String() != tt.expectedBody { - t.Errorf("Expected body %q, got %q", tt.expectedBody, w.Body.String()) - } - }) - } -} - -func TestATSessionHandler_Output(t *testing.T) { - tests := []struct { - name string - input handlers.RequestSession - expectedPrefix string - expectedError bool - }{ - { - name: "Continue true", - input: handlers.RequestSession{ - Continue: true, - Engine: &httpmocks.MockEngine{ - FlushFunc: func(context.Context, io.Writer) (int, error) { - return 0, nil - }, - }, - Writer: &httpmocks.MockWriter{}, - }, - expectedPrefix: "CON ", - expectedError: false, - }, - { - name: "Continue false", - input: handlers.RequestSession{ - Continue: false, - Engine: &httpmocks.MockEngine{ - FlushFunc: func(context.Context, io.Writer) (int, error) { - return 0, nil - }, - }, - Writer: &httpmocks.MockWriter{}, - }, - expectedPrefix: "END ", - expectedError: false, - }, - { - name: "Flush error", - input: handlers.RequestSession{ - Continue: true, - Engine: &httpmocks.MockEngine{ - FlushFunc: func(context.Context, io.Writer) (int, error) { - return 0, errors.New("write error") - }, - }, - Writer: &httpmocks.MockWriter{}, - }, - expectedPrefix: "CON ", - expectedError: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ash := &ATSessionHandler{} - _, err := ash.Output(tt.input) - - if tt.expectedError && err == nil { - t.Error("Expected an error, but got nil") - } - - if !tt.expectedError && err != nil { - t.Errorf("Unexpected error: %v", err) - } - - mw := tt.input.Writer.(*httpmocks.MockWriter) - if !mw.WriteStringCalled { - t.Error("WriteString was not called") - } - - if mw.WrittenString != tt.expectedPrefix { - t.Errorf("Expected prefix %q, got %q", tt.expectedPrefix, mw.WrittenString) - } - }) - } -} - - diff --git a/internal/http/server.go b/internal/http/server.go index 9cadfa3..7c0d378 100644 --- a/internal/http/server.go +++ b/internal/http/server.go @@ -7,22 +7,16 @@ import ( "git.defalsify.org/vise.git/logging" "git.grassecon.net/urdt/ussd/internal/handlers" + "git.grassecon.net/urdt/ussd/request" ) var ( logg = logging.NewVanilla().WithDomain("httpserver") ) -type SessionHandler struct { - handlers.RequestHandler -} - -func ToSessionHandler(h handlers.RequestHandler) *SessionHandler { - return &SessionHandler{ - RequestHandler: h, - } -} +type SessionHandler request.SessionHandler +// TODO: duplicated func (f *SessionHandler) WriteError(w http.ResponseWriter, code int, err error) { s := err.Error() w.Header().Set("Content-Length", strconv.Itoa(len(s))) @@ -39,7 +33,7 @@ func (f *SessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { var err error var perr error - rqs := handlers.RequestSession{ + rqs := request.RequestSession{ Ctx: req.Context(), Writer: w, } diff --git a/internal/http/server_test.go b/internal/http/server_test.go index a46f98e..2a63f9c 100644 --- a/internal/http/server_test.go +++ b/internal/http/server_test.go @@ -10,6 +10,7 @@ import ( "git.defalsify.org/vise.git/engine" "git.grassecon.net/urdt/ussd/internal/handlers" "git.grassecon.net/urdt/ussd/internal/testutil/mocks/httpmocks" + "git.grassecon.net/urdt/ussd/request" ) // invalidRequestType is a custom type to test invalid request scenarios @@ -80,16 +81,16 @@ func TestSessionHandler_ServeHTTP(t *testing.T) { } mockRequestHandler := &httpmocks.MockRequestHandler{ - ProcessFunc: func(rs handlers.RequestSession) (handlers.RequestSession, error) { + ProcessFunc: func(rs request.RequestSession) (request.RequestSession, error) { return rs, tt.processErr }, - OutputFunc: func(rs handlers.RequestSession) (handlers.RequestSession, error) { + OutputFunc: func(rs request.RequestSession) (request.RequestSession, error) { return rs, tt.outputErr }, - ResetFunc: func(rs handlers.RequestSession) (handlers.RequestSession, error) { + ResetFunc: func(rs request.RequestSession) (request.RequestSession, error) { return rs, tt.resetErr }, - GetRequestParserFunc: func() handlers.RequestParser { + GetRequestParserFunc: func() request.RequestParser { return mockRequestParser }, GetConfigFunc: func() engine.Config { @@ -97,7 +98,7 @@ func TestSessionHandler_ServeHTTP(t *testing.T) { }, } - sessionHandler := ToSessionHandler(mockRequestHandler) + sessionHandler := request.ToSessionHandler(mockRequestHandler) req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBuffer(tt.input)) req.Header.Set("X-Vise-Session", tt.sessionID) diff --git a/internal/storage/storageservice.go b/internal/storage/storageservice.go index 83ce051..33cbf5b 100644 --- a/internal/storage/storageservice.go +++ b/internal/storage/storageservice.go @@ -40,6 +40,11 @@ func NewMenuStorageService(conn ConnData, resourceDir string) *MenuStorageServic } } +func (ms *MenuStorageService) WithResourceDir(resourceDir string) *MenuStorageService { + ms.resourceDir = resourceDir + return ms +} + func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.Db, section string) (db.Db, error) { var newDb db.Db var err error diff --git a/internal/testutil/mocks/httpmocks/requesthandlermock.go b/internal/testutil/mocks/httpmocks/requesthandlermock.go index f17abce..e887711 100644 --- a/internal/testutil/mocks/httpmocks/requesthandlermock.go +++ b/internal/testutil/mocks/httpmocks/requesthandlermock.go @@ -4,21 +4,21 @@ import ( "git.defalsify.org/vise.git/engine" "git.defalsify.org/vise.git/persist" "git.defalsify.org/vise.git/resource" - "git.grassecon.net/urdt/ussd/internal/handlers" + "git.grassecon.net/urdt/ussd/request" ) -// MockRequestHandler implements handlers.RequestHandler interface for testing +// MockRequestHandler implements request.RequestHandler interface for testing type MockRequestHandler struct { - ProcessFunc func(handlers.RequestSession) (handlers.RequestSession, error) + ProcessFunc func(request.RequestSession) (request.RequestSession, error) GetConfigFunc func() engine.Config GetEngineFunc func(cfg engine.Config, rs resource.Resource, pe *persist.Persister) engine.Engine - OutputFunc func(rs handlers.RequestSession) (handlers.RequestSession, error) - ResetFunc func(rs handlers.RequestSession) (handlers.RequestSession, error) + OutputFunc func(rs request.RequestSession) (request.RequestSession, error) + ResetFunc func(rs request.RequestSession) (request.RequestSession, error) ShutdownFunc func() - GetRequestParserFunc func() handlers.RequestParser + GetRequestParserFunc func() request.RequestParser } -func (m *MockRequestHandler) Process(rqs handlers.RequestSession) (handlers.RequestSession, error) { +func (m *MockRequestHandler) Process(rqs request.RequestSession) (request.RequestSession, error) { return m.ProcessFunc(rqs) } @@ -30,11 +30,11 @@ func (m *MockRequestHandler) GetEngine(cfg engine.Config, rs resource.Resource, return m.GetEngineFunc(cfg, rs, pe) } -func (m *MockRequestHandler) Output(rs handlers.RequestSession) (handlers.RequestSession, error) { +func (m *MockRequestHandler) Output(rs request.RequestSession) (request.RequestSession, error) { return m.OutputFunc(rs) } -func (m *MockRequestHandler) Reset(rs handlers.RequestSession) (handlers.RequestSession, error) { +func (m *MockRequestHandler) Reset(rs request.RequestSession) (request.RequestSession, error) { return m.ResetFunc(rs) } @@ -42,6 +42,6 @@ func (m *MockRequestHandler) Shutdown() { m.ShutdownFunc() } -func (m *MockRequestHandler) GetRequestParser() handlers.RequestParser { +func (m *MockRequestHandler) GetRequestParser() request.RequestParser { return m.GetRequestParserFunc() } diff --git a/request/request.go b/request/request.go new file mode 100644 index 0000000..dc400c5 --- /dev/null +++ b/request/request.go @@ -0,0 +1,70 @@ +package request + +import ( + "context" + "fmt" + "io" + "net/http" + "strconv" + + "git.defalsify.org/vise.git/resource" + "git.defalsify.org/vise.git/persist" + "git.defalsify.org/vise.git/engine" + "git.defalsify.org/vise.git/logging" + "git.grassecon.net/urdt/ussd/internal/storage" +) + +var ( + logg = logging.NewVanilla().WithDomain("visedriver.request") +) + +type RequestSession struct { + Ctx context.Context + Config engine.Config + Engine engine.Engine + Input []byte + Storage *storage.Storage + Writer io.Writer + Continue bool +} + + +// TODO: seems like can remove this. +type RequestParser interface { + GetSessionId(rq any) (string, error) + GetInput(rq any) ([]byte, error) +} + +type RequestHandler interface { + GetConfig() engine.Config + GetRequestParser() RequestParser + GetEngine(cfg engine.Config, rs resource.Resource, pe *persist.Persister) engine.Engine + Process(rs RequestSession) (RequestSession, error) + Output(rs RequestSession) (RequestSession, error) + Reset(rs RequestSession) (RequestSession, error) + Shutdown() +} +type SessionHandler struct { + RequestHandler +} + +func (f *SessionHandler) WriteError(w http.ResponseWriter, code int, err error) { + s := err.Error() + w.Header().Set("Content-Length", strconv.Itoa(len(s))) + w.WriteHeader(code) + _, err = w.Write([]byte(s)) + if err != nil { + logg.Errorf("error writing error!!", "err", err, "olderr", s) + w.WriteHeader(500) + } +} + +func (f *SessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + f.WriteError(w, 500, fmt.Errorf("not implemented")) +} + +func ToSessionHandler(h RequestHandler) *SessionHandler { + return &SessionHandler{ + RequestHandler: h, + } +}