log-session-id-at-sessionid #251

Merged
lash merged 4 commits from log-session-id-at-sessionid into master 2025-01-06 08:01:21 +01:00
10 changed files with 35 additions and 37 deletions

View File

@ -121,9 +121,7 @@ func main() {
}
defer stateStore.Close()
rp := &at.ATRequestParser{
Context: ctx,
}
rp := &at.ATRequestParser{}
bsh := handlers.NewBaseSessionHandler(cfg, rs, stateStore, userdataStore, rp, hl)
sh := httpserver.NewATSessionHandler(bsh)

View File

@ -21,8 +21,8 @@ import (
)
var (
logg = logging.NewVanilla()
scriptDir = path.Join("services", "registration")
logg = logging.NewVanilla()
scriptDir = path.Join("services", "registration")
menuSeparator = ": "
)
@ -35,7 +35,7 @@ type asyncRequestParser struct {
input []byte
}
func (p *asyncRequestParser) GetSessionId(r any) (string, error) {
func (p *asyncRequestParser) GetSessionId(ctx context.Context, r any) (string, error) {
return p.sessionId, nil
}

View File

@ -6,9 +6,9 @@ import (
"io"
"git.defalsify.org/vise.git/engine"
"git.defalsify.org/vise.git/resource"
"git.defalsify.org/vise.git/persist"
"git.defalsify.org/vise.git/logging"
"git.defalsify.org/vise.git/persist"
"git.defalsify.org/vise.git/resource"
"git.grassecon.net/urdt/ussd/internal/storage"
)
@ -20,26 +20,26 @@ var (
var (
ErrInvalidRequest = errors.New("invalid request for context")
ErrSessionMissing = errors.New("missing session")
ErrInvalidInput = errors.New("invalid input")
ErrStorage = errors.New("storage retrieval fail")
ErrEngineType = errors.New("incompatible engine")
ErrEngineInit = errors.New("engine init fail")
ErrEngineExec = errors.New("engine exec fail")
ErrInvalidInput = errors.New("invalid input")
ErrStorage = errors.New("storage retrieval fail")
ErrEngineType = errors.New("incompatible engine")
ErrEngineInit = errors.New("engine init fail")
ErrEngineExec = errors.New("engine exec fail")
)
type RequestSession struct {
Ctx context.Context
Config engine.Config
Engine engine.Engine
Input []byte
Storage *storage.Storage
Writer io.Writer
Ctx context.Context
Config engine.Config
Engine engine.Engine
Input []byte
Storage *storage.Storage
Writer io.Writer
Continue bool
}
// TODO: seems like can remove this.
type RequestParser interface {
GetSessionId(rq any) (string, error)
GetSessionId(context context.Context, rq any) (string, error)
GetInput(rq any) ([]byte, error)
}

View File

@ -28,7 +28,7 @@ import (
)
var (
logg = logging.NewVanilla().WithDomain("ussdmenuhandler").WithContextKey("session-id")
logg = logging.NewVanilla().WithDomain("ussdmenuhandler").WithContextKey("SessionId")
scriptDir = path.Join("services", "registration")
translationDir = path.Join(scriptDir, "locale")
)
@ -124,7 +124,7 @@ func (h *Handlers) Init(ctx context.Context, sym string, input []byte) (resource
sessionId, ok := ctx.Value("SessionId").(string)
if ok {
context.WithValue(ctx, "session-id", sessionId)
ctx = context.WithValue(ctx, "SessionId", sessionId)
}
flag_admin_privilege, _ := h.flagManager.GetFlag("flag_admin_privilege")

View File

@ -15,16 +15,14 @@ import (
)
type ATRequestParser struct {
Context context.Context
}
func (arp *ATRequestParser) GetSessionId(rq any) (string, error) {
func (arp *ATRequestParser) GetSessionId(ctx context.Context, rq any) (string, error) {
rqv, ok := rq.(*http.Request)
if !ok {
logg.Warnf("got an invalid request", "req", rq)
return "", handlers.ErrInvalidRequest
}
// Capture body (if any) for logging
body, err := io.ReadAll(rqv.Body)
if err != nil {
@ -43,9 +41,9 @@ func (arp *ATRequestParser) GetSessionId(rq any) (string, error) {
decodedStr := string(logBytes)
sessionId, err := extractATSessionId(decodedStr)
if err != nil {
context.WithValue(arp.Context, "at-session-id", sessionId)
ctx = context.WithValue(ctx, "AT-SessionId", sessionId)
}
logg.Debugf("Received request:", decodedStr)
logg.DebugCtxf(ctx, "Received request:", decodedStr)
}
if err := rqv.ParseForm(); err != nil {

View File

@ -10,7 +10,7 @@ import (
)
var (
logg = logging.NewVanilla().WithDomain("atserver")
logg = logging.NewVanilla().WithDomain("atserver").WithContextKey("SessionId").WithContextKey("AT-SessionId")
)
type ATSessionHandler struct {
@ -34,7 +34,7 @@ func (ash *ATSessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request)
rp := ash.GetRequestParser()
cfg := ash.GetConfig()
cfg.SessionId, err = rp.GetSessionId(req)
cfg.SessionId, err = rp.GetSessionId(req.Context(), req)
if err != nil {
logg.ErrorCtxf(rqs.Ctx, "", "header processing error", err)
ash.WriteError(w, 400, err)

View File

@ -1,6 +1,7 @@
package http
import (
"context"
"io/ioutil"
"net/http"
@ -10,7 +11,7 @@ import (
type DefaultRequestParser struct {
}
func (rp *DefaultRequestParser) GetSessionId(rq any) (string, error) {
func (rp *DefaultRequestParser) GetSessionId(ctx context.Context, rq any) (string, error) {
rqv, ok := rq.(*http.Request)
if !ok {
return "", handlers.ErrInvalidRequest
@ -34,5 +35,3 @@ func (rp *DefaultRequestParser) GetInput(rq any) ([]byte, error) {
}
return v, nil
}

View File

@ -46,7 +46,7 @@ func (f *SessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
rp := f.GetRequestParser()
cfg := f.GetConfig()
cfg.SessionId, err = rp.GetSessionId(req)
cfg.SessionId, err = rp.GetSessionId(req.Context(), req)
if err != nil {
logg.ErrorCtxf(rqs.Ctx, "", "header processing error", err)
f.WriteError(w, 400, err)

View File

@ -2,6 +2,7 @@ package http
import (
"bytes"
"context"
"errors"
"net/http"
"net/http/httptest"
@ -161,7 +162,7 @@ func TestDefaultRequestParser_GetSessionId(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
id, err := parser.GetSessionId(tt.request)
id, err := parser.GetSessionId(context.Background(),tt.request)
if id != tt.expectedID {
t.Errorf("Expected session ID %s, got %s", tt.expectedID, id)

View File

@ -1,12 +1,14 @@
package httpmocks
import "context"
// MockRequestParser implements the handlers.RequestParser interface for testing
type MockRequestParser struct {
GetSessionIdFunc func(any) (string, error)
GetInputFunc func(any) ([]byte, error)
}
func (m *MockRequestParser) GetSessionId(rq any) (string, error) {
func (m *MockRequestParser) GetSessionId(ctx context.Context, rq any) (string, error) {
return m.GetSessionIdFunc(rq)
}