Ensure db close on http signal shutdown, correct stores to provider

This commit is contained in:
lash 2024-09-10 20:44:10 +01:00
parent dd2468a4d7
commit 8e3ff27bb8
Signed by untrusted user who does not match committer: lash
GPG Key ID: 21D2E7BB88C2A746
5 changed files with 49 additions and 13 deletions

View File

@ -6,8 +6,10 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"os" "os"
"os/signal"
"path" "path"
"strconv" "strconv"
"syscall"
"git.defalsify.org/vise.git/asm" "git.defalsify.org/vise.git/asm"
"git.defalsify.org/vise.git/db" "git.defalsify.org/vise.git/db"
@ -78,11 +80,15 @@ func getHandler(appFlags *asm.FlagParser, rs *resource.DbResource, userdataStore
return ussdHandlers, nil return ussdHandlers, nil
} }
func getStateStore(dbDir string, ctx context.Context) (db.Db, error) { func ensureDbDir(dbDir string) error {
err := os.MkdirAll(dbDir, 0700) err := os.MkdirAll(dbDir, 0700)
if err != nil { if err != nil {
return nil, fmt.Errorf("state dir create exited with error: %v\n", err) return fmt.Errorf("state dir create exited with error: %v\n", err)
} }
return nil
}
func getStateStore(dbDir string, ctx context.Context) (db.Db, error) {
store := gdbmdb.NewGdbmDb() store := gdbmdb.NewGdbmDb()
storeFile := path.Join(dbDir, "state.gdbm") storeFile := path.Join(dbDir, "state.gdbm")
store.Connect(ctx, storeFile) store.Connect(ctx, storeFile)
@ -117,12 +123,10 @@ func main() {
var dbDir string var dbDir string
var resourceDir string var resourceDir string
var size uint var size uint
var sessionId string
var engineDebug bool var engineDebug bool
var stateDebug bool var stateDebug bool
var host string var host string
var port uint var port uint
flag.StringVar(&sessionId, "session-id", "075xx2123", "session id")
flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from") flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from")
flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir") flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir")
flag.BoolVar(&engineDebug, "engine-debug", false, "use engine debug output") flag.BoolVar(&engineDebug, "engine-debug", false, "use engine debug output")
@ -135,7 +139,6 @@ func main() {
logg.Infof("start command", "dbdir", dbDir, "resourcedir", resourceDir, "outputsize", size) logg.Infof("start command", "dbdir", dbDir, "resourcedir", resourceDir, "outputsize", size)
ctx := context.Background() ctx := context.Background()
ctx = context.WithValue(ctx, "SessionId",sessionId)
pfp := path.Join(scriptDir, "pp.csv") pfp := path.Join(scriptDir, "pp.csv")
flagParser, err := getFlags(pfp, true) flagParser, err := getFlags(pfp, true)
@ -145,7 +148,6 @@ func main() {
cfg := engine.Config{ cfg := engine.Config{
Root: "root", Root: "root",
SessionId: sessionId,
OutputSize: uint32(size), OutputSize: uint32(size),
FlagCount: uint32(16), FlagCount: uint32(16),
} }
@ -162,11 +164,18 @@ func main() {
os.Exit(1) os.Exit(1)
} }
err = ensureDbDir(dbDir)
if err != nil {
fmt.Fprintf(os.Stderr, err.Error())
os.Exit(1)
}
userdataStore := getUserdataDb(dbDir, ctx) userdataStore := getUserdataDb(dbDir, ctx)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, err.Error()) fmt.Fprintf(os.Stderr, err.Error())
os.Exit(1) os.Exit(1)
} }
defer userdataStore.Close()
dbResource, ok := rs.(*resource.DbResource) dbResource, ok := rs.(*resource.DbResource)
if !ok { if !ok {
@ -184,16 +193,28 @@ func main() {
fmt.Fprintf(os.Stderr, err.Error()) fmt.Fprintf(os.Stderr, err.Error())
os.Exit(1) os.Exit(1)
} }
defer stateStore.Close()
sh := httpserver.NewSessionHandler(cfg, rs, userdataStore, stateStore, hl.Init) sh := httpserver.NewSessionHandler(cfg, rs, stateStore, userdataStore, hl.Init)
s := &http.Server{ s := &http.Server{
Addr: fmt.Sprintf("%s:%s", host, strconv.Itoa(int(port))), Addr: fmt.Sprintf("%s:%s", host, strconv.Itoa(int(port))),
Handler: sh, Handler: sh,
} }
s.RegisterOnShutdown(sh.Shutdown)
cint := make(chan os.Signal)
cterm := make(chan os.Signal)
signal.Notify(cint, os.Interrupt, syscall.SIGINT)
signal.Notify(cterm, os.Interrupt, syscall.SIGTERM)
go func() {
select {
case _ = <-cint:
case _ = <-cterm:
}
s.Shutdown(ctx)
}()
err = s.ListenAndServe() err = s.ListenAndServe()
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "Server error: %s", err) logg.Infof("Server closed with error", "err", err)
os.Exit(1)
} }
} }

View File

@ -75,10 +75,18 @@ func getHandler(appFlags *asm.FlagParser, rs *resource.DbResource, pe *persist.P
return ussdHandlers, nil return ussdHandlers, nil
} }
func getPersister(dbDir string, ctx context.Context) (*persist.Persister, error) { func ensureDbDir(dbDir string) error {
err := os.MkdirAll(dbDir, 0700) err := os.MkdirAll(dbDir, 0700)
if err != nil { if err != nil {
return nil, fmt.Errorf("state dir create exited with error: %v\n", err) return fmt.Errorf("state dir create exited with error: %v\n", err)
}
return nil
}
func getPersister(dbDir string, ctx context.Context) (*persist.Persister, error) {
err := ensureDbDir(dbDir)
if err != nil {
return nil, err
} }
store := gdbmdb.NewGdbmDb() store := gdbmdb.NewGdbmDb()
storeFile := path.Join(dbDir, "state.gdbm") storeFile := path.Join(dbDir, "state.gdbm")

View File

@ -66,6 +66,13 @@ func(f *SessionHandler) writeError(w http.ResponseWriter, code int, msg string,
return return
} }
func(f* SessionHandler) Shutdown() {
err := f.provider.Close()
if err != nil {
logg.Errorf("handler shutdown error", "err", err)
}
}
func(f *SessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { func(f *SessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
var r bool var r bool
sessionId, err := f.rp.GetSessionId(req) sessionId, err := f.rp.GetSessionId(req)
@ -89,6 +96,7 @@ func(f *SessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
f.writeError(w, 500, "Storage retrieval fail", err) f.writeError(w, 500, "Storage retrieval fail", err)
return return
} }
defer f.provider.Put(cfg.SessionId, storage)
en := getEngine(cfg, f.rs, storage.Persister) en := getEngine(cfg, f.rs, storage.Persister)
en = en.WithFirst(f.first) en = en.WithFirst(f.first)
if cfg.EngineDebug { if cfg.EngineDebug {

View File

@ -39,5 +39,5 @@ func (p *SimpleStorageProvider) Put(sessionId string, storage Storage) error {
} }
func (p *SimpleStorageProvider) Close() error { func (p *SimpleStorageProvider) Close() error {
return nil return p.Storage.UserdataDb.Close()
} }

View File

@ -39,7 +39,6 @@ func PackKey(typ DataTyp, data []byte) []byte {
} }
func ReadEntry(ctx context.Context, store db.Db, sessionId string, typ DataTyp) ([]byte, error) { func ReadEntry(ctx context.Context, store db.Db, sessionId string, typ DataTyp) ([]byte, error) {
store.SetPrefix(db.DATATYPE_USERDATA) store.SetPrefix(db.DATATYPE_USERDATA)
store.SetSession(sessionId) store.SetSession(sessionId)
k := PackKey(typ, []byte(sessionId)) k := PackKey(typ, []byte(sessionId))