Merge pull request 'log-session-id-at-sessionid' (#251) from log-session-id-at-sessionid into master
Reviewed-on: urdt/ussd#251 Reviewed-by: lash <accounts-grassrootseconomics@holbrook.no>
This commit is contained in:
commit
c995143543
@ -121,9 +121,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
defer stateStore.Close()
|
defer stateStore.Close()
|
||||||
|
|
||||||
rp := &at.ATRequestParser{
|
rp := &at.ATRequestParser{}
|
||||||
Context: ctx,
|
|
||||||
}
|
|
||||||
bsh := handlers.NewBaseSessionHandler(cfg, rs, stateStore, userdataStore, rp, hl)
|
bsh := handlers.NewBaseSessionHandler(cfg, rs, stateStore, userdataStore, rp, hl)
|
||||||
sh := httpserver.NewATSessionHandler(bsh)
|
sh := httpserver.NewATSessionHandler(bsh)
|
||||||
|
|
||||||
|
@ -21,8 +21,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
logg = logging.NewVanilla()
|
logg = logging.NewVanilla()
|
||||||
scriptDir = path.Join("services", "registration")
|
scriptDir = path.Join("services", "registration")
|
||||||
menuSeparator = ": "
|
menuSeparator = ": "
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -35,7 +35,7 @@ type asyncRequestParser struct {
|
|||||||
input []byte
|
input []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *asyncRequestParser) GetSessionId(r any) (string, error) {
|
func (p *asyncRequestParser) GetSessionId(ctx context.Context, r any) (string, error) {
|
||||||
return p.sessionId, nil
|
return p.sessionId, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,9 +6,9 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
|
|
||||||
"git.defalsify.org/vise.git/engine"
|
"git.defalsify.org/vise.git/engine"
|
||||||
"git.defalsify.org/vise.git/resource"
|
|
||||||
"git.defalsify.org/vise.git/persist"
|
|
||||||
"git.defalsify.org/vise.git/logging"
|
"git.defalsify.org/vise.git/logging"
|
||||||
|
"git.defalsify.org/vise.git/persist"
|
||||||
|
"git.defalsify.org/vise.git/resource"
|
||||||
|
|
||||||
"git.grassecon.net/urdt/ussd/internal/storage"
|
"git.grassecon.net/urdt/ussd/internal/storage"
|
||||||
)
|
)
|
||||||
@ -20,33 +20,33 @@ var (
|
|||||||
var (
|
var (
|
||||||
ErrInvalidRequest = errors.New("invalid request for context")
|
ErrInvalidRequest = errors.New("invalid request for context")
|
||||||
ErrSessionMissing = errors.New("missing session")
|
ErrSessionMissing = errors.New("missing session")
|
||||||
ErrInvalidInput = errors.New("invalid input")
|
ErrInvalidInput = errors.New("invalid input")
|
||||||
ErrStorage = errors.New("storage retrieval fail")
|
ErrStorage = errors.New("storage retrieval fail")
|
||||||
ErrEngineType = errors.New("incompatible engine")
|
ErrEngineType = errors.New("incompatible engine")
|
||||||
ErrEngineInit = errors.New("engine init fail")
|
ErrEngineInit = errors.New("engine init fail")
|
||||||
ErrEngineExec = errors.New("engine exec fail")
|
ErrEngineExec = errors.New("engine exec fail")
|
||||||
)
|
)
|
||||||
|
|
||||||
type RequestSession struct {
|
type RequestSession struct {
|
||||||
Ctx context.Context
|
Ctx context.Context
|
||||||
Config engine.Config
|
Config engine.Config
|
||||||
Engine engine.Engine
|
Engine engine.Engine
|
||||||
Input []byte
|
Input []byte
|
||||||
Storage *storage.Storage
|
Storage *storage.Storage
|
||||||
Writer io.Writer
|
Writer io.Writer
|
||||||
Continue bool
|
Continue bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: seems like can remove this.
|
// TODO: seems like can remove this.
|
||||||
type RequestParser interface {
|
type RequestParser interface {
|
||||||
GetSessionId(rq any) (string, error)
|
GetSessionId(context context.Context, rq any) (string, error)
|
||||||
GetInput(rq any) ([]byte, error)
|
GetInput(rq any) ([]byte, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type RequestHandler interface {
|
type RequestHandler interface {
|
||||||
GetConfig() engine.Config
|
GetConfig() engine.Config
|
||||||
GetRequestParser() RequestParser
|
GetRequestParser() RequestParser
|
||||||
GetEngine(cfg engine.Config, rs resource.Resource, pe *persist.Persister) engine.Engine
|
GetEngine(cfg engine.Config, rs resource.Resource, pe *persist.Persister) engine.Engine
|
||||||
Process(rs RequestSession) (RequestSession, error)
|
Process(rs RequestSession) (RequestSession, error)
|
||||||
Output(rs RequestSession) (RequestSession, error)
|
Output(rs RequestSession) (RequestSession, error)
|
||||||
Reset(rs RequestSession) (RequestSession, error)
|
Reset(rs RequestSession) (RequestSession, error)
|
||||||
|
@ -28,7 +28,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
logg = logging.NewVanilla().WithDomain("ussdmenuhandler").WithContextKey("session-id")
|
logg = logging.NewVanilla().WithDomain("ussdmenuhandler").WithContextKey("SessionId")
|
||||||
scriptDir = path.Join("services", "registration")
|
scriptDir = path.Join("services", "registration")
|
||||||
translationDir = path.Join(scriptDir, "locale")
|
translationDir = path.Join(scriptDir, "locale")
|
||||||
)
|
)
|
||||||
@ -124,7 +124,7 @@ func (h *Handlers) Init(ctx context.Context, sym string, input []byte) (resource
|
|||||||
|
|
||||||
sessionId, ok := ctx.Value("SessionId").(string)
|
sessionId, ok := ctx.Value("SessionId").(string)
|
||||||
if ok {
|
if ok {
|
||||||
context.WithValue(ctx, "session-id", sessionId)
|
ctx = context.WithValue(ctx, "SessionId", sessionId)
|
||||||
}
|
}
|
||||||
|
|
||||||
flag_admin_privilege, _ := h.flagManager.GetFlag("flag_admin_privilege")
|
flag_admin_privilege, _ := h.flagManager.GetFlag("flag_admin_privilege")
|
||||||
|
@ -15,16 +15,14 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type ATRequestParser struct {
|
type ATRequestParser struct {
|
||||||
Context context.Context
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (arp *ATRequestParser) GetSessionId(rq any) (string, error) {
|
func (arp *ATRequestParser) GetSessionId(ctx context.Context, rq any) (string, error) {
|
||||||
rqv, ok := rq.(*http.Request)
|
rqv, ok := rq.(*http.Request)
|
||||||
if !ok {
|
if !ok {
|
||||||
logg.Warnf("got an invalid request", "req", rq)
|
logg.Warnf("got an invalid request", "req", rq)
|
||||||
return "", handlers.ErrInvalidRequest
|
return "", handlers.ErrInvalidRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
// Capture body (if any) for logging
|
// Capture body (if any) for logging
|
||||||
body, err := io.ReadAll(rqv.Body)
|
body, err := io.ReadAll(rqv.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -43,9 +41,9 @@ func (arp *ATRequestParser) GetSessionId(rq any) (string, error) {
|
|||||||
decodedStr := string(logBytes)
|
decodedStr := string(logBytes)
|
||||||
sessionId, err := extractATSessionId(decodedStr)
|
sessionId, err := extractATSessionId(decodedStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
context.WithValue(arp.Context, "at-session-id", sessionId)
|
ctx = context.WithValue(ctx, "AT-SessionId", sessionId)
|
||||||
}
|
}
|
||||||
logg.Debugf("Received request:", decodedStr)
|
logg.DebugCtxf(ctx, "Received request:", decodedStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := rqv.ParseForm(); err != nil {
|
if err := rqv.ParseForm(); err != nil {
|
||||||
|
@ -10,7 +10,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
logg = logging.NewVanilla().WithDomain("atserver")
|
logg = logging.NewVanilla().WithDomain("atserver").WithContextKey("SessionId").WithContextKey("AT-SessionId")
|
||||||
)
|
)
|
||||||
|
|
||||||
type ATSessionHandler struct {
|
type ATSessionHandler struct {
|
||||||
@ -34,7 +34,7 @@ func (ash *ATSessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request)
|
|||||||
|
|
||||||
rp := ash.GetRequestParser()
|
rp := ash.GetRequestParser()
|
||||||
cfg := ash.GetConfig()
|
cfg := ash.GetConfig()
|
||||||
cfg.SessionId, err = rp.GetSessionId(req)
|
cfg.SessionId, err = rp.GetSessionId(req.Context(), req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logg.ErrorCtxf(rqs.Ctx, "", "header processing error", err)
|
logg.ErrorCtxf(rqs.Ctx, "", "header processing error", err)
|
||||||
ash.WriteError(w, 400, err)
|
ash.WriteError(w, 400, err)
|
||||||
@ -48,7 +48,7 @@ func (ash *ATSessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request)
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
rqs, err = ash.Process(rqs)
|
rqs, err = ash.Process(rqs)
|
||||||
switch err {
|
switch err {
|
||||||
case nil: // set code to 200 if no err
|
case nil: // set code to 200 if no err
|
||||||
code = 200
|
code = 200
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package http
|
package http
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
@ -10,7 +11,7 @@ import (
|
|||||||
type DefaultRequestParser struct {
|
type DefaultRequestParser struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rp *DefaultRequestParser) GetSessionId(rq any) (string, error) {
|
func (rp *DefaultRequestParser) GetSessionId(ctx context.Context, rq any) (string, error) {
|
||||||
rqv, ok := rq.(*http.Request)
|
rqv, ok := rq.(*http.Request)
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", handlers.ErrInvalidRequest
|
return "", handlers.ErrInvalidRequest
|
||||||
@ -34,5 +35,3 @@ func (rp *DefaultRequestParser) GetInput(rq any) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
return v, nil
|
return v, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -46,7 +46,7 @@ func (f *SessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
|||||||
|
|
||||||
rp := f.GetRequestParser()
|
rp := f.GetRequestParser()
|
||||||
cfg := f.GetConfig()
|
cfg := f.GetConfig()
|
||||||
cfg.SessionId, err = rp.GetSessionId(req)
|
cfg.SessionId, err = rp.GetSessionId(req.Context(), req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logg.ErrorCtxf(rqs.Ctx, "", "header processing error", err)
|
logg.ErrorCtxf(rqs.Ctx, "", "header processing error", err)
|
||||||
f.WriteError(w, 400, err)
|
f.WriteError(w, 400, err)
|
||||||
|
@ -2,6 +2,7 @@ package http
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@ -161,7 +162,7 @@ func TestDefaultRequestParser_GetSessionId(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
id, err := parser.GetSessionId(tt.request)
|
id, err := parser.GetSessionId(context.Background(),tt.request)
|
||||||
|
|
||||||
if id != tt.expectedID {
|
if id != tt.expectedID {
|
||||||
t.Errorf("Expected session ID %s, got %s", tt.expectedID, id)
|
t.Errorf("Expected session ID %s, got %s", tt.expectedID, id)
|
||||||
|
@ -1,12 +1,14 @@
|
|||||||
package httpmocks
|
package httpmocks
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
// MockRequestParser implements the handlers.RequestParser interface for testing
|
// MockRequestParser implements the handlers.RequestParser interface for testing
|
||||||
type MockRequestParser struct {
|
type MockRequestParser struct {
|
||||||
GetSessionIdFunc func(any) (string, error)
|
GetSessionIdFunc func(any) (string, error)
|
||||||
GetInputFunc func(any) ([]byte, error)
|
GetInputFunc func(any) ([]byte, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockRequestParser) GetSessionId(rq any) (string, error) {
|
func (m *MockRequestParser) GetSessionId(ctx context.Context, rq any) (string, error) {
|
||||||
return m.GetSessionIdFunc(rq)
|
return m.GetSessionIdFunc(rq)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user