From 0813a619b4dddff14aeaf084cc4052a9e7bf369f Mon Sep 17 00:00:00 2001 From: lash Date: Sat, 21 Sep 2024 21:32:02 +0100 Subject: [PATCH] Add hacky db closer function in ssh --- cmd/ssh/main.go | 147 +++++++++++++++++++---------- internal/storage/storageservice.go | 71 ++++++++++---- 2 files changed, 146 insertions(+), 72 deletions(-) diff --git a/cmd/ssh/main.go b/cmd/ssh/main.go index 649d2ae..ddf2937 100644 --- a/cmd/ssh/main.go +++ b/cmd/ssh/main.go @@ -6,7 +6,6 @@ import ( "errors" "flag" "fmt" - "log" "net" "path" "os" @@ -79,8 +78,8 @@ func(a *auther) Get(k []byte) (string, error) { return v, nil } -// TODO: where should the session id be uniquely embedded -func serve(ctx context.Context, sessionId string, ch ssh.NewChannel, mss *storage.MenuStorageService, lhs *handlers.LocalHandlerService) error { +//func serve(ctx context.Context, sessionId string, ch ssh.NewChannel, mss *storage.MenuStorageService, lhs *handlers.LocalHandlerService) error { +func serve(ctx context.Context, sessionId string, ch ssh.NewChannel, en engine.Engine) error { if ch == nil { return errors.New("nil channel") } @@ -102,23 +101,6 @@ func serve(ctx context.Context, sessionId string, ch ssh.NewChannel, mss *storag _ = requests }(requests) - pe, err := mss.GetPersister(ctx) - if err != nil { - return fmt.Errorf("cannot get persister: %v", err) - } - lhs.SetPersister(pe) - lhs.Cfg.SessionId = sessionId - - hl, err := lhs.GetHandler() - if err != nil { - return fmt.Errorf("cannot get handler: %v", err) - } - - en := lhs.GetEngine() - en = en.WithFirst(hl.Init) - en = en.WithDebug(nil) - defer en.Finish() - cont, err := en.Exec(ctx, []byte{}) if err != nil { return fmt.Errorf("initial engine exec err: %v", err) @@ -154,8 +136,75 @@ func serve(ctx context.Context, sessionId string, ch ssh.NewChannel, mss *storag return nil } +type sshRunner struct { + Ctx context.Context + Cfg engine.Config + FlagFile string + DbDir string + ResourceDir string + Debug bool +} + +func(s *sshRunner) GetEngine(sessionId string) (engine.Engine, func(), error) { + ctx := s.Ctx + menuStorageService := storage.NewMenuStorageService(s.DbDir, s.ResourceDir) + + err := menuStorageService.EnsureDbDir() + if err != nil { + return nil, nil, err + } + + rs, err := menuStorageService.GetResource(ctx) + if err != nil { + return nil, nil, err + } + + pe, err := menuStorageService.GetPersister(ctx) + if err != nil { + return nil, nil, err + } + + userdatastore, err := menuStorageService.GetUserdataDb(ctx) + if err != nil { + return nil, nil, err + } + + dbResource, ok := rs.(*resource.DbResource) + if !ok { + return nil, nil, err + } + + lhs, err := handlers.NewLocalHandlerService(s.FlagFile, true, dbResource, s.Cfg, rs) + lhs.SetDataStore(&userdatastore) + lhs.SetPersister(pe) + lhs.Cfg.SessionId = sessionId + + if err != nil { + return nil, nil, err + } + + hl, err := lhs.GetHandler() + if err != nil { + return nil, nil, err + } + + en := lhs.GetEngine() + en = en.WithFirst(hl.Init) + if s.Debug { + en = en.WithDebug(nil) + } + // TODO: this is getting very hacky! + closer := func() { + err := menuStorageService.Close() + if err != nil { + logg.ErrorCtxf(ctx, "menu storage service cleanup fail", "err", err) + } + } + return en, closer, nil +} + // adapted example from crypto/ssh package, NewServerConn doc -func sshRun(ctx context.Context, mss *storage.MenuStorageService, lhs *handlers.LocalHandlerService) { +func(s *sshRunner) Run(ctx context.Context) {//, mss *storage.MenuStorageService, lhs *handlers.LocalHandlerService) { running := true defer wg.Wait() @@ -168,11 +217,11 @@ func sshRun(ctx context.Context, mss *storage.MenuStorageService, lhs *handlers. privateBytes, err := os.ReadFile("/home/lash/.ssh/id_rsa_tmp") if err != nil { - log.Fatal("Failed to load private key: ", err) + logg.ErrorCtxf(ctx, "Failed to load private key", "err", err) } private, err := ssh.ParsePrivateKey(privateBytes) if err != nil { - log.Fatal("Failed to parse private key: ", err) + logg.ErrorCtxf(ctx, "Failed to parse private key", "err", err) } cfg.AddHostKey(private) @@ -209,8 +258,20 @@ func sshRun(ctx context.Context, mss *storage.MenuStorageService, lhs *handlers. logg.ErrorCtxf(ctx, "Cannot find authentication") return } + en, closer, err := s.GetEngine(sessionId) + if err != nil { + logg.ErrorCtxf(ctx, "engine won't start", "err", err) + return + } + defer func() { + err := en.Finish() + if err != nil { + logg.ErrorCtxf(ctx, "engine won't stop", "err", err) + } + closer() + }() for ch := range nC { - err = serve(ctx, sessionId, ch, mss, lhs) + err = serve(ctx, sessionId, ch, en) logg.ErrorCtxf(ctx, "ssh server finish", "err", err) } } @@ -218,6 +279,7 @@ func sshRun(ctx context.Context, mss *storage.MenuStorageService, lhs *handlers. } } + func sshLoadKeys(ctx context.Context, dbDir string) error { keyStoreFile := path.Join(dbDir, "ssh_authorized_keys.gdbm") keyStore = gdbmdb.NewGdbmDb() @@ -264,38 +326,19 @@ func main() { if engineDebug { cfg.EngineDebug = true } - - mss := storage.NewMenuStorageService(dbDir, resourceDir) - rs, err := mss.GetResource(ctx) - if err != nil { - fmt.Fprintf(os.Stderr, err.Error()) - os.Exit(1) - } - - err = mss.EnsureDbDir() - if err != nil { - fmt.Fprintf(os.Stderr, err.Error()) - os.Exit(1) - } - - dbResource, ok := rs.(*resource.DbResource) - if !ok { - os.Exit(1) - } - userdataStore := mss.GetUserdataDb(ctx) - if err != nil { - fmt.Fprintf(os.Stderr, err.Error()) - os.Exit(1) - } - - lhs, err := handlers.NewLocalHandlerService(pfp, engineDebug, dbResource, cfg, rs) - lhs.SetDataStore(&userdataStore) - err = sshLoadKeys(ctx, dbDir) + err := sshLoadKeys(ctx, dbDir) if err != nil { fmt.Fprintf(os.Stderr, err.Error()) os.Exit(1) } - sshRun(ctx, mss, lhs) + runner := &sshRunner{ + Cfg: cfg, + Debug: engineDebug, + FlagFile: pfp, + DbDir: dbDir, + ResourceDir: resourceDir, + } + runner.Run(ctx) } diff --git a/internal/storage/storageservice.go b/internal/storage/storageservice.go index 3b622a1..7cceb4e 100644 --- a/internal/storage/storageservice.go +++ b/internal/storage/storageservice.go @@ -11,8 +11,13 @@ import ( gdbmdb "git.defalsify.org/vise.git/db/gdbm" "git.defalsify.org/vise.git/persist" "git.defalsify.org/vise.git/resource" + "git.defalsify.org/vise.git/logging" ) +var ( + logg = logging.NewVanilla().WithDomain("storage") +) + type StorageService interface { GetPersister(ctx context.Context) (*persist.Persister, error) GetUserdataDb(ctx context.Context) db.Db @@ -23,6 +28,9 @@ type StorageService interface { type MenuStorageService struct{ dbDir string resourceDir string + resourceStore db.Db + stateStore db.Db + userDataStore db.Db } func NewMenuStorageService(dbDir string, resourceDir string) *MenuStorageService { @@ -33,35 +41,48 @@ func NewMenuStorageService(dbDir string, resourceDir string) *MenuStorageService } func (ms *MenuStorageService) GetPersister(ctx context.Context) (*persist.Persister, error) { - store := gdbmdb.NewGdbmDb() + ms.stateStore = gdbmdb.NewGdbmDb() storeFile := path.Join(ms.dbDir, "state.gdbm") - store.Connect(ctx, storeFile) - pr := persist.NewPersister(store) - return pr, nil -} - -func (ms *MenuStorageService) GetUserdataDb(ctx context.Context) db.Db { - store := gdbmdb.NewGdbmDb() - storeFile := path.Join(ms.dbDir, "userdata.gdbm") - store.Connect(ctx, storeFile) - return store -} - -func (ms *MenuStorageService) GetResource(ctx context.Context) (resource.Resource, error) { - store := fsdb.NewFsDb() - err := store.Connect(ctx, ms.resourceDir) + err := ms.stateStore.Connect(ctx, storeFile) if err != nil { return nil, err } - rfs := resource.NewDbResource(store) + pr := persist.NewPersister(ms.stateStore) + logg.TraceCtxf(ctx, "menu storage service", "persist", pr, "store", ms.stateStore) + return pr, nil +} + +func (ms *MenuStorageService) GetUserdataDb(ctx context.Context) (db.Db, error) { + ms.userDataStore = gdbmdb.NewGdbmDb() + storeFile := path.Join(ms.dbDir, "userdata.gdbm") + err := ms.userDataStore.Connect(ctx, storeFile) + if err != nil { + return nil, err + } + return ms.userDataStore, nil +} + +func (ms *MenuStorageService) GetResource(ctx context.Context) (resource.Resource, error) { + ms.resourceStore = fsdb.NewFsDb() + err := ms.resourceStore.Connect(ctx, ms.resourceDir) + if err != nil { + return nil, err + } + rfs := resource.NewDbResource(ms.resourceStore) return rfs, nil } func (ms *MenuStorageService) GetStateStore(ctx context.Context) (db.Db, error) { - store := gdbmdb.NewGdbmDb() + if ms.stateStore != nil { + panic("set up store when already exists") + } + ms.stateStore = gdbmdb.NewGdbmDb() storeFile := path.Join(ms.dbDir, "state.gdbm") - store.Connect(ctx, storeFile) - return store, nil + err := ms.stateStore.Connect(ctx, storeFile) + if err != nil { + return nil, err + } + return ms.stateStore, nil } func (ms *MenuStorageService) EnsureDbDir() error { @@ -71,3 +92,13 @@ func (ms *MenuStorageService) EnsureDbDir() error { } return nil } + +func (ms *MenuStorageService) Close() error { + errA := ms.stateStore.Close() + errB := ms.userDataStore.Close() + errC := ms.resourceStore.Close() + if errA != nil || errB != nil || errC != nil { + return fmt.Errorf("%v %v %v", errA, errB, errC) + } + return nil +}