forked from urdt/ussd
450 lines
12 KiB
Go
450 lines
12 KiB
Go
package http
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"strings"
|
|
"testing"
|
|
|
|
"git.defalsify.org/vise.git/engine"
|
|
"git.grassecon.net/urdt/ussd/internal/handlers"
|
|
"git.grassecon.net/urdt/ussd/internal/mocks/httpmocks"
|
|
)
|
|
|
|
// invalidRequestType is a custom type to test invalid request scenarios
|
|
type invalidRequestType struct{}
|
|
|
|
// errorReader is a helper type that always returns an error when Read is called
|
|
type errorReader struct{}
|
|
|
|
func (e *errorReader) Read(p []byte) (n int, err error) {
|
|
return 0, errors.New("read error")
|
|
}
|
|
|
|
func TestNewATSessionHandler(t *testing.T) {
|
|
mockHandler := &httpmocks.MockRequestHandler{}
|
|
ash := NewATSessionHandler(mockHandler)
|
|
|
|
if ash == nil {
|
|
t.Fatal("NewATSessionHandler returned nil")
|
|
}
|
|
|
|
if ash.SessionHandler == nil {
|
|
t.Fatal("SessionHandler is nil")
|
|
}
|
|
}
|
|
|
|
func TestATSessionHandler_ServeHTTP(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
setupMocks func(*httpmocks.MockRequestHandler, *httpmocks.MockRequestParser, *httpmocks.MockEngine)
|
|
formData url.Values
|
|
expectedStatus int
|
|
expectedBody string
|
|
}{
|
|
{
|
|
name: "Successful request",
|
|
setupMocks: func(mh *httpmocks.MockRequestHandler, mrp *httpmocks.MockRequestParser, me *httpmocks.MockEngine) {
|
|
mrp.GetSessionIdFunc = func(rq any) (string, error) {
|
|
req := rq.(*http.Request)
|
|
return req.FormValue("phoneNumber"), nil
|
|
}
|
|
mrp.GetInputFunc = func(rq any) ([]byte, error) {
|
|
req := rq.(*http.Request)
|
|
text := req.FormValue("text")
|
|
parts := strings.Split(text, "*")
|
|
return []byte(parts[len(parts)-1]), nil
|
|
}
|
|
mh.ProcessFunc = func(rqs handlers.RequestSession) (handlers.RequestSession, error) {
|
|
rqs.Continue = true
|
|
rqs.Engine = me
|
|
return rqs, nil
|
|
}
|
|
mh.GetConfigFunc = func() engine.Config { return engine.Config{} }
|
|
mh.GetRequestParserFunc = func() handlers.RequestParser { return mrp }
|
|
mh.OutputFunc = func(rs handlers.RequestSession) (handlers.RequestSession, error) { return rs, nil }
|
|
mh.ResetFunc = func(rs handlers.RequestSession) (handlers.RequestSession, error) { return rs, nil }
|
|
me.FlushFunc = func(context.Context, io.Writer) (int, error) { return 0, nil }
|
|
},
|
|
formData: url.Values{
|
|
"phoneNumber": []string{"+1234567890"},
|
|
"text": []string{"1*2*3"},
|
|
},
|
|
expectedStatus: http.StatusOK,
|
|
expectedBody: "CON ",
|
|
},
|
|
{
|
|
name: "GetSessionId error",
|
|
setupMocks: func(mh *httpmocks.MockRequestHandler, mrp *httpmocks.MockRequestParser, me *httpmocks.MockEngine) {
|
|
mrp.GetSessionIdFunc = func(rq any) (string, error) {
|
|
return "", errors.New("no phone number found")
|
|
}
|
|
mh.GetConfigFunc = func() engine.Config { return engine.Config{} }
|
|
mh.GetRequestParserFunc = func() handlers.RequestParser { return mrp }
|
|
},
|
|
formData: url.Values{
|
|
"text": []string{"1*2*3"},
|
|
},
|
|
expectedStatus: http.StatusBadRequest,
|
|
expectedBody: "",
|
|
},
|
|
{
|
|
name: "GetInput error",
|
|
setupMocks: func(mh *httpmocks.MockRequestHandler, mrp *httpmocks.MockRequestParser, me *httpmocks.MockEngine) {
|
|
mrp.GetSessionIdFunc = func(rq any) (string, error) {
|
|
req := rq.(*http.Request)
|
|
return req.FormValue("phoneNumber"), nil
|
|
}
|
|
mrp.GetInputFunc = func(rq any) ([]byte, error) {
|
|
return nil, errors.New("no input found")
|
|
}
|
|
mh.GetConfigFunc = func() engine.Config { return engine.Config{} }
|
|
mh.GetRequestParserFunc = func() handlers.RequestParser { return mrp }
|
|
},
|
|
formData: url.Values{
|
|
"phoneNumber": []string{"+1234567890"},
|
|
},
|
|
expectedStatus: http.StatusBadRequest,
|
|
expectedBody: "",
|
|
},
|
|
{
|
|
name: "Process error",
|
|
setupMocks: func(mh *httpmocks.MockRequestHandler, mrp *httpmocks.MockRequestParser, me *httpmocks.MockEngine) {
|
|
mrp.GetSessionIdFunc = func(rq any) (string, error) {
|
|
req := rq.(*http.Request)
|
|
return req.FormValue("phoneNumber"), nil
|
|
}
|
|
mrp.GetInputFunc = func(rq any) ([]byte, error) {
|
|
req := rq.(*http.Request)
|
|
text := req.FormValue("text")
|
|
parts := strings.Split(text, "*")
|
|
return []byte(parts[len(parts)-1]), nil
|
|
}
|
|
mh.ProcessFunc = func(rqs handlers.RequestSession) (handlers.RequestSession, error) {
|
|
return rqs, handlers.ErrStorage
|
|
}
|
|
mh.GetConfigFunc = func() engine.Config { return engine.Config{} }
|
|
mh.GetRequestParserFunc = func() handlers.RequestParser { return mrp }
|
|
},
|
|
formData: url.Values{
|
|
"phoneNumber": []string{"+1234567890"},
|
|
"text": []string{"1*2*3"},
|
|
},
|
|
expectedStatus: http.StatusInternalServerError,
|
|
expectedBody: "",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
mockHandler := &httpmocks.MockRequestHandler{}
|
|
mockRequestParser := &httpmocks.MockRequestParser{}
|
|
mockEngine := &httpmocks.MockEngine{}
|
|
tt.setupMocks(mockHandler, mockRequestParser, mockEngine)
|
|
|
|
ash := NewATSessionHandler(mockHandler)
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(tt.formData.Encode()))
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
w := httptest.NewRecorder()
|
|
|
|
ash.ServeHTTP(w, req)
|
|
|
|
if w.Code != tt.expectedStatus {
|
|
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
|
|
}
|
|
|
|
if tt.expectedBody != "" && w.Body.String() != tt.expectedBody {
|
|
t.Errorf("Expected body %q, got %q", tt.expectedBody, w.Body.String())
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestATSessionHandler_Output(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
input handlers.RequestSession
|
|
expectedPrefix string
|
|
expectedError bool
|
|
}{
|
|
{
|
|
name: "Continue true",
|
|
input: handlers.RequestSession{
|
|
Continue: true,
|
|
Engine: &httpmocks.MockEngine{
|
|
FlushFunc: func(context.Context, io.Writer) (int, error) {
|
|
return 0, nil
|
|
},
|
|
},
|
|
Writer: &httpmocks.MockWriter{},
|
|
},
|
|
expectedPrefix: "CON ",
|
|
expectedError: false,
|
|
},
|
|
{
|
|
name: "Continue false",
|
|
input: handlers.RequestSession{
|
|
Continue: false,
|
|
Engine: &httpmocks.MockEngine{
|
|
FlushFunc: func(context.Context, io.Writer) (int, error) {
|
|
return 0, nil
|
|
},
|
|
},
|
|
Writer: &httpmocks.MockWriter{},
|
|
},
|
|
expectedPrefix: "END ",
|
|
expectedError: false,
|
|
},
|
|
{
|
|
name: "Flush error",
|
|
input: handlers.RequestSession{
|
|
Continue: true,
|
|
Engine: &httpmocks.MockEngine{
|
|
FlushFunc: func(context.Context, io.Writer) (int, error) {
|
|
return 0, errors.New("write error")
|
|
},
|
|
},
|
|
Writer: &httpmocks.MockWriter{},
|
|
},
|
|
expectedPrefix: "CON ",
|
|
expectedError: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
ash := &ATSessionHandler{}
|
|
_, err := ash.Output(tt.input)
|
|
|
|
if tt.expectedError && err == nil {
|
|
t.Error("Expected an error, but got nil")
|
|
}
|
|
|
|
if !tt.expectedError && err != nil {
|
|
t.Errorf("Unexpected error: %v", err)
|
|
}
|
|
|
|
mw := tt.input.Writer.(*httpmocks.MockWriter)
|
|
if !mw.WriteStringCalled {
|
|
t.Error("WriteString was not called")
|
|
}
|
|
|
|
if mw.WrittenString != tt.expectedPrefix {
|
|
t.Errorf("Expected prefix %q, got %q", tt.expectedPrefix, mw.WrittenString)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSessionHandler_ServeHTTP(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
sessionID string
|
|
input []byte
|
|
parserErr error
|
|
processErr error
|
|
outputErr error
|
|
resetErr error
|
|
expectedStatus int
|
|
}{
|
|
{
|
|
name: "Success",
|
|
sessionID: "123",
|
|
input: []byte("test input"),
|
|
expectedStatus: http.StatusOK,
|
|
},
|
|
{
|
|
name: "Missing Session ID",
|
|
sessionID: "",
|
|
parserErr: handlers.ErrSessionMissing,
|
|
expectedStatus: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "Process Error",
|
|
sessionID: "123",
|
|
input: []byte("test input"),
|
|
processErr: handlers.ErrStorage,
|
|
expectedStatus: http.StatusInternalServerError,
|
|
},
|
|
{
|
|
name: "Output Error",
|
|
sessionID: "123",
|
|
input: []byte("test input"),
|
|
outputErr: errors.New("output error"),
|
|
expectedStatus: http.StatusOK,
|
|
},
|
|
{
|
|
name: "Reset Error",
|
|
sessionID: "123",
|
|
input: []byte("test input"),
|
|
resetErr: errors.New("reset error"),
|
|
expectedStatus: http.StatusOK,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
mockRequestParser := &httpmocks.MockRequestParser{
|
|
GetSessionIdFunc: func(any) (string, error) {
|
|
return tt.sessionID, tt.parserErr
|
|
},
|
|
GetInputFunc: func(any) ([]byte, error) {
|
|
return tt.input, nil
|
|
},
|
|
}
|
|
|
|
mockRequestHandler := &httpmocks.MockRequestHandler{
|
|
ProcessFunc: func(rs handlers.RequestSession) (handlers.RequestSession, error) {
|
|
return rs, tt.processErr
|
|
},
|
|
OutputFunc: func(rs handlers.RequestSession) (handlers.RequestSession, error) {
|
|
return rs, tt.outputErr
|
|
},
|
|
ResetFunc: func(rs handlers.RequestSession) (handlers.RequestSession, error) {
|
|
return rs, tt.resetErr
|
|
},
|
|
GetRequestParserFunc: func() handlers.RequestParser {
|
|
return mockRequestParser
|
|
},
|
|
GetConfigFunc: func() engine.Config {
|
|
return engine.Config{}
|
|
},
|
|
}
|
|
|
|
sessionHandler := ToSessionHandler(mockRequestHandler)
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBuffer(tt.input))
|
|
req.Header.Set("X-Vise-Session", tt.sessionID)
|
|
|
|
rr := httptest.NewRecorder()
|
|
|
|
sessionHandler.ServeHTTP(rr, req)
|
|
|
|
if status := rr.Code; status != tt.expectedStatus {
|
|
t.Errorf("handler returned wrong status code: got %v want %v",
|
|
status, tt.expectedStatus)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSessionHandler_writeError(t *testing.T) {
|
|
handler := &SessionHandler{}
|
|
mockWriter := &httpmocks.MockWriter{}
|
|
err := errors.New("test error")
|
|
|
|
handler.writeError(mockWriter, http.StatusBadRequest, err)
|
|
|
|
if mockWriter.WrittenString != "" {
|
|
t.Errorf("Expected empty body, got %s", mockWriter.WrittenString)
|
|
}
|
|
}
|
|
|
|
func TestDefaultRequestParser_GetSessionId(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
request any
|
|
expectedID string
|
|
expectedError error
|
|
}{
|
|
{
|
|
name: "Valid Session ID",
|
|
request: func() *http.Request {
|
|
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
|
req.Header.Set("X-Vise-Session", "123456")
|
|
return req
|
|
}(),
|
|
expectedID: "123456",
|
|
expectedError: nil,
|
|
},
|
|
{
|
|
name: "Missing Session ID",
|
|
request: httptest.NewRequest(http.MethodPost, "/", nil),
|
|
expectedID: "",
|
|
expectedError: handlers.ErrSessionMissing,
|
|
},
|
|
{
|
|
name: "Invalid Request Type",
|
|
request: invalidRequestType{},
|
|
expectedID: "",
|
|
expectedError: handlers.ErrInvalidRequest,
|
|
},
|
|
}
|
|
|
|
parser := &DefaultRequestParser{}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
id, err := parser.GetSessionId(tt.request)
|
|
|
|
if id != tt.expectedID {
|
|
t.Errorf("Expected session ID %s, got %s", tt.expectedID, id)
|
|
}
|
|
|
|
if err != tt.expectedError {
|
|
t.Errorf("Expected error %v, got %v", tt.expectedError, err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDefaultRequestParser_GetInput(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
request any
|
|
expectedInput []byte
|
|
expectedError error
|
|
}{
|
|
{
|
|
name: "Valid Input",
|
|
request: func() *http.Request {
|
|
return httptest.NewRequest(http.MethodPost, "/", bytes.NewBufferString("test input"))
|
|
}(),
|
|
expectedInput: []byte("test input"),
|
|
expectedError: nil,
|
|
},
|
|
{
|
|
name: "Empty Input",
|
|
request: httptest.NewRequest(http.MethodPost, "/", nil),
|
|
expectedInput: []byte{},
|
|
expectedError: nil,
|
|
},
|
|
{
|
|
name: "Invalid Request Type",
|
|
request: invalidRequestType{},
|
|
expectedInput: nil,
|
|
expectedError: handlers.ErrInvalidRequest,
|
|
},
|
|
{
|
|
name: "Read Error",
|
|
request: func() *http.Request {
|
|
return httptest.NewRequest(http.MethodPost, "/", &errorReader{})
|
|
}(),
|
|
expectedInput: nil,
|
|
expectedError: errors.New("read error"),
|
|
},
|
|
}
|
|
|
|
parser := &DefaultRequestParser{}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
input, err := parser.GetInput(tt.request)
|
|
|
|
if !bytes.Equal(input, tt.expectedInput) {
|
|
t.Errorf("Expected input %s, got %s", tt.expectedInput, input)
|
|
}
|
|
|
|
if err != tt.expectedError && (err == nil || err.Error() != tt.expectedError.Error()) {
|
|
t.Errorf("Expected error %v, got %v", tt.expectedError, err)
|
|
}
|
|
})
|
|
}
|
|
}
|