diff --git a/cmd/africastalking/main.go b/cmd/africastalking/main.go index a7fa74b..3ac1591 100644 --- a/cmd/africastalking/main.go +++ b/cmd/africastalking/main.go @@ -39,113 +39,6 @@ var ( func init() { initializers.LoadEnvVariables() } - -type atRequestParser struct { - context context.Context -} - -func parseQueryParams(query string) map[string]string { - params := make(map[string]string) - - queryParams := strings.Split(query, "&") - for _, param := range queryParams { - // Split each key-value pair by '=' - parts := strings.SplitN(param, "=", 2) - if len(parts) == 2 { - params[parts[0]] = parts[1] - } - } - return params -} - -func extractATSessionId(decodedStr string) (string, error) { - var data map[string]string - err := json.Unmarshal([]byte(decodedStr), &data) - - if err != nil { - logg.Errorf("Error unmarshalling JSON: %v", err) - return "", nil - } - decodedBody, err := url.QueryUnescape(data["body"]) - if err != nil { - logg.Errorf("Error URL-decoding body: %v", err) - return "", nil - } - params := parseQueryParams(decodedBody) - - sessionId := params["sessionId"] - return sessionId, nil - -} - -func (arp *atRequestParser) GetSessionId(rq any) (string, error) { - rqv, ok := rq.(*http.Request) - if !ok { - logg.Warnf("got an invalid request", "req", rq) - return "", handlers.ErrInvalidRequest - } - - // Capture body (if any) for logging - body, err := io.ReadAll(rqv.Body) - if err != nil { - logg.Warnf("failed to read request body", "err", err) - return "", fmt.Errorf("failed to read request body: %v", err) - } - // Reset the body for further reading - rqv.Body = io.NopCloser(bytes.NewReader(body)) - - // Log the body as JSON - bodyLog := map[string]string{"body": string(body)} - logBytes, err := json.Marshal(bodyLog) - if err != nil { - logg.Warnf("failed to marshal request body", "err", err) - } else { - decodedStr := string(logBytes) - sessionId, err := extractATSessionId(decodedStr) - if err != nil { - context.WithValue(arp.context, "at-session-id", sessionId) - } - logg.Debugf("Received request:", decodedStr) - } - - if err := rqv.ParseForm(); err != nil { - logg.Warnf("failed to parse form data", "err", err) - return "", fmt.Errorf("failed to parse form data: %v", err) - } - - phoneNumber := rqv.FormValue("phoneNumber") - if phoneNumber == "" { - return "", fmt.Errorf("no phone number found") - } - - formattedNumber, err := common.FormatPhoneNumber(phoneNumber) - if err != nil { - logg.Warnf("failed to format phone number", "err", err) - return "", fmt.Errorf("failed to format number") - } - - return formattedNumber, nil -} - -func (arp *atRequestParser) GetInput(rq any) ([]byte, error) { - rqv, ok := rq.(*http.Request) - if !ok { - return nil, handlers.ErrInvalidRequest - } - if err := rqv.ParseForm(); err != nil { - return nil, fmt.Errorf("failed to parse form data: %v", err) - } - - text := rqv.FormValue("text") - - parts := strings.Split(text, "*") - if len(parts) == 0 { - return nil, fmt.Errorf("no input found") - } - - return []byte(parts[len(parts)-1]), nil -} - func main() { config.LoadConfig() @@ -233,7 +126,7 @@ func main() { } defer stateStore.Close() - rp := &atRequestParser{ + rp := &at.ATRequestParser{ context: ctx, } bsh := handlers.NewBaseSessionHandler(cfg, rs, stateStore, userdataStore, rp, hl) diff --git a/internal/http/at/parse.go b/internal/http/at/parse.go new file mode 100644 index 0000000..a40cf0f --- /dev/null +++ b/internal/http/at/parse.go @@ -0,0 +1,121 @@ +package at + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + + "git.grassecon.net/urdt/ussd/common" + "git.grassecon.net/urdt/ussd/internal/handlers" +) + +type ATRequestParser struct { + context context.Context +} + +func (arp *ATRequestParser) GetSessionId(rq any) (string, error) { + rqv, ok := rq.(*http.Request) + if !ok { + logg.Warnf("got an invalid request", "req", rq) + return "", handlers.ErrInvalidRequest + } + + // Capture body (if any) for logging + body, err := io.ReadAll(rqv.Body) + if err != nil { + logg.Warnf("failed to read request body", "err", err) + return "", fmt.Errorf("failed to read request body: %v", err) + } + // Reset the body for further reading + rqv.Body = io.NopCloser(bytes.NewReader(body)) + + // Log the body as JSON + bodyLog := map[string]string{"body": string(body)} + logBytes, err := json.Marshal(bodyLog) + if err != nil { + logg.Warnf("failed to marshal request body", "err", err) + } else { + decodedStr := string(logBytes) + sessionId, err := extractATSessionId(decodedStr) + if err != nil { + context.WithValue(arp.context, "at-session-id", sessionId) + } + logg.Debugf("Received request:", decodedStr) + } + + if err := rqv.ParseForm(); err != nil { + logg.Warnf("failed to parse form data", "err", err) + return "", fmt.Errorf("failed to parse form data: %v", err) + } + + phoneNumber := rqv.FormValue("phoneNumber") + if phoneNumber == "" { + return "", fmt.Errorf("no phone number found") + } + + formattedNumber, err := common.FormatPhoneNumber(phoneNumber) + if err != nil { + logg.Warnf("failed to format phone number", "err", err) + return "", fmt.Errorf("failed to format number") + } + + return formattedNumber, nil +} + +func (arp *ATRequestParser) GetInput(rq any) ([]byte, error) { + rqv, ok := rq.(*http.Request) + if !ok { + return nil, handlers.ErrInvalidRequest + } + if err := rqv.ParseForm(); err != nil { + return nil, fmt.Errorf("failed to parse form data: %v", err) + } + + text := rqv.FormValue("text") + + parts := strings.Split(text, "*") + if len(parts) == 0 { + return nil, fmt.Errorf("no input found") + } + + return []byte(parts[len(parts)-1]), nil +} + +func parseQueryParams(query string) map[string]string { + params := make(map[string]string) + + queryParams := strings.Split(query, "&") + for _, param := range queryParams { + // Split each key-value pair by '=' + parts := strings.SplitN(param, "=", 2) + if len(parts) == 2 { + params[parts[0]] = parts[1] + } + } + return params +} + +func extractATSessionId(decodedStr string) (string, error) { + var data map[string]string + err := json.Unmarshal([]byte(decodedStr), &data) + + if err != nil { + logg.Errorf("Error unmarshalling JSON: %v", err) + return "", nil + } + decodedBody, err := url.QueryUnescape(data["body"]) + if err != nil { + logg.Errorf("Error URL-decoding body: %v", err) + return "", nil + } + params := parseQueryParams(decodedBody) + + sessionId := params["sessionId"] + return sessionId, nil + +} diff --git a/internal/http/at/server.go b/internal/http/at/server.go index 9cade3d..705ff76 100644 --- a/internal/http/at/server.go +++ b/internal/http/at/server.go @@ -1,4 +1,4 @@ -package http +package at import ( "io" diff --git a/internal/http/at/server_test.go b/internal/http/at/server_test.go index d49f9ce..dd45c25 100644 --- a/internal/http/at/server_test.go +++ b/internal/http/at/server_test.go @@ -1,4 +1,4 @@ -package http +package at import ( "context" diff --git a/internal/http/parse.go b/internal/http/parse.go new file mode 100644 index 0000000..ec8e00b --- /dev/null +++ b/internal/http/parse.go @@ -0,0 +1,38 @@ +package http + +import ( + "io/ioutil" + "net/http" + + "git.grassecon.net/urdt/ussd/internal/handlers" +) + +type DefaultRequestParser struct { +} + +func (rp *DefaultRequestParser) GetSessionId(rq any) (string, error) { + rqv, ok := rq.(*http.Request) + if !ok { + return "", handlers.ErrInvalidRequest + } + v := rqv.Header.Get("X-Vise-Session") + if v == "" { + return "", handlers.ErrSessionMissing + } + return v, nil +} + +func (rp *DefaultRequestParser) GetInput(rq any) ([]byte, error) { + rqv, ok := rq.(*http.Request) + if !ok { + return nil, handlers.ErrInvalidRequest + } + defer rqv.Body.Close() + v, err := ioutil.ReadAll(rqv.Body) + if err != nil { + return nil, err + } + return v, nil +} + + diff --git a/internal/http/server.go b/internal/http/server.go index df15407..9cadfa3 100644 --- a/internal/http/server.go +++ b/internal/http/server.go @@ -1,7 +1,6 @@ package http import ( - "io/ioutil" "net/http" "strconv" @@ -14,34 +13,6 @@ var ( logg = logging.NewVanilla().WithDomain("httpserver") ) -type DefaultRequestParser struct { -} - -func (rp *DefaultRequestParser) GetSessionId(rq any) (string, error) { - rqv, ok := rq.(*http.Request) - if !ok { - return "", handlers.ErrInvalidRequest - } - v := rqv.Header.Get("X-Vise-Session") - if v == "" { - return "", handlers.ErrSessionMissing - } - return v, nil -} - -func (rp *DefaultRequestParser) GetInput(rq any) ([]byte, error) { - rqv, ok := rq.(*http.Request) - if !ok { - return nil, handlers.ErrInvalidRequest - } - defer rqv.Body.Close() - v, err := ioutil.ReadAll(rqv.Body) - if err != nil { - return nil, err - } - return v, nil -} - type SessionHandler struct { handlers.RequestHandler }