From 8e3ff27bb87d63e81f51feb3baee99f21804e567 Mon Sep 17 00:00:00 2001 From: lash Date: Tue, 10 Sep 2024 20:44:10 +0100 Subject: [PATCH] Ensure db close on http signal shutdown, correct stores to provider --- cmd/http/main.go | 39 ++++++++++++++++++++++++++++++--------- cmd/main.go | 12 ++++++++++-- internal/http/server.go | 8 ++++++++ internal/http/storage.go | 2 +- internal/utils/db.go | 1 - 5 files changed, 49 insertions(+), 13 deletions(-) diff --git a/cmd/http/main.go b/cmd/http/main.go index d253c12..1edc1fe 100644 --- a/cmd/http/main.go +++ b/cmd/http/main.go @@ -6,8 +6,10 @@ import ( "fmt" "net/http" "os" + "os/signal" "path" "strconv" + "syscall" "git.defalsify.org/vise.git/asm" "git.defalsify.org/vise.git/db" @@ -78,11 +80,15 @@ func getHandler(appFlags *asm.FlagParser, rs *resource.DbResource, userdataStore return ussdHandlers, nil } -func getStateStore(dbDir string, ctx context.Context) (db.Db, error) { +func ensureDbDir(dbDir string) error { err := os.MkdirAll(dbDir, 0700) if err != nil { - return nil, fmt.Errorf("state dir create exited with error: %v\n", err) + return fmt.Errorf("state dir create exited with error: %v\n", err) } + return nil +} + +func getStateStore(dbDir string, ctx context.Context) (db.Db, error) { store := gdbmdb.NewGdbmDb() storeFile := path.Join(dbDir, "state.gdbm") store.Connect(ctx, storeFile) @@ -117,12 +123,10 @@ func main() { var dbDir string var resourceDir string var size uint - var sessionId string var engineDebug bool var stateDebug bool var host string var port uint - flag.StringVar(&sessionId, "session-id", "075xx2123", "session id") flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from") flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir") flag.BoolVar(&engineDebug, "engine-debug", false, "use engine debug output") @@ -135,7 +139,6 @@ func main() { logg.Infof("start command", "dbdir", dbDir, "resourcedir", resourceDir, "outputsize", size) ctx := context.Background() - ctx = context.WithValue(ctx, "SessionId",sessionId) pfp := path.Join(scriptDir, "pp.csv") flagParser, err := getFlags(pfp, true) @@ -145,7 +148,6 @@ func main() { cfg := engine.Config{ Root: "root", - SessionId: sessionId, OutputSize: uint32(size), FlagCount: uint32(16), } @@ -162,11 +164,18 @@ func main() { os.Exit(1) } + err = ensureDbDir(dbDir) + if err != nil { + fmt.Fprintf(os.Stderr, err.Error()) + os.Exit(1) + } + userdataStore := getUserdataDb(dbDir, ctx) if err != nil { fmt.Fprintf(os.Stderr, err.Error()) os.Exit(1) } + defer userdataStore.Close() dbResource, ok := rs.(*resource.DbResource) if !ok { @@ -184,16 +193,28 @@ func main() { fmt.Fprintf(os.Stderr, err.Error()) os.Exit(1) } + defer stateStore.Close() - sh := httpserver.NewSessionHandler(cfg, rs, userdataStore, stateStore, hl.Init) + sh := httpserver.NewSessionHandler(cfg, rs, stateStore, userdataStore, hl.Init) s := &http.Server{ Addr: fmt.Sprintf("%s:%s", host, strconv.Itoa(int(port))), Handler: sh, } + s.RegisterOnShutdown(sh.Shutdown) + cint := make(chan os.Signal) + cterm := make(chan os.Signal) + signal.Notify(cint, os.Interrupt, syscall.SIGINT) + signal.Notify(cterm, os.Interrupt, syscall.SIGTERM) + go func() { + select { + case _ = <-cint: + case _ = <-cterm: + } + s.Shutdown(ctx) + }() err = s.ListenAndServe() if err != nil { - fmt.Fprintf(os.Stderr, "Server error: %s", err) - os.Exit(1) + logg.Infof("Server closed with error", "err", err) } } diff --git a/cmd/main.go b/cmd/main.go index 9547dc4..9222c13 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -75,10 +75,18 @@ func getHandler(appFlags *asm.FlagParser, rs *resource.DbResource, pe *persist.P return ussdHandlers, nil } -func getPersister(dbDir string, ctx context.Context) (*persist.Persister, error) { +func ensureDbDir(dbDir string) error { err := os.MkdirAll(dbDir, 0700) if err != nil { - return nil, fmt.Errorf("state dir create exited with error: %v\n", err) + return fmt.Errorf("state dir create exited with error: %v\n", err) + } + return nil +} + +func getPersister(dbDir string, ctx context.Context) (*persist.Persister, error) { + err := ensureDbDir(dbDir) + if err != nil { + return nil, err } store := gdbmdb.NewGdbmDb() storeFile := path.Join(dbDir, "state.gdbm") diff --git a/internal/http/server.go b/internal/http/server.go index aa53448..4ca7f73 100644 --- a/internal/http/server.go +++ b/internal/http/server.go @@ -66,6 +66,13 @@ func(f *SessionHandler) writeError(w http.ResponseWriter, code int, msg string, return } +func(f* SessionHandler) Shutdown() { + err := f.provider.Close() + if err != nil { + logg.Errorf("handler shutdown error", "err", err) + } +} + func(f *SessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { var r bool sessionId, err := f.rp.GetSessionId(req) @@ -89,6 +96,7 @@ func(f *SessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { f.writeError(w, 500, "Storage retrieval fail", err) return } + defer f.provider.Put(cfg.SessionId, storage) en := getEngine(cfg, f.rs, storage.Persister) en = en.WithFirst(f.first) if cfg.EngineDebug { diff --git a/internal/http/storage.go b/internal/http/storage.go index 012c56c..f8243cc 100644 --- a/internal/http/storage.go +++ b/internal/http/storage.go @@ -39,5 +39,5 @@ func (p *SimpleStorageProvider) Put(sessionId string, storage Storage) error { } func (p *SimpleStorageProvider) Close() error { - return nil + return p.Storage.UserdataDb.Close() } diff --git a/internal/utils/db.go b/internal/utils/db.go index 5b128f6..94ce250 100644 --- a/internal/utils/db.go +++ b/internal/utils/db.go @@ -39,7 +39,6 @@ func PackKey(typ DataTyp, data []byte) []byte { } func ReadEntry(ctx context.Context, store db.Db, sessionId string, typ DataTyp) ([]byte, error) { - store.SetPrefix(db.DATATYPE_USERDATA) store.SetSession(sessionId) k := PackKey(typ, []byte(sessionId))