diff --git a/internal/storage/storageservice.go b/internal/storage/storageservice.go index 01a122c..c616019 100644 --- a/internal/storage/storageservice.go +++ b/internal/storage/storageservice.go @@ -53,40 +53,58 @@ func NewMenuStorageService(dbDir string, resourceDir string) *MenuStorageService } } -func (ms *MenuStorageService) GetPersister(ctx context.Context) (*persist.Persister, error) { - ms.stateStore = NewThreadGdbmDb() - storeFile := path.Join(ms.dbDir, "state.gdbm") - err := ms.stateStore.Connect(ctx, storeFile) - if err != nil { - return nil, err - } - pr := persist.NewPersister(ms.stateStore) - logg.TraceCtxf(ctx, "menu storage service", "persist", pr, "store", ms.stateStore) - return pr, nil -} - -func (ms *MenuStorageService) GetUserdataDb(ctx context.Context) (db.Db, error) { +func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.Db, fileName string) (db.Db, error) { database, ok := ctx.Value("Database").(string) if !ok { return nil, fmt.Errorf("failed to select the database") } - if database == "postgres" { - ms.userDataStore = NewPostgresDb() - connStr := buildConnStr() - err := ms.userDataStore.Connect(ctx, connStr) - if err != nil { - return nil, err - } - } else { - ms.userDataStore = NewThreadGdbmDb() - storeFile := path.Join(ms.dbDir, "userdata.gdbm") - err := ms.userDataStore.Connect(ctx, storeFile) - if err != nil { - return nil, err - } + if existingDb != nil { + return existingDb, nil } + var newDb db.Db + var err error + + if database == "postgres" { + newDb = NewPostgresDb() + connStr := buildConnStr() + err = newDb.Connect(ctx, connStr) + } else { + newDb = NewThreadGdbmDb() + storeFile := path.Join(ms.dbDir, fileName) + err = newDb.Connect(ctx, storeFile) + } + + if err != nil { + return nil, err + } + + return newDb, nil +} + +func (ms *MenuStorageService) GetPersister(ctx context.Context) (*persist.Persister, error) { + stateStore, err := ms.GetStateStore(ctx) + if err != nil { + return nil, err + } + + pr := persist.NewPersister(stateStore) + logg.TraceCtxf(ctx, "menu storage service", "persist", pr, "store", stateStore) + return pr, nil +} + +func (ms *MenuStorageService) GetUserdataDb(ctx context.Context) (db.Db, error) { + if ms.userDataStore != nil { + return ms.userDataStore, nil + } + + userDataStore, err := ms.getOrCreateDb(ctx, ms.userDataStore, "userdata.gdbm") + if err != nil { + return nil, err + } + + ms.userDataStore = userDataStore return ms.userDataStore, nil } @@ -102,14 +120,15 @@ func (ms *MenuStorageService) GetResource(ctx context.Context) (resource.Resourc func (ms *MenuStorageService) GetStateStore(ctx context.Context) (db.Db, error) { if ms.stateStore != nil { - panic("set up store when already exists") + return ms.stateStore, nil } - ms.stateStore = NewThreadGdbmDb() - storeFile := path.Join(ms.dbDir, "state.gdbm") - err := ms.stateStore.Connect(ctx, storeFile) + + stateStore, err := ms.getOrCreateDb(ctx, ms.stateStore, "state.gdbm") if err != nil { return nil, err } + + ms.stateStore = stateStore return ms.stateStore, nil }