diff --git a/cmd/africastalking/main.go b/cmd/africastalking/main.go index ca88978..0019239 100644 --- a/cmd/africastalking/main.go +++ b/cmd/africastalking/main.go @@ -1,35 +1,31 @@ package main import ( - "bytes" "context" - "encoding/json" "flag" "fmt" - "io" "net/http" "os" "os/signal" "path" "strconv" - "strings" "syscall" "git.defalsify.org/vise.git/engine" "git.defalsify.org/vise.git/logging" "git.defalsify.org/vise.git/resource" - "git.grassecon.net/urdt/ussd/common" "git.grassecon.net/urdt/ussd/config" "git.grassecon.net/urdt/ussd/initializers" "git.grassecon.net/urdt/ussd/internal/handlers" - httpserver "git.grassecon.net/urdt/ussd/internal/http" + "git.grassecon.net/urdt/ussd/internal/http/at" + httpserver "git.grassecon.net/urdt/ussd/internal/http/at" "git.grassecon.net/urdt/ussd/internal/storage" "git.grassecon.net/urdt/ussd/remote" ) var ( - logg = logging.NewVanilla() + logg = logging.NewVanilla().WithDomain("AfricasTalking").WithContextKey("at-session-id") scriptDir = path.Join("services", "registration") build = "dev" menuSeparator = ": " @@ -38,72 +34,6 @@ var ( func init() { initializers.LoadEnvVariables() } - -type atRequestParser struct{} - -func (arp *atRequestParser) GetSessionId(rq any) (string, error) { - rqv, ok := rq.(*http.Request) - if !ok { - logg.Warnf("got an invalid request", "req", rq) - return "", handlers.ErrInvalidRequest - } - - // Capture body (if any) for logging - body, err := io.ReadAll(rqv.Body) - if err != nil { - logg.Warnf("failed to read request body", "err", err) - return "", fmt.Errorf("failed to read request body: %v", err) - } - // Reset the body for further reading - rqv.Body = io.NopCloser(bytes.NewReader(body)) - - // Log the body as JSON - bodyLog := map[string]string{"body": string(body)} - logBytes, err := json.Marshal(bodyLog) - if err != nil { - logg.Warnf("failed to marshal request body", "err", err) - } else { - logg.Debugf("received request", "bytes", logBytes) - } - - if err := rqv.ParseForm(); err != nil { - logg.Warnf("failed to parse form data", "err", err) - return "", fmt.Errorf("failed to parse form data: %v", err) - } - - phoneNumber := rqv.FormValue("phoneNumber") - if phoneNumber == "" { - return "", fmt.Errorf("no phone number found") - } - - formattedNumber, err := common.FormatPhoneNumber(phoneNumber) - if err != nil { - logg.Warnf("failed to format phone number", "err", err) - return "", fmt.Errorf("failed to format number") - } - - return formattedNumber, nil -} - -func (arp *atRequestParser) GetInput(rq any) ([]byte, error) { - rqv, ok := rq.(*http.Request) - if !ok { - return nil, handlers.ErrInvalidRequest - } - if err := rqv.ParseForm(); err != nil { - return nil, fmt.Errorf("failed to parse form data: %v", err) - } - - text := rqv.FormValue("text") - - parts := strings.Split(text, "*") - if len(parts) == 0 { - return nil, fmt.Errorf("no input found") - } - - return []byte(parts[len(parts)-1]), nil -} - func main() { config.LoadConfig() @@ -191,7 +121,9 @@ func main() { } defer stateStore.Close() - rp := &atRequestParser{} + rp := &at.ATRequestParser{ + Context: ctx, + } bsh := handlers.NewBaseSessionHandler(cfg, rs, stateStore, userdataStore, rp, hl) sh := httpserver.NewATSessionHandler(bsh) diff --git a/common/storage.go b/common/storage.go index dff4774..d37bce3 100644 --- a/common/storage.go +++ b/common/storage.go @@ -8,14 +8,15 @@ import ( "git.defalsify.org/vise.git/resource" "git.defalsify.org/vise.git/persist" "git.grassecon.net/urdt/ussd/internal/storage" + dbstorage "git.grassecon.net/urdt/ussd/internal/storage/db" ) func StoreToDb(store *UserDataStore) db.Db { return store.Db } -func StoreToPrefixDb(store *UserDataStore, pfx []byte) storage.PrefixDb { - return storage.NewSubPrefixDb(store.Db, pfx) +func StoreToPrefixDb(store *UserDataStore, pfx []byte) dbstorage.PrefixDb { + return dbstorage.NewSubPrefixDb(store.Db, pfx) } type StorageServices interface { diff --git a/common/transfer_statements.go b/common/transfer_statements.go index 243ef4c..e97437f 100644 --- a/common/transfer_statements.go +++ b/common/transfer_statements.go @@ -6,7 +6,7 @@ import ( "strings" "time" - "git.grassecon.net/urdt/ussd/internal/storage" + dbstorage "git.grassecon.net/urdt/ussd/internal/storage/db" dataserviceapi "github.com/grassrootseconomics/ussd-data-service/pkg/api" ) @@ -56,7 +56,7 @@ func ProcessTransfers(transfers []dataserviceapi.Last10TxResponse) TransferMetad // GetTransferData retrieves and matches transfer data // returns a formatted string of the full transaction/statement -func GetTransferData(ctx context.Context, db storage.PrefixDb, publicKey string, index int) (string, error) { +func GetTransferData(ctx context.Context, db dbstorage.PrefixDb, publicKey string, index int) (string, error) { keys := []DataTyp{DATA_TX_SENDERS, DATA_TX_RECIPIENTS, DATA_TX_VALUES, DATA_TX_ADDRESSES, DATA_TX_HASHES, DATA_TX_DATES, DATA_TX_SYMBOLS} data := make(map[DataTyp]string) diff --git a/common/vouchers.go b/common/vouchers.go index 6cff91d..5dbdb71 100644 --- a/common/vouchers.go +++ b/common/vouchers.go @@ -6,7 +6,7 @@ import ( "math/big" "strings" - "git.grassecon.net/urdt/ussd/internal/storage" + dbstorage "git.grassecon.net/urdt/ussd/internal/storage/db" dataserviceapi "github.com/grassrootseconomics/ussd-data-service/pkg/api" ) @@ -63,7 +63,7 @@ func ScaleDownBalance(balance, decimals string) string { } // GetVoucherData retrieves and matches voucher data -func GetVoucherData(ctx context.Context, db storage.PrefixDb, input string) (*dataserviceapi.TokenHoldings, error) { +func GetVoucherData(ctx context.Context, db dbstorage.PrefixDb, input string) (*dataserviceapi.TokenHoldings, error) { keys := []DataTyp{DATA_VOUCHER_SYMBOLS, DATA_VOUCHER_BALANCES, DATA_VOUCHER_DECIMALS, DATA_VOUCHER_ADDRESSES} data := make(map[DataTyp]string) diff --git a/common/vouchers_test.go b/common/vouchers_test.go index ba6cd60..8b04e4a 100644 --- a/common/vouchers_test.go +++ b/common/vouchers_test.go @@ -10,7 +10,7 @@ import ( visedb "git.defalsify.org/vise.git/db" memdb "git.defalsify.org/vise.git/db/mem" - "git.grassecon.net/urdt/ussd/internal/storage" + dbstorage "git.grassecon.net/urdt/ussd/internal/storage/db" dataserviceapi "github.com/grassrootseconomics/ussd-data-service/pkg/api" ) @@ -86,7 +86,7 @@ func TestGetVoucherData(t *testing.T) { } prefix := ToBytes(visedb.DATATYPE_USERDATA) - spdb := storage.NewSubPrefixDb(db, prefix) + spdb := dbstorage.NewSubPrefixDb(db, prefix) // Test voucher data mockData := map[DataTyp][]byte{ diff --git a/go.mod b/go.mod index 16ccdc3..41c6700 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module git.grassecon.net/urdt/ussd go 1.23.0 require ( - git.defalsify.org/vise.git v0.2.3-0.20241231085136-8582c7e157d9 + git.defalsify.org/vise.git v0.2.3-0.20250103172917-3e190a44568d github.com/alecthomas/assert/v2 v2.2.2 github.com/gofrs/uuid v4.4.0+incompatible github.com/grassrootseconomics/eth-custodial v1.3.0-beta @@ -11,6 +11,7 @@ require ( github.com/joho/godotenv v1.5.1 github.com/peteole/testdata-loader v0.3.0 github.com/stretchr/testify v1.9.0 + golang.org/x/crypto v0.27.0 gopkg.in/leonelquinteros/gotext.v1 v1.3.1 ) @@ -32,7 +33,6 @@ require ( github.com/rogpeppe/go-internal v1.13.1 // indirect github.com/stretchr/objx v0.5.2 // indirect github.com/x448/float16 v0.8.4 // indirect - golang.org/x/crypto v0.27.0 // indirect golang.org/x/sync v0.8.0 // indirect golang.org/x/text v0.18.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 9086cd8..6bef621 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -git.defalsify.org/vise.git v0.2.3-0.20241231085136-8582c7e157d9 h1:O3m+NgWDWtJm8OculT99c4bDMAO4xLe2c8hpCKpsd9g= -git.defalsify.org/vise.git v0.2.3-0.20241231085136-8582c7e157d9/go.mod h1:jyBMe1qTYUz3mmuoC9JQ/TvFeW0vTanCUcPu3H8p4Ck= +git.defalsify.org/vise.git v0.2.3-0.20250103172917-3e190a44568d h1:bPAOVZOX4frSGhfOdcj7kc555f8dc9DmMd2YAyC2AMw= +git.defalsify.org/vise.git v0.2.3-0.20250103172917-3e190a44568d/go.mod h1:jyBMe1qTYUz3mmuoC9JQ/TvFeW0vTanCUcPu3H8p4Ck= github.com/alecthomas/assert/v2 v2.2.2 h1:Z/iVC0xZfWTaFNE6bA3z07T86hd45Xe2eLt6WVy2bbk= github.com/alecthomas/assert/v2 v2.2.2/go.mod h1:pXcQ2Asjp247dahGEmsZ6ru0UVwnkhktn7S0bBDLxvQ= github.com/alecthomas/participle/v2 v2.0.0 h1:Fgrq+MbuSsJwIkw3fEj9h75vDP0Er5JzepJ0/HNHv0g= diff --git a/internal/handlers/ussd/menuhandler.go b/internal/handlers/ussd/menuhandler.go index 3919595..095d77b 100644 --- a/internal/handlers/ussd/menuhandler.go +++ b/internal/handlers/ussd/menuhandler.go @@ -23,12 +23,12 @@ import ( "git.grassecon.net/urdt/ussd/remote" "gopkg.in/leonelquinteros/gotext.v1" - "git.grassecon.net/urdt/ussd/internal/storage" + dbstorage "git.grassecon.net/urdt/ussd/internal/storage/db" dataserviceapi "github.com/grassrootseconomics/ussd-data-service/pkg/api" ) var ( - logg = logging.NewVanilla().WithDomain("ussdmenuhandler") + logg = logging.NewVanilla().WithDomain("ussdmenuhandler").WithContextKey("session-id") scriptDir = path.Join("services", "registration") translationDir = path.Join(scriptDir, "locale") ) @@ -64,7 +64,7 @@ type Handlers struct { adminstore *utils.AdminStore flagManager *asm.FlagParser accountService remote.AccountServiceInterface - prefixDb storage.PrefixDb + prefixDb dbstorage.PrefixDb profile *models.Profile ReplaceSeparatorFunc func(string) string } @@ -80,7 +80,7 @@ func NewHandlers(appFlags *asm.FlagParser, userdataStore db.Db, adminstore *util // Instantiate the SubPrefixDb with "DATATYPE_USERDATA" prefix prefix := common.ToBytes(db.DATATYPE_USERDATA) - prefixDb := storage.NewSubPrefixDb(userdataStore, prefix) + prefixDb := dbstorage.NewSubPrefixDb(userdataStore, prefix) h := &Handlers{ userdataStore: userDb, @@ -122,9 +122,12 @@ func (h *Handlers) Init(ctx context.Context, sym string, input []byte) (resource h.st.Code = []byte{} } - sessionId, _ := ctx.Value("SessionId").(string) - flag_admin_privilege, _ := h.flagManager.GetFlag("flag_admin_privilege") + sessionId, ok := ctx.Value("SessionId").(string) + if ok { + context.WithValue(ctx, "session-id", sessionId) + } + flag_admin_privilege, _ := h.flagManager.GetFlag("flag_admin_privilege") isAdmin, _ := h.adminstore.IsAdmin(sessionId) if isAdmin { diff --git a/internal/handlers/ussd/menuhandler_test.go b/internal/handlers/ussd/menuhandler_test.go index 12ed5c2..914dffc 100644 --- a/internal/handlers/ussd/menuhandler_test.go +++ b/internal/handlers/ussd/menuhandler_test.go @@ -13,7 +13,7 @@ import ( "git.defalsify.org/vise.git/persist" "git.defalsify.org/vise.git/resource" "git.defalsify.org/vise.git/state" - "git.grassecon.net/urdt/ussd/internal/storage" + dbstorage "git.grassecon.net/urdt/ussd/internal/storage/db" "git.grassecon.net/urdt/ussd/internal/testutil/mocks" "git.grassecon.net/urdt/ussd/internal/testutil/testservice" "git.grassecon.net/urdt/ussd/internal/utils" @@ -59,14 +59,14 @@ func InitializeTestStore(t *testing.T) (context.Context, *common.UserDataStore) return ctx, store } -func InitializeTestSubPrefixDb(t *testing.T, ctx context.Context) *storage.SubPrefixDb { +func InitializeTestSubPrefixDb(t *testing.T, ctx context.Context) *dbstorage.SubPrefixDb { db := memdb.NewMemDb() err := db.Connect(ctx, "") if err != nil { t.Fatal(err) } prefix := common.ToBytes(visedb.DATATYPE_USERDATA) - spdb := storage.NewSubPrefixDb(db, prefix) + spdb := dbstorage.NewSubPrefixDb(db, prefix) return spdb } diff --git a/internal/http/at/parse.go b/internal/http/at/parse.go new file mode 100644 index 0000000..d2696ed --- /dev/null +++ b/internal/http/at/parse.go @@ -0,0 +1,121 @@ +package at + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + + "git.grassecon.net/urdt/ussd/common" + "git.grassecon.net/urdt/ussd/internal/handlers" +) + +type ATRequestParser struct { + Context context.Context +} + +func (arp *ATRequestParser) GetSessionId(rq any) (string, error) { + rqv, ok := rq.(*http.Request) + if !ok { + logg.Warnf("got an invalid request", "req", rq) + return "", handlers.ErrInvalidRequest + } + + // Capture body (if any) for logging + body, err := io.ReadAll(rqv.Body) + if err != nil { + logg.Warnf("failed to read request body", "err", err) + return "", fmt.Errorf("failed to read request body: %v", err) + } + // Reset the body for further reading + rqv.Body = io.NopCloser(bytes.NewReader(body)) + + // Log the body as JSON + bodyLog := map[string]string{"body": string(body)} + logBytes, err := json.Marshal(bodyLog) + if err != nil { + logg.Warnf("failed to marshal request body", "err", err) + } else { + decodedStr := string(logBytes) + sessionId, err := extractATSessionId(decodedStr) + if err != nil { + context.WithValue(arp.Context, "at-session-id", sessionId) + } + logg.Debugf("Received request:", decodedStr) + } + + if err := rqv.ParseForm(); err != nil { + logg.Warnf("failed to parse form data", "err", err) + return "", fmt.Errorf("failed to parse form data: %v", err) + } + + phoneNumber := rqv.FormValue("phoneNumber") + if phoneNumber == "" { + return "", fmt.Errorf("no phone number found") + } + + formattedNumber, err := common.FormatPhoneNumber(phoneNumber) + if err != nil { + logg.Warnf("failed to format phone number", "err", err) + return "", fmt.Errorf("failed to format number") + } + + return formattedNumber, nil +} + +func (arp *ATRequestParser) GetInput(rq any) ([]byte, error) { + rqv, ok := rq.(*http.Request) + if !ok { + return nil, handlers.ErrInvalidRequest + } + if err := rqv.ParseForm(); err != nil { + return nil, fmt.Errorf("failed to parse form data: %v", err) + } + + text := rqv.FormValue("text") + + parts := strings.Split(text, "*") + if len(parts) == 0 { + return nil, fmt.Errorf("no input found") + } + + return []byte(parts[len(parts)-1]), nil +} + +func parseQueryParams(query string) map[string]string { + params := make(map[string]string) + + queryParams := strings.Split(query, "&") + for _, param := range queryParams { + // Split each key-value pair by '=' + parts := strings.SplitN(param, "=", 2) + if len(parts) == 2 { + params[parts[0]] = parts[1] + } + } + return params +} + +func extractATSessionId(decodedStr string) (string, error) { + var data map[string]string + err := json.Unmarshal([]byte(decodedStr), &data) + + if err != nil { + logg.Errorf("Error unmarshalling JSON: %v", err) + return "", nil + } + decodedBody, err := url.QueryUnescape(data["body"]) + if err != nil { + logg.Errorf("Error URL-decoding body: %v", err) + return "", nil + } + params := parseQueryParams(decodedBody) + + sessionId := params["sessionId"] + return sessionId, nil + +} diff --git a/internal/http/at_session_handler.go b/internal/http/at/server.go similarity index 79% rename from internal/http/at_session_handler.go rename to internal/http/at/server.go index 25da954..705ff76 100644 --- a/internal/http/at_session_handler.go +++ b/internal/http/at/server.go @@ -1,19 +1,25 @@ -package http +package at import ( "io" "net/http" + "git.defalsify.org/vise.git/logging" "git.grassecon.net/urdt/ussd/internal/handlers" + httpserver "git.grassecon.net/urdt/ussd/internal/http" +) + +var ( + logg = logging.NewVanilla().WithDomain("atserver") ) type ATSessionHandler struct { - *SessionHandler + *httpserver.SessionHandler } func NewATSessionHandler(h handlers.RequestHandler) *ATSessionHandler { return &ATSessionHandler{ - SessionHandler: ToSessionHandler(h), + SessionHandler: httpserver.ToSessionHandler(h), } } @@ -31,14 +37,14 @@ func (ash *ATSessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) cfg.SessionId, err = rp.GetSessionId(req) if err != nil { logg.ErrorCtxf(rqs.Ctx, "", "header processing error", err) - ash.writeError(w, 400, err) + ash.WriteError(w, 400, err) return } rqs.Config = cfg rqs.Input, err = rp.GetInput(req) if err != nil { logg.ErrorCtxf(rqs.Ctx, "", "header processing error", err) - ash.writeError(w, 400, err) + ash.WriteError(w, 400, err) return } @@ -53,7 +59,7 @@ func (ash *ATSessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) } if code != 200 { - ash.writeError(w, 500, err) + ash.WriteError(w, 500, err) return } @@ -61,13 +67,13 @@ func (ash *ATSessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) w.Header().Set("Content-Type", "text/plain") rqs, err = ash.Output(rqs) if err != nil { - ash.writeError(w, 500, err) + ash.WriteError(w, 500, err) return } rqs, err = ash.Reset(rqs) if err != nil { - ash.writeError(w, 500, err) + ash.WriteError(w, 500, err) return } } @@ -89,4 +95,4 @@ func (ash *ATSessionHandler) Output(rqs handlers.RequestSession) (handlers.Reque _, err = rqs.Engine.Flush(rqs.Ctx, rqs.Writer) return rqs, err -} \ No newline at end of file +} diff --git a/internal/http/http_test.go b/internal/http/at/server_test.go similarity index 54% rename from internal/http/http_test.go rename to internal/http/at/server_test.go index 14bb90a..dd45c25 100644 --- a/internal/http/http_test.go +++ b/internal/http/at/server_test.go @@ -1,7 +1,6 @@ -package http +package at import ( - "bytes" "context" "errors" "io" @@ -16,16 +15,6 @@ import ( "git.grassecon.net/urdt/ussd/internal/testutil/mocks/httpmocks" ) -// invalidRequestType is a custom type to test invalid request scenarios -type invalidRequestType struct{} - -// errorReader is a helper type that always returns an error when Read is called -type errorReader struct{} - -func (e *errorReader) Read(p []byte) (n int, err error) { - return 0, errors.New("read error") -} - func TestNewATSessionHandler(t *testing.T) { mockHandler := &httpmocks.MockRequestHandler{} ash := NewATSessionHandler(mockHandler) @@ -242,208 +231,4 @@ func TestATSessionHandler_Output(t *testing.T) { } } -func TestSessionHandler_ServeHTTP(t *testing.T) { - tests := []struct { - name string - sessionID string - input []byte - parserErr error - processErr error - outputErr error - resetErr error - expectedStatus int - }{ - { - name: "Success", - sessionID: "123", - input: []byte("test input"), - expectedStatus: http.StatusOK, - }, - { - name: "Missing Session ID", - sessionID: "", - parserErr: handlers.ErrSessionMissing, - expectedStatus: http.StatusBadRequest, - }, - { - name: "Process Error", - sessionID: "123", - input: []byte("test input"), - processErr: handlers.ErrStorage, - expectedStatus: http.StatusInternalServerError, - }, - { - name: "Output Error", - sessionID: "123", - input: []byte("test input"), - outputErr: errors.New("output error"), - expectedStatus: http.StatusOK, - }, - { - name: "Reset Error", - sessionID: "123", - input: []byte("test input"), - resetErr: errors.New("reset error"), - expectedStatus: http.StatusOK, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mockRequestParser := &httpmocks.MockRequestParser{ - GetSessionIdFunc: func(any) (string, error) { - return tt.sessionID, tt.parserErr - }, - GetInputFunc: func(any) ([]byte, error) { - return tt.input, nil - }, - } - - mockRequestHandler := &httpmocks.MockRequestHandler{ - ProcessFunc: func(rs handlers.RequestSession) (handlers.RequestSession, error) { - return rs, tt.processErr - }, - OutputFunc: func(rs handlers.RequestSession) (handlers.RequestSession, error) { - return rs, tt.outputErr - }, - ResetFunc: func(rs handlers.RequestSession) (handlers.RequestSession, error) { - return rs, tt.resetErr - }, - GetRequestParserFunc: func() handlers.RequestParser { - return mockRequestParser - }, - GetConfigFunc: func() engine.Config { - return engine.Config{} - }, - } - - sessionHandler := ToSessionHandler(mockRequestHandler) - - req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBuffer(tt.input)) - req.Header.Set("X-Vise-Session", tt.sessionID) - - rr := httptest.NewRecorder() - - sessionHandler.ServeHTTP(rr, req) - - if status := rr.Code; status != tt.expectedStatus { - t.Errorf("handler returned wrong status code: got %v want %v", - status, tt.expectedStatus) - } - }) - } -} - -func TestSessionHandler_writeError(t *testing.T) { - handler := &SessionHandler{} - mockWriter := &httpmocks.MockWriter{} - err := errors.New("test error") - - handler.writeError(mockWriter, http.StatusBadRequest, err) - - if mockWriter.WrittenString != "" { - t.Errorf("Expected empty body, got %s", mockWriter.WrittenString) - } -} - -func TestDefaultRequestParser_GetSessionId(t *testing.T) { - tests := []struct { - name string - request any - expectedID string - expectedError error - }{ - { - name: "Valid Session ID", - request: func() *http.Request { - req := httptest.NewRequest(http.MethodPost, "/", nil) - req.Header.Set("X-Vise-Session", "123456") - return req - }(), - expectedID: "123456", - expectedError: nil, - }, - { - name: "Missing Session ID", - request: httptest.NewRequest(http.MethodPost, "/", nil), - expectedID: "", - expectedError: handlers.ErrSessionMissing, - }, - { - name: "Invalid Request Type", - request: invalidRequestType{}, - expectedID: "", - expectedError: handlers.ErrInvalidRequest, - }, - } - - parser := &DefaultRequestParser{} - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - id, err := parser.GetSessionId(tt.request) - - if id != tt.expectedID { - t.Errorf("Expected session ID %s, got %s", tt.expectedID, id) - } - - if err != tt.expectedError { - t.Errorf("Expected error %v, got %v", tt.expectedError, err) - } - }) - } -} - -func TestDefaultRequestParser_GetInput(t *testing.T) { - tests := []struct { - name string - request any - expectedInput []byte - expectedError error - }{ - { - name: "Valid Input", - request: func() *http.Request { - return httptest.NewRequest(http.MethodPost, "/", bytes.NewBufferString("test input")) - }(), - expectedInput: []byte("test input"), - expectedError: nil, - }, - { - name: "Empty Input", - request: httptest.NewRequest(http.MethodPost, "/", nil), - expectedInput: []byte{}, - expectedError: nil, - }, - { - name: "Invalid Request Type", - request: invalidRequestType{}, - expectedInput: nil, - expectedError: handlers.ErrInvalidRequest, - }, - { - name: "Read Error", - request: func() *http.Request { - return httptest.NewRequest(http.MethodPost, "/", &errorReader{}) - }(), - expectedInput: nil, - expectedError: errors.New("read error"), - }, - } - - parser := &DefaultRequestParser{} - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - input, err := parser.GetInput(tt.request) - - if !bytes.Equal(input, tt.expectedInput) { - t.Errorf("Expected input %s, got %s", tt.expectedInput, input) - } - - if err != tt.expectedError && (err == nil || err.Error() != tt.expectedError.Error()) { - t.Errorf("Expected error %v, got %v", tt.expectedError, err) - } - }) - } -} diff --git a/internal/http/parse.go b/internal/http/parse.go new file mode 100644 index 0000000..ec8e00b --- /dev/null +++ b/internal/http/parse.go @@ -0,0 +1,38 @@ +package http + +import ( + "io/ioutil" + "net/http" + + "git.grassecon.net/urdt/ussd/internal/handlers" +) + +type DefaultRequestParser struct { +} + +func (rp *DefaultRequestParser) GetSessionId(rq any) (string, error) { + rqv, ok := rq.(*http.Request) + if !ok { + return "", handlers.ErrInvalidRequest + } + v := rqv.Header.Get("X-Vise-Session") + if v == "" { + return "", handlers.ErrSessionMissing + } + return v, nil +} + +func (rp *DefaultRequestParser) GetInput(rq any) ([]byte, error) { + rqv, ok := rq.(*http.Request) + if !ok { + return nil, handlers.ErrInvalidRequest + } + defer rqv.Body.Close() + v, err := ioutil.ReadAll(rqv.Body) + if err != nil { + return nil, err + } + return v, nil +} + + diff --git a/internal/http/server.go b/internal/http/server.go index a6239c4..9cadfa3 100644 --- a/internal/http/server.go +++ b/internal/http/server.go @@ -1,7 +1,6 @@ package http import ( - "io/ioutil" "net/http" "strconv" @@ -14,34 +13,6 @@ var ( logg = logging.NewVanilla().WithDomain("httpserver") ) -type DefaultRequestParser struct { -} - -func (rp *DefaultRequestParser) GetSessionId(rq any) (string, error) { - rqv, ok := rq.(*http.Request) - if !ok { - return "", handlers.ErrInvalidRequest - } - v := rqv.Header.Get("X-Vise-Session") - if v == "" { - return "", handlers.ErrSessionMissing - } - return v, nil -} - -func (rp *DefaultRequestParser) GetInput(rq any) ([]byte, error) { - rqv, ok := rq.(*http.Request) - if !ok { - return nil, handlers.ErrInvalidRequest - } - defer rqv.Body.Close() - v, err := ioutil.ReadAll(rqv.Body) - if err != nil { - return nil, err - } - return v, nil -} - type SessionHandler struct { handlers.RequestHandler } @@ -52,7 +23,7 @@ func ToSessionHandler(h handlers.RequestHandler) *SessionHandler { } } -func (f *SessionHandler) writeError(w http.ResponseWriter, code int, err error) { +func (f *SessionHandler) WriteError(w http.ResponseWriter, code int, err error) { s := err.Error() w.Header().Set("Content-Length", strconv.Itoa(len(s))) w.WriteHeader(code) @@ -78,13 +49,13 @@ func (f *SessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { cfg.SessionId, err = rp.GetSessionId(req) if err != nil { logg.ErrorCtxf(rqs.Ctx, "", "header processing error", err) - f.writeError(w, 400, err) + f.WriteError(w, 400, err) } rqs.Config = cfg rqs.Input, err = rp.GetInput(req) if err != nil { logg.ErrorCtxf(rqs.Ctx, "", "header processing error", err) - f.writeError(w, 400, err) + f.WriteError(w, 400, err) return } @@ -101,7 +72,7 @@ func (f *SessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { } if code != 200 { - f.writeError(w, 500, err) + f.WriteError(w, 500, err) return } @@ -110,11 +81,11 @@ func (f *SessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { rqs, err = f.Output(rqs) rqs, perr = f.Reset(rqs) if err != nil { - f.writeError(w, 500, err) + f.WriteError(w, 500, err) return } if perr != nil { - f.writeError(w, 500, perr) + f.WriteError(w, 500, perr) return } } diff --git a/internal/http/server_test.go b/internal/http/server_test.go new file mode 100644 index 0000000..a46f98e --- /dev/null +++ b/internal/http/server_test.go @@ -0,0 +1,229 @@ +package http + +import ( + "bytes" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "git.defalsify.org/vise.git/engine" + "git.grassecon.net/urdt/ussd/internal/handlers" + "git.grassecon.net/urdt/ussd/internal/testutil/mocks/httpmocks" +) + +// invalidRequestType is a custom type to test invalid request scenarios +type invalidRequestType struct{} + +// errorReader is a helper type that always returns an error when Read is called +type errorReader struct{} + +func (e *errorReader) Read(p []byte) (n int, err error) { + return 0, errors.New("read error") +} + +func TestSessionHandler_ServeHTTP(t *testing.T) { + tests := []struct { + name string + sessionID string + input []byte + parserErr error + processErr error + outputErr error + resetErr error + expectedStatus int + }{ + { + name: "Success", + sessionID: "123", + input: []byte("test input"), + expectedStatus: http.StatusOK, + }, + { + name: "Missing Session ID", + sessionID: "", + parserErr: handlers.ErrSessionMissing, + expectedStatus: http.StatusBadRequest, + }, + { + name: "Process Error", + sessionID: "123", + input: []byte("test input"), + processErr: handlers.ErrStorage, + expectedStatus: http.StatusInternalServerError, + }, + { + name: "Output Error", + sessionID: "123", + input: []byte("test input"), + outputErr: errors.New("output error"), + expectedStatus: http.StatusOK, + }, + { + name: "Reset Error", + sessionID: "123", + input: []byte("test input"), + resetErr: errors.New("reset error"), + expectedStatus: http.StatusOK, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockRequestParser := &httpmocks.MockRequestParser{ + GetSessionIdFunc: func(any) (string, error) { + return tt.sessionID, tt.parserErr + }, + GetInputFunc: func(any) ([]byte, error) { + return tt.input, nil + }, + } + + mockRequestHandler := &httpmocks.MockRequestHandler{ + ProcessFunc: func(rs handlers.RequestSession) (handlers.RequestSession, error) { + return rs, tt.processErr + }, + OutputFunc: func(rs handlers.RequestSession) (handlers.RequestSession, error) { + return rs, tt.outputErr + }, + ResetFunc: func(rs handlers.RequestSession) (handlers.RequestSession, error) { + return rs, tt.resetErr + }, + GetRequestParserFunc: func() handlers.RequestParser { + return mockRequestParser + }, + GetConfigFunc: func() engine.Config { + return engine.Config{} + }, + } + + sessionHandler := ToSessionHandler(mockRequestHandler) + + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBuffer(tt.input)) + req.Header.Set("X-Vise-Session", tt.sessionID) + + rr := httptest.NewRecorder() + + sessionHandler.ServeHTTP(rr, req) + + if status := rr.Code; status != tt.expectedStatus { + t.Errorf("handler returned wrong status code: got %v want %v", + status, tt.expectedStatus) + } + }) + } +} + +func TestSessionHandler_WriteError(t *testing.T) { + handler := &SessionHandler{} + mockWriter := &httpmocks.MockWriter{} + err := errors.New("test error") + + handler.WriteError(mockWriter, http.StatusBadRequest, err) + + if mockWriter.WrittenString != "" { + t.Errorf("Expected empty body, got %s", mockWriter.WrittenString) + } +} + +func TestDefaultRequestParser_GetSessionId(t *testing.T) { + tests := []struct { + name string + request any + expectedID string + expectedError error + }{ + { + name: "Valid Session ID", + request: func() *http.Request { + req := httptest.NewRequest(http.MethodPost, "/", nil) + req.Header.Set("X-Vise-Session", "123456") + return req + }(), + expectedID: "123456", + expectedError: nil, + }, + { + name: "Missing Session ID", + request: httptest.NewRequest(http.MethodPost, "/", nil), + expectedID: "", + expectedError: handlers.ErrSessionMissing, + }, + { + name: "Invalid Request Type", + request: invalidRequestType{}, + expectedID: "", + expectedError: handlers.ErrInvalidRequest, + }, + } + + parser := &DefaultRequestParser{} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + id, err := parser.GetSessionId(tt.request) + + if id != tt.expectedID { + t.Errorf("Expected session ID %s, got %s", tt.expectedID, id) + } + + if err != tt.expectedError { + t.Errorf("Expected error %v, got %v", tt.expectedError, err) + } + }) + } +} + +func TestDefaultRequestParser_GetInput(t *testing.T) { + tests := []struct { + name string + request any + expectedInput []byte + expectedError error + }{ + { + name: "Valid Input", + request: func() *http.Request { + return httptest.NewRequest(http.MethodPost, "/", bytes.NewBufferString("test input")) + }(), + expectedInput: []byte("test input"), + expectedError: nil, + }, + { + name: "Empty Input", + request: httptest.NewRequest(http.MethodPost, "/", nil), + expectedInput: []byte{}, + expectedError: nil, + }, + { + name: "Invalid Request Type", + request: invalidRequestType{}, + expectedInput: nil, + expectedError: handlers.ErrInvalidRequest, + }, + { + name: "Read Error", + request: func() *http.Request { + return httptest.NewRequest(http.MethodPost, "/", &errorReader{}) + }(), + expectedInput: nil, + expectedError: errors.New("read error"), + }, + } + + parser := &DefaultRequestParser{} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + input, err := parser.GetInput(tt.request) + + if !bytes.Equal(input, tt.expectedInput) { + t.Errorf("Expected input %s, got %s", tt.expectedInput, input) + } + + if err != tt.expectedError && (err == nil || err.Error() != tt.expectedError.Error()) { + t.Errorf("Expected error %v, got %v", tt.expectedError, err) + } + }) + } +} diff --git a/internal/storage/gdbm.go b/internal/storage/db/gdbm/gdbm.go similarity index 95% rename from internal/storage/gdbm.go rename to internal/storage/db/gdbm/gdbm.go index 31ebf47..dab767a 100644 --- a/internal/storage/gdbm.go +++ b/internal/storage/db/gdbm/gdbm.go @@ -6,6 +6,11 @@ import ( "git.defalsify.org/vise.git/db" gdbmdb "git.defalsify.org/vise.git/db/gdbm" "git.defalsify.org/vise.git/lang" + "git.defalsify.org/vise.git/logging" +) + +var ( + logg = logging.NewVanilla().WithDomain("gdbmstorage") ) var ( diff --git a/internal/storage/sub_prefix_db.go b/internal/storage/db/sub_prefix_db.go similarity index 100% rename from internal/storage/sub_prefix_db.go rename to internal/storage/db/sub_prefix_db.go diff --git a/internal/storage/sub_prefix_db_test.go b/internal/storage/db/sub_prefix_db_test.go similarity index 100% rename from internal/storage/sub_prefix_db_test.go rename to internal/storage/db/sub_prefix_db_test.go diff --git a/internal/storage/storageservice.go b/internal/storage/storageservice.go index ca28bbb..04e75ce 100644 --- a/internal/storage/storageservice.go +++ b/internal/storage/storageservice.go @@ -13,6 +13,7 @@ import ( "git.defalsify.org/vise.git/persist" "git.defalsify.org/vise.git/resource" "git.grassecon.net/urdt/ussd/initializers" + gdbmstorage "git.grassecon.net/urdt/ussd/internal/storage/db/gdbm" ) var ( @@ -75,7 +76,7 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D connStr := buildConnStr() err = newDb.Connect(ctx, connStr) } else { - newDb = NewThreadGdbmDb() + newDb = gdbmstorage.NewThreadGdbmDb() storeFile := path.Join(ms.dbDir, fileName) err = newDb.Connect(ctx, storeFile) }