diff --git a/internal/handlers/ussd/menuhandler.go b/internal/handlers/ussd/menuhandler.go index 6b6407b..640517f 100644 --- a/internal/handlers/ussd/menuhandler.go +++ b/internal/handlers/ussd/menuhandler.go @@ -129,6 +129,11 @@ func (h *Handlers) Init(ctx context.Context, sym string, input []byte) (resource h.st = h.pe.GetState() h.ca = h.pe.GetMemory() + if len(input) == 0 { + // move to the top node + h.st.Code = []byte{} + } + sessionId, _ := ctx.Value("SessionId").(string) flag_admin_privilege, _ := h.flagManager.GetFlag("flag_admin_privilege") diff --git a/internal/handlers/ussd/menuhandler_test.go b/internal/handlers/ussd/menuhandler_test.go index c01678d..34c8e76 100644 --- a/internal/handlers/ussd/menuhandler_test.go +++ b/internal/handlers/ussd/menuhandler_test.go @@ -8,6 +8,7 @@ import ( "strings" "testing" + "git.defalsify.org/vise.git/cache" "git.defalsify.org/vise.git/lang" "git.defalsify.org/vise.git/persist" "git.defalsify.org/vise.git/resource" @@ -15,6 +16,7 @@ import ( "git.grassecon.net/urdt/ussd/internal/storage" "git.grassecon.net/urdt/ussd/internal/testutil/mocks" "git.grassecon.net/urdt/ussd/internal/testutil/testservice" + "git.grassecon.net/urdt/ussd/internal/utils" "git.grassecon.net/urdt/ussd/models" "git.grassecon.net/urdt/ussd/common" @@ -119,6 +121,102 @@ func TestNewHandlers(t *testing.T) { }) } +func TestInit(t *testing.T) { + sessionId := "session123" + ctx, store := InitializeTestStore(t) + ctx = context.WithValue(ctx, "SessionId", sessionId) + + fm, err := NewFlagManager(flagsPath) + if err != nil { + t.Fatal(err.Error()) + } + + adminstore, err := utils.NewAdminStore(ctx, "admin_numbers") + if err != nil { + t.Fatal(err.Error()) + } + + st := state.NewState(128) + ca := cache.NewCache() + + flag_admin_privilege, _ := fm.GetFlag("flag_admin_privilege") + + tests := []struct { + name string + setup func() (*Handlers, context.Context) + input []byte + expectedResult resource.Result + }{ + { + name: "Handler not ready", + setup: func() (*Handlers, context.Context) { + return &Handlers{}, ctx + }, + input: []byte("1"), + expectedResult: resource.Result{}, + }, + { + name: "State and memory initialization", + setup: func() (*Handlers, context.Context) { + pe := persist.NewPersister(store).WithSession(sessionId).WithContent(st, ca) + h := &Handlers{ + flagManager: fm.parser, + adminstore: adminstore, + pe: pe, + } + return h, context.WithValue(ctx, "SessionId", sessionId) + }, + input: []byte("1"), + expectedResult: resource.Result{ + FlagReset: []uint32{flag_admin_privilege}, + }, + }, + { + name: "Non-admin session initialization", + setup: func() (*Handlers, context.Context) { + pe := persist.NewPersister(store).WithSession("0712345678").WithContent(st, ca) + h := &Handlers{ + flagManager: fm.parser, + adminstore: adminstore, + pe: pe, + } + return h, context.WithValue(context.Background(), "SessionId", "0712345678") + }, + input: []byte("1"), + expectedResult: resource.Result{ + FlagReset: []uint32{flag_admin_privilege}, + }, + }, + { + name: "Move to top node on empty input", + setup: func() (*Handlers, context.Context) { + pe := persist.NewPersister(store).WithSession(sessionId).WithContent(st, ca) + h := &Handlers{ + flagManager: fm.parser, + adminstore: adminstore, + pe: pe, + } + st.Code = []byte("some pending bytecode") + return h, context.WithValue(ctx, "SessionId", sessionId) + }, + input: []byte(""), + expectedResult: resource.Result{ + FlagReset: []uint32{flag_admin_privilege}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h, testCtx := tt.setup() + res, err := h.Init(testCtx, "", tt.input) + + assert.NoError(t, err, "Unexpected error occurred") + assert.Equal(t, res, tt.expectedResult, "Expected result should match actual result") + }) + } +} + func TestCreateAccount(t *testing.T) { sessionId := "session123" ctx, store := InitializeTestStore(t) diff --git a/internal/http/server.go b/internal/http/server.go index 3ea0159..a6239c4 100644 --- a/internal/http/server.go +++ b/internal/http/server.go @@ -17,8 +17,7 @@ var ( type DefaultRequestParser struct { } - -func(rp *DefaultRequestParser) GetSessionId(rq any) (string, error) { +func (rp *DefaultRequestParser) GetSessionId(rq any) (string, error) { rqv, ok := rq.(*http.Request) if !ok { return "", handlers.ErrInvalidRequest @@ -30,7 +29,7 @@ func(rp *DefaultRequestParser) GetSessionId(rq any) (string, error) { return v, nil } -func(rp *DefaultRequestParser) GetInput(rq any) ([]byte, error) { +func (rp *DefaultRequestParser) GetInput(rq any) ([]byte, error) { rqv, ok := rq.(*http.Request) if !ok { return nil, handlers.ErrInvalidRequest @@ -53,25 +52,24 @@ func ToSessionHandler(h handlers.RequestHandler) *SessionHandler { } } -func(f *SessionHandler) writeError(w http.ResponseWriter, code int, err error) { +func (f *SessionHandler) writeError(w http.ResponseWriter, code int, err error) { s := err.Error() w.Header().Set("Content-Length", strconv.Itoa(len(s))) w.WriteHeader(code) - _, err = w.Write([]byte{}) + _, err = w.Write([]byte(s)) if err != nil { logg.Errorf("error writing error!!", "err", err, "olderr", s) w.WriteHeader(500) } - return } -func(f *SessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { +func (f *SessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { var code int var err error var perr error rqs := handlers.RequestSession{ - Ctx: req.Context(), + Ctx: req.Context(), Writer: w, } diff --git a/menutraversal_test/menu_traversal_test.go b/menutraversal_test/menu_traversal_test.go index 28d88db..6b6b3da 100644 --- a/menutraversal_test/menu_traversal_test.go +++ b/menutraversal_test/menu_traversal_test.go @@ -298,9 +298,10 @@ func TestMainMenuSend(t *testing.T) { ctx := context.Background() sessions := testData for _, session := range sessions { - groups := driver.FilterGroupsByName(session.Groups, "send_with_invalid_inputs") + groups := driver.FilterGroupsByName(session.Groups, "send_with_invite") for _, group := range groups { - for _, step := range group.Steps { + for index, step := range group.Steps { + t.Logf("step %v with input %v", index, step.Input) cont, err := en.Exec(ctx, []byte(step.Input)) if err != nil { t.Fatalf("Test case '%s' failed at input '%s': %v", group.Name, step.Input, err) diff --git a/menutraversal_test/test_setup.json b/menutraversal_test/test_setup.json index c5860b4..5115de9 100644 --- a/menutraversal_test/test_setup.json +++ b/menutraversal_test/test_setup.json @@ -64,8 +64,8 @@ "expectedContent": "Enter recipient's phone number/address/alias:\n0:Back" }, { - "input": "000", - "expectedContent": "000 is invalid, please try again:\n1:Retry\n9:Quit" + "input": "0@0", + "expectedContent": "0@0 is invalid, please try again:\n1:Retry\n9:Quit" }, { "input": "1",