diff --git a/cmd/africastalking/main.go b/cmd/africastalking/main.go index ca88978..1f142da 100644 --- a/cmd/africastalking/main.go +++ b/cmd/africastalking/main.go @@ -23,7 +23,7 @@ import ( "git.grassecon.net/urdt/ussd/config" "git.grassecon.net/urdt/ussd/initializers" "git.grassecon.net/urdt/ussd/internal/handlers" - httpserver "git.grassecon.net/urdt/ussd/internal/http" + httpserver "git.grassecon.net/urdt/ussd/internal/http/at" "git.grassecon.net/urdt/ussd/internal/storage" "git.grassecon.net/urdt/ussd/remote" ) diff --git a/internal/http/at_session_handler.go b/internal/http/at/server.go similarity index 80% rename from internal/http/at_session_handler.go rename to internal/http/at/server.go index 25da954..9cade3d 100644 --- a/internal/http/at_session_handler.go +++ b/internal/http/at/server.go @@ -4,16 +4,22 @@ import ( "io" "net/http" + "git.defalsify.org/vise.git/logging" "git.grassecon.net/urdt/ussd/internal/handlers" + httpserver "git.grassecon.net/urdt/ussd/internal/http" +) + +var ( + logg = logging.NewVanilla().WithDomain("atserver") ) type ATSessionHandler struct { - *SessionHandler + *httpserver.SessionHandler } func NewATSessionHandler(h handlers.RequestHandler) *ATSessionHandler { return &ATSessionHandler{ - SessionHandler: ToSessionHandler(h), + SessionHandler: httpserver.ToSessionHandler(h), } } @@ -31,14 +37,14 @@ func (ash *ATSessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) cfg.SessionId, err = rp.GetSessionId(req) if err != nil { logg.ErrorCtxf(rqs.Ctx, "", "header processing error", err) - ash.writeError(w, 400, err) + ash.WriteError(w, 400, err) return } rqs.Config = cfg rqs.Input, err = rp.GetInput(req) if err != nil { logg.ErrorCtxf(rqs.Ctx, "", "header processing error", err) - ash.writeError(w, 400, err) + ash.WriteError(w, 400, err) return } @@ -53,7 +59,7 @@ func (ash *ATSessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) } if code != 200 { - ash.writeError(w, 500, err) + ash.WriteError(w, 500, err) return } @@ -61,13 +67,13 @@ func (ash *ATSessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) w.Header().Set("Content-Type", "text/plain") rqs, err = ash.Output(rqs) if err != nil { - ash.writeError(w, 500, err) + ash.WriteError(w, 500, err) return } rqs, err = ash.Reset(rqs) if err != nil { - ash.writeError(w, 500, err) + ash.WriteError(w, 500, err) return } } @@ -89,4 +95,4 @@ func (ash *ATSessionHandler) Output(rqs handlers.RequestSession) (handlers.Reque _, err = rqs.Engine.Flush(rqs.Ctx, rqs.Writer) return rqs, err -} \ No newline at end of file +} diff --git a/internal/http/http_test.go b/internal/http/at/server_test.go similarity index 55% rename from internal/http/http_test.go rename to internal/http/at/server_test.go index 14bb90a..d49f9ce 100644 --- a/internal/http/http_test.go +++ b/internal/http/at/server_test.go @@ -1,7 +1,6 @@ package http import ( - "bytes" "context" "errors" "io" @@ -16,16 +15,6 @@ import ( "git.grassecon.net/urdt/ussd/internal/testutil/mocks/httpmocks" ) -// invalidRequestType is a custom type to test invalid request scenarios -type invalidRequestType struct{} - -// errorReader is a helper type that always returns an error when Read is called -type errorReader struct{} - -func (e *errorReader) Read(p []byte) (n int, err error) { - return 0, errors.New("read error") -} - func TestNewATSessionHandler(t *testing.T) { mockHandler := &httpmocks.MockRequestHandler{} ash := NewATSessionHandler(mockHandler) @@ -242,208 +231,4 @@ func TestATSessionHandler_Output(t *testing.T) { } } -func TestSessionHandler_ServeHTTP(t *testing.T) { - tests := []struct { - name string - sessionID string - input []byte - parserErr error - processErr error - outputErr error - resetErr error - expectedStatus int - }{ - { - name: "Success", - sessionID: "123", - input: []byte("test input"), - expectedStatus: http.StatusOK, - }, - { - name: "Missing Session ID", - sessionID: "", - parserErr: handlers.ErrSessionMissing, - expectedStatus: http.StatusBadRequest, - }, - { - name: "Process Error", - sessionID: "123", - input: []byte("test input"), - processErr: handlers.ErrStorage, - expectedStatus: http.StatusInternalServerError, - }, - { - name: "Output Error", - sessionID: "123", - input: []byte("test input"), - outputErr: errors.New("output error"), - expectedStatus: http.StatusOK, - }, - { - name: "Reset Error", - sessionID: "123", - input: []byte("test input"), - resetErr: errors.New("reset error"), - expectedStatus: http.StatusOK, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mockRequestParser := &httpmocks.MockRequestParser{ - GetSessionIdFunc: func(any) (string, error) { - return tt.sessionID, tt.parserErr - }, - GetInputFunc: func(any) ([]byte, error) { - return tt.input, nil - }, - } - - mockRequestHandler := &httpmocks.MockRequestHandler{ - ProcessFunc: func(rs handlers.RequestSession) (handlers.RequestSession, error) { - return rs, tt.processErr - }, - OutputFunc: func(rs handlers.RequestSession) (handlers.RequestSession, error) { - return rs, tt.outputErr - }, - ResetFunc: func(rs handlers.RequestSession) (handlers.RequestSession, error) { - return rs, tt.resetErr - }, - GetRequestParserFunc: func() handlers.RequestParser { - return mockRequestParser - }, - GetConfigFunc: func() engine.Config { - return engine.Config{} - }, - } - - sessionHandler := ToSessionHandler(mockRequestHandler) - - req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBuffer(tt.input)) - req.Header.Set("X-Vise-Session", tt.sessionID) - - rr := httptest.NewRecorder() - - sessionHandler.ServeHTTP(rr, req) - - if status := rr.Code; status != tt.expectedStatus { - t.Errorf("handler returned wrong status code: got %v want %v", - status, tt.expectedStatus) - } - }) - } -} - -func TestSessionHandler_writeError(t *testing.T) { - handler := &SessionHandler{} - mockWriter := &httpmocks.MockWriter{} - err := errors.New("test error") - - handler.writeError(mockWriter, http.StatusBadRequest, err) - - if mockWriter.WrittenString != "" { - t.Errorf("Expected empty body, got %s", mockWriter.WrittenString) - } -} - -func TestDefaultRequestParser_GetSessionId(t *testing.T) { - tests := []struct { - name string - request any - expectedID string - expectedError error - }{ - { - name: "Valid Session ID", - request: func() *http.Request { - req := httptest.NewRequest(http.MethodPost, "/", nil) - req.Header.Set("X-Vise-Session", "123456") - return req - }(), - expectedID: "123456", - expectedError: nil, - }, - { - name: "Missing Session ID", - request: httptest.NewRequest(http.MethodPost, "/", nil), - expectedID: "", - expectedError: handlers.ErrSessionMissing, - }, - { - name: "Invalid Request Type", - request: invalidRequestType{}, - expectedID: "", - expectedError: handlers.ErrInvalidRequest, - }, - } - - parser := &DefaultRequestParser{} - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - id, err := parser.GetSessionId(tt.request) - - if id != tt.expectedID { - t.Errorf("Expected session ID %s, got %s", tt.expectedID, id) - } - - if err != tt.expectedError { - t.Errorf("Expected error %v, got %v", tt.expectedError, err) - } - }) - } -} - -func TestDefaultRequestParser_GetInput(t *testing.T) { - tests := []struct { - name string - request any - expectedInput []byte - expectedError error - }{ - { - name: "Valid Input", - request: func() *http.Request { - return httptest.NewRequest(http.MethodPost, "/", bytes.NewBufferString("test input")) - }(), - expectedInput: []byte("test input"), - expectedError: nil, - }, - { - name: "Empty Input", - request: httptest.NewRequest(http.MethodPost, "/", nil), - expectedInput: []byte{}, - expectedError: nil, - }, - { - name: "Invalid Request Type", - request: invalidRequestType{}, - expectedInput: nil, - expectedError: handlers.ErrInvalidRequest, - }, - { - name: "Read Error", - request: func() *http.Request { - return httptest.NewRequest(http.MethodPost, "/", &errorReader{}) - }(), - expectedInput: nil, - expectedError: errors.New("read error"), - }, - } - - parser := &DefaultRequestParser{} - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - input, err := parser.GetInput(tt.request) - - if !bytes.Equal(input, tt.expectedInput) { - t.Errorf("Expected input %s, got %s", tt.expectedInput, input) - } - - if err != tt.expectedError && (err == nil || err.Error() != tt.expectedError.Error()) { - t.Errorf("Expected error %v, got %v", tt.expectedError, err) - } - }) - } -} diff --git a/internal/http/server.go b/internal/http/server.go index a6239c4..df15407 100644 --- a/internal/http/server.go +++ b/internal/http/server.go @@ -52,7 +52,7 @@ func ToSessionHandler(h handlers.RequestHandler) *SessionHandler { } } -func (f *SessionHandler) writeError(w http.ResponseWriter, code int, err error) { +func (f *SessionHandler) WriteError(w http.ResponseWriter, code int, err error) { s := err.Error() w.Header().Set("Content-Length", strconv.Itoa(len(s))) w.WriteHeader(code) @@ -78,13 +78,13 @@ func (f *SessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { cfg.SessionId, err = rp.GetSessionId(req) if err != nil { logg.ErrorCtxf(rqs.Ctx, "", "header processing error", err) - f.writeError(w, 400, err) + f.WriteError(w, 400, err) } rqs.Config = cfg rqs.Input, err = rp.GetInput(req) if err != nil { logg.ErrorCtxf(rqs.Ctx, "", "header processing error", err) - f.writeError(w, 400, err) + f.WriteError(w, 400, err) return } @@ -101,7 +101,7 @@ func (f *SessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { } if code != 200 { - f.writeError(w, 500, err) + f.WriteError(w, 500, err) return } @@ -110,11 +110,11 @@ func (f *SessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { rqs, err = f.Output(rqs) rqs, perr = f.Reset(rqs) if err != nil { - f.writeError(w, 500, err) + f.WriteError(w, 500, err) return } if perr != nil { - f.writeError(w, 500, perr) + f.WriteError(w, 500, perr) return } } diff --git a/internal/http/server_test.go b/internal/http/server_test.go new file mode 100644 index 0000000..a46f98e --- /dev/null +++ b/internal/http/server_test.go @@ -0,0 +1,229 @@ +package http + +import ( + "bytes" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "git.defalsify.org/vise.git/engine" + "git.grassecon.net/urdt/ussd/internal/handlers" + "git.grassecon.net/urdt/ussd/internal/testutil/mocks/httpmocks" +) + +// invalidRequestType is a custom type to test invalid request scenarios +type invalidRequestType struct{} + +// errorReader is a helper type that always returns an error when Read is called +type errorReader struct{} + +func (e *errorReader) Read(p []byte) (n int, err error) { + return 0, errors.New("read error") +} + +func TestSessionHandler_ServeHTTP(t *testing.T) { + tests := []struct { + name string + sessionID string + input []byte + parserErr error + processErr error + outputErr error + resetErr error + expectedStatus int + }{ + { + name: "Success", + sessionID: "123", + input: []byte("test input"), + expectedStatus: http.StatusOK, + }, + { + name: "Missing Session ID", + sessionID: "", + parserErr: handlers.ErrSessionMissing, + expectedStatus: http.StatusBadRequest, + }, + { + name: "Process Error", + sessionID: "123", + input: []byte("test input"), + processErr: handlers.ErrStorage, + expectedStatus: http.StatusInternalServerError, + }, + { + name: "Output Error", + sessionID: "123", + input: []byte("test input"), + outputErr: errors.New("output error"), + expectedStatus: http.StatusOK, + }, + { + name: "Reset Error", + sessionID: "123", + input: []byte("test input"), + resetErr: errors.New("reset error"), + expectedStatus: http.StatusOK, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockRequestParser := &httpmocks.MockRequestParser{ + GetSessionIdFunc: func(any) (string, error) { + return tt.sessionID, tt.parserErr + }, + GetInputFunc: func(any) ([]byte, error) { + return tt.input, nil + }, + } + + mockRequestHandler := &httpmocks.MockRequestHandler{ + ProcessFunc: func(rs handlers.RequestSession) (handlers.RequestSession, error) { + return rs, tt.processErr + }, + OutputFunc: func(rs handlers.RequestSession) (handlers.RequestSession, error) { + return rs, tt.outputErr + }, + ResetFunc: func(rs handlers.RequestSession) (handlers.RequestSession, error) { + return rs, tt.resetErr + }, + GetRequestParserFunc: func() handlers.RequestParser { + return mockRequestParser + }, + GetConfigFunc: func() engine.Config { + return engine.Config{} + }, + } + + sessionHandler := ToSessionHandler(mockRequestHandler) + + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBuffer(tt.input)) + req.Header.Set("X-Vise-Session", tt.sessionID) + + rr := httptest.NewRecorder() + + sessionHandler.ServeHTTP(rr, req) + + if status := rr.Code; status != tt.expectedStatus { + t.Errorf("handler returned wrong status code: got %v want %v", + status, tt.expectedStatus) + } + }) + } +} + +func TestSessionHandler_WriteError(t *testing.T) { + handler := &SessionHandler{} + mockWriter := &httpmocks.MockWriter{} + err := errors.New("test error") + + handler.WriteError(mockWriter, http.StatusBadRequest, err) + + if mockWriter.WrittenString != "" { + t.Errorf("Expected empty body, got %s", mockWriter.WrittenString) + } +} + +func TestDefaultRequestParser_GetSessionId(t *testing.T) { + tests := []struct { + name string + request any + expectedID string + expectedError error + }{ + { + name: "Valid Session ID", + request: func() *http.Request { + req := httptest.NewRequest(http.MethodPost, "/", nil) + req.Header.Set("X-Vise-Session", "123456") + return req + }(), + expectedID: "123456", + expectedError: nil, + }, + { + name: "Missing Session ID", + request: httptest.NewRequest(http.MethodPost, "/", nil), + expectedID: "", + expectedError: handlers.ErrSessionMissing, + }, + { + name: "Invalid Request Type", + request: invalidRequestType{}, + expectedID: "", + expectedError: handlers.ErrInvalidRequest, + }, + } + + parser := &DefaultRequestParser{} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + id, err := parser.GetSessionId(tt.request) + + if id != tt.expectedID { + t.Errorf("Expected session ID %s, got %s", tt.expectedID, id) + } + + if err != tt.expectedError { + t.Errorf("Expected error %v, got %v", tt.expectedError, err) + } + }) + } +} + +func TestDefaultRequestParser_GetInput(t *testing.T) { + tests := []struct { + name string + request any + expectedInput []byte + expectedError error + }{ + { + name: "Valid Input", + request: func() *http.Request { + return httptest.NewRequest(http.MethodPost, "/", bytes.NewBufferString("test input")) + }(), + expectedInput: []byte("test input"), + expectedError: nil, + }, + { + name: "Empty Input", + request: httptest.NewRequest(http.MethodPost, "/", nil), + expectedInput: []byte{}, + expectedError: nil, + }, + { + name: "Invalid Request Type", + request: invalidRequestType{}, + expectedInput: nil, + expectedError: handlers.ErrInvalidRequest, + }, + { + name: "Read Error", + request: func() *http.Request { + return httptest.NewRequest(http.MethodPost, "/", &errorReader{}) + }(), + expectedInput: nil, + expectedError: errors.New("read error"), + }, + } + + parser := &DefaultRequestParser{} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + input, err := parser.GetInput(tt.request) + + if !bytes.Equal(input, tt.expectedInput) { + t.Errorf("Expected input %s, got %s", tt.expectedInput, input) + } + + if err != tt.expectedError && (err == nil || err.Error() != tt.expectedError.Error()) { + t.Errorf("Expected error %v, got %v", tt.expectedError, err) + } + }) + } +}