From f267aa2b4115138ee461ee0f480c331f12d9d9ac Mon Sep 17 00:00:00 2001 From: lash Date: Tue, 1 Oct 2024 00:18:54 +0100 Subject: [PATCH] Delete connstr in threadgdbm global channel map on close --- engine/engine.go | 18 ++++++++++++------ internal/storage/gdbm.go | 1 + menu_traversal_test.go | 33 ++++++++++++++++----------------- 3 files changed, 29 insertions(+), 23 deletions(-) diff --git a/engine/engine.go b/engine/engine.go index 5ac62ec..5fe0977 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -8,7 +8,6 @@ import ( "git.defalsify.org/vise.git/engine" "git.defalsify.org/vise.git/logging" - "git.defalsify.org/vise.git/persist" "git.defalsify.org/vise.git/resource" "git.grassecon.net/urdt/ussd/internal/handlers" "git.grassecon.net/urdt/ussd/internal/storage" @@ -19,7 +18,7 @@ var ( scriptDir = path.Join("services", "registration") ) -func TestEngine(sessionId string) (engine.Engine, *persist.Persister) { +func TestEngine(sessionId string) (engine.Engine, func()) { ctx := context.Background() ctx = context.WithValue(ctx, "SessionId", sessionId) pfp := path.Join(scriptDir, "pp.csv") @@ -53,7 +52,7 @@ func TestEngine(sessionId string) (engine.Engine, *persist.Persister) { os.Exit(1) } - userdatastore, err := menuStorageService.GetUserdataDb(ctx) + userDataStore, err := menuStorageService.GetUserdataDb(ctx) if err != nil { fmt.Fprintf(os.Stderr, err.Error()) os.Exit(1) @@ -66,7 +65,7 @@ func TestEngine(sessionId string) (engine.Engine, *persist.Persister) { } lhs, err := handlers.NewLocalHandlerService(pfp, true, dbResource, cfg, rs) - lhs.SetDataStore(&userdatastore) + lhs.SetDataStore(&userDataStore) lhs.SetPersister(pe) if err != nil { @@ -83,7 +82,14 @@ func TestEngine(sessionId string) (engine.Engine, *persist.Persister) { en := lhs.GetEngine() en = en.WithFirst(hl.Init) - //en = en.WithDebug(nil) - return en, pe + cleanFn := func() { + err := menuStorageService.Close() + if err != nil { + logg.Errorf(err.Error()) + } + logg.Infof("testengine storage closed") + } + //en = en.WithDebug(nil) + return en, cleanFn } diff --git a/internal/storage/gdbm.go b/internal/storage/gdbm.go index eb959cf..49de570 100644 --- a/internal/storage/gdbm.go +++ b/internal/storage/gdbm.go @@ -109,6 +109,7 @@ func(tdb *ThreadGdbmDb) Get(ctx context.Context, key []byte) ([]byte, error) { func(tdb *ThreadGdbmDb) Close() error { tdb.reserve() close(dbC[tdb.connStr]) + delete(dbC, tdb.connStr) err := tdb.db.Close() tdb.db = nil return err diff --git a/menu_traversal_test.go b/menu_traversal_test.go index 728d8c3..9c1e839 100644 --- a/menu_traversal_test.go +++ b/menu_traversal_test.go @@ -14,8 +14,8 @@ var ( ) func TestUserRegistration(t *testing.T) { - en, _ := enginetest.TestEngine("session1234112") - defer en.Finish() + en, fn := enginetest.TestEngine("session1234112") + defer fn() ctx := context.Background() sessions := testData for _, session := range sessions { @@ -26,8 +26,7 @@ func TestUserRegistration(t *testing.T) { cont, err := en.Exec(ctx, []byte(step.Input)) if err != nil { - t.Errorf("Test case '%s' failed at input '%s': %v", group.Name, step.Input, err) - return + t.Fatalf("Test case '%s' failed at input '%s': %v", group.Name, step.Input, err) } if !cont { break @@ -35,7 +34,7 @@ func TestUserRegistration(t *testing.T) { w := bytes.NewBuffer(nil) _, err = en.Flush(ctx, w) if err != nil { - t.Errorf("Test case '%s' failed during Flush: %v", group.Name, err) + t.Fatalf("Test case '%s' failed during Flush: %v", group.Name, err) } b := w.Bytes() if !bytes.Equal(b, []byte(step.ExpectedContent)) { @@ -48,8 +47,8 @@ func TestUserRegistration(t *testing.T) { } func TestTerms(t *testing.T) { - en, _ := enginetest.TestEngine("session1234112") - defer en.Finish() + en, fn := enginetest.TestEngine("session1234112_a") + defer fn() ctx := context.Background() sessions := testData @@ -59,13 +58,13 @@ func TestTerms(t *testing.T) { for _, step := range group.Steps { _, err := en.Exec(ctx, []byte(step.Input)) if err != nil { - t.Fail() + t.Fatalf("Test case '%s' failed during Exec: %v", group.Name, err) } w := bytes.NewBuffer(nil) _, err = en.Flush(ctx, w) if err != nil { - t.Errorf("Test case '%s' failed during Flush: %v", group.Name, err) + t.Fatalf("Test case '%s' failed during Flush: %v", group.Name, err) } b := w.Bytes() if !bytes.Equal(b, []byte(step.ExpectedContent)) { @@ -78,8 +77,8 @@ func TestTerms(t *testing.T) { } func TestAccountRegistrationRejectTerms(t *testing.T) { - en, _ := enginetest.TestEngine("session1234112") - defer en.Finish() + en, fn := enginetest.TestEngine("session1234112_b") + defer fn() ctx := context.Background() sessions := testData for _, session := range sessions { @@ -88,7 +87,7 @@ func TestAccountRegistrationRejectTerms(t *testing.T) { for _, step := range group.Steps { cont, err := en.Exec(ctx, []byte(step.Input)) if err != nil { - t.Errorf("Test case '%s' failed at input '%s': %v", group.Name, step.Input, err) + t.Fatalf("Test case '%s' failed at input '%s': %v", group.Name, step.Input, err) return } if !cont { @@ -96,7 +95,7 @@ func TestAccountRegistrationRejectTerms(t *testing.T) { } w := bytes.NewBuffer(nil) if _, err := en.Flush(ctx, w); err != nil { - t.Errorf("Test case '%s' failed during Flush: %v", group.Name, err) + t.Fatalf("Test case '%s' failed during Flush: %v", group.Name, err) } b := w.Bytes() @@ -109,8 +108,8 @@ func TestAccountRegistrationRejectTerms(t *testing.T) { } func TestAccountRegistrationInvalidPin(t *testing.T) { - en, _ := enginetest.TestEngine("session1234112") - defer en.Finish() + en, fn := enginetest.TestEngine("session1234112") + defer fn() ctx := context.Background() sessions := testData for _, session := range sessions { @@ -119,7 +118,7 @@ func TestAccountRegistrationInvalidPin(t *testing.T) { for _, step := range group.Steps { cont, err := en.Exec(ctx, []byte(step.Input)) if err != nil { - t.Errorf("Test case '%s' failed at input '%s': %v", group.Name, step.Input, err) + t.Fatalf("Test case '%s' failed at input '%s': %v", group.Name, step.Input, err) return } if !cont { @@ -127,7 +126,7 @@ func TestAccountRegistrationInvalidPin(t *testing.T) { } w := bytes.NewBuffer(nil) if _, err := en.Flush(ctx, w); err != nil { - t.Errorf("Test case '%s' failed during Flush: %v", group.Name, err) + t.Fatalf("Test case '%s' failed during Flush: %v", group.Name, err) } b := w.Bytes()