diff --git a/devtools/restart_state/main.go b/devtools/restart_state/main.go new file mode 100644 index 0000000..3068a38 --- /dev/null +++ b/devtools/restart_state/main.go @@ -0,0 +1,77 @@ +package main + +import ( + "context" + "flag" + "fmt" + "os" + "path" + + "git.defalsify.org/vise.git/logging" + "git.grassecon.net/urdt/ussd/config" + "git.grassecon.net/urdt/ussd/initializers" + "git.grassecon.net/urdt/ussd/internal/storage" +) + +var ( + logg = logging.NewVanilla() + scriptDir = path.Join("services", "registration") +) + +func init() { + initializers.LoadEnvVariables() +} + +func main() { + config.LoadConfig() + + var dbDir string + var sessionId string + var database string + + flag.StringVar(&sessionId, "session-id", "075xx2123", "session id") + flag.StringVar(&database, "db", "gdbm", "database to be used") + flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from") + flag.Parse() + + ctx := context.Background() + ctx = context.WithValue(ctx, "SessionId", sessionId) + ctx = context.WithValue(ctx, "Database", database) + + resourceDir := scriptDir + menuStorageService := storage.NewMenuStorageService(dbDir, resourceDir) + + err := menuStorageService.EnsureDbDir() + if err != nil { + fmt.Fprintf(os.Stderr, err.Error()) + os.Exit(1) + } + + pe, err := menuStorageService.GetPersister(ctx) + if err != nil { + fmt.Fprintf(os.Stderr, err.Error()) + os.Exit(1) + } + + // initialize the persister + + // get the state + + // restart the state + + // persist the state + + // exit + + st := pe.GetState() + + if st == nil { + logg.ErrorCtxf(ctx, "state fail in devtool", "state", st) + fmt.Errorf("cannot get state") + os.Exit(1) + } + + st.Restart() + + os.Exit(1) +} diff --git a/internal/handlers/ussd/menuhandler.go b/internal/handlers/ussd/menuhandler.go index 0b8ea64..f509ff4 100644 --- a/internal/handlers/ussd/menuhandler.go +++ b/internal/handlers/ussd/menuhandler.go @@ -124,6 +124,11 @@ func (h *Handlers) Init(ctx context.Context, sym string, input []byte) (resource h.st = h.pe.GetState() h.ca = h.pe.GetMemory() + if len(input) == 0 { + // move to the top node + h.st.Code = []byte{} + } + sessionId, _ := ctx.Value("SessionId").(string) flag_admin_privilege, _ := h.flagManager.GetFlag("flag_admin_privilege") diff --git a/internal/http/server.go b/internal/http/server.go index 3ea0159..a6239c4 100644 --- a/internal/http/server.go +++ b/internal/http/server.go @@ -17,8 +17,7 @@ var ( type DefaultRequestParser struct { } - -func(rp *DefaultRequestParser) GetSessionId(rq any) (string, error) { +func (rp *DefaultRequestParser) GetSessionId(rq any) (string, error) { rqv, ok := rq.(*http.Request) if !ok { return "", handlers.ErrInvalidRequest @@ -30,7 +29,7 @@ func(rp *DefaultRequestParser) GetSessionId(rq any) (string, error) { return v, nil } -func(rp *DefaultRequestParser) GetInput(rq any) ([]byte, error) { +func (rp *DefaultRequestParser) GetInput(rq any) ([]byte, error) { rqv, ok := rq.(*http.Request) if !ok { return nil, handlers.ErrInvalidRequest @@ -53,25 +52,24 @@ func ToSessionHandler(h handlers.RequestHandler) *SessionHandler { } } -func(f *SessionHandler) writeError(w http.ResponseWriter, code int, err error) { +func (f *SessionHandler) writeError(w http.ResponseWriter, code int, err error) { s := err.Error() w.Header().Set("Content-Length", strconv.Itoa(len(s))) w.WriteHeader(code) - _, err = w.Write([]byte{}) + _, err = w.Write([]byte(s)) if err != nil { logg.Errorf("error writing error!!", "err", err, "olderr", s) w.WriteHeader(500) } - return } -func(f *SessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { +func (f *SessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { var code int var err error var perr error rqs := handlers.RequestSession{ - Ctx: req.Context(), + Ctx: req.Context(), Writer: w, }