Compare commits
18 Commits
v0.8.0-bet
...
lash/stale
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
df095f0873 | ||
| 47b5ff0435 | |||
|
|
25867cf05e
|
||
|
|
d5a2680500
|
||
|
|
d950b10b50
|
||
|
|
bcb3ab905e
|
||
|
|
3ed9caf16d
|
||
|
|
86464c31d2 | ||
|
|
67007fcd48
|
||
|
|
f1b258fa6d
|
||
|
|
473a7fc480 | ||
|
|
94551ba37f
|
||
|
|
973a69455e | ||
|
|
0af7379ae4 | ||
|
|
ce30cb740e
|
||
|
|
659fd00c53
|
||
|
|
9b3ed0d6ae
|
||
|
|
fbcde2f322
|
@@ -1,30 +1,25 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"git.defalsify.org/vise.git/engine"
|
||||
"git.defalsify.org/vise.git/logging"
|
||||
"git.defalsify.org/vise.git/resource"
|
||||
|
||||
"git.grassecon.net/urdt/ussd/common"
|
||||
"git.grassecon.net/urdt/ussd/config"
|
||||
"git.grassecon.net/urdt/ussd/initializers"
|
||||
"git.grassecon.net/urdt/ussd/internal/handlers"
|
||||
httpserver "git.grassecon.net/urdt/ussd/internal/http"
|
||||
"git.grassecon.net/urdt/ussd/internal/http/at"
|
||||
httpserver "git.grassecon.net/urdt/ussd/internal/http/at"
|
||||
"git.grassecon.net/urdt/ussd/internal/storage"
|
||||
"git.grassecon.net/urdt/ussd/remote"
|
||||
)
|
||||
@@ -39,113 +34,6 @@ var (
|
||||
func init() {
|
||||
initializers.LoadEnvVariables()
|
||||
}
|
||||
|
||||
type atRequestParser struct {
|
||||
context context.Context
|
||||
}
|
||||
|
||||
func parseQueryParams(query string) map[string]string {
|
||||
params := make(map[string]string)
|
||||
|
||||
queryParams := strings.Split(query, "&")
|
||||
for _, param := range queryParams {
|
||||
// Split each key-value pair by '='
|
||||
parts := strings.SplitN(param, "=", 2)
|
||||
if len(parts) == 2 {
|
||||
params[parts[0]] = parts[1]
|
||||
}
|
||||
}
|
||||
return params
|
||||
}
|
||||
|
||||
func extractATSessionId(decodedStr string) (string, error) {
|
||||
var data map[string]string
|
||||
err := json.Unmarshal([]byte(decodedStr), &data)
|
||||
|
||||
if err != nil {
|
||||
logg.Errorf("Error unmarshalling JSON: %v", err)
|
||||
return "", nil
|
||||
}
|
||||
decodedBody, err := url.QueryUnescape(data["body"])
|
||||
if err != nil {
|
||||
logg.Errorf("Error URL-decoding body: %v", err)
|
||||
return "", nil
|
||||
}
|
||||
params := parseQueryParams(decodedBody)
|
||||
|
||||
sessionId := params["sessionId"]
|
||||
return sessionId, nil
|
||||
|
||||
}
|
||||
|
||||
func (arp *atRequestParser) GetSessionId(rq any) (string, error) {
|
||||
rqv, ok := rq.(*http.Request)
|
||||
if !ok {
|
||||
logg.Warnf("got an invalid request", "req", rq)
|
||||
return "", handlers.ErrInvalidRequest
|
||||
}
|
||||
|
||||
// Capture body (if any) for logging
|
||||
body, err := io.ReadAll(rqv.Body)
|
||||
if err != nil {
|
||||
logg.Warnf("failed to read request body", "err", err)
|
||||
return "", fmt.Errorf("failed to read request body: %v", err)
|
||||
}
|
||||
// Reset the body for further reading
|
||||
rqv.Body = io.NopCloser(bytes.NewReader(body))
|
||||
|
||||
// Log the body as JSON
|
||||
bodyLog := map[string]string{"body": string(body)}
|
||||
logBytes, err := json.Marshal(bodyLog)
|
||||
if err != nil {
|
||||
logg.Warnf("failed to marshal request body", "err", err)
|
||||
} else {
|
||||
decodedStr := string(logBytes)
|
||||
sessionId, err := extractATSessionId(decodedStr)
|
||||
if err != nil {
|
||||
context.WithValue(arp.context, "at-session-id", sessionId)
|
||||
}
|
||||
logg.Debugf("Received request:", decodedStr)
|
||||
}
|
||||
|
||||
if err := rqv.ParseForm(); err != nil {
|
||||
logg.Warnf("failed to parse form data", "err", err)
|
||||
return "", fmt.Errorf("failed to parse form data: %v", err)
|
||||
}
|
||||
|
||||
phoneNumber := rqv.FormValue("phoneNumber")
|
||||
if phoneNumber == "" {
|
||||
return "", fmt.Errorf("no phone number found")
|
||||
}
|
||||
|
||||
formattedNumber, err := common.FormatPhoneNumber(phoneNumber)
|
||||
if err != nil {
|
||||
logg.Warnf("failed to format phone number", "err", err)
|
||||
return "", fmt.Errorf("failed to format number")
|
||||
}
|
||||
|
||||
return formattedNumber, nil
|
||||
}
|
||||
|
||||
func (arp *atRequestParser) GetInput(rq any) ([]byte, error) {
|
||||
rqv, ok := rq.(*http.Request)
|
||||
if !ok {
|
||||
return nil, handlers.ErrInvalidRequest
|
||||
}
|
||||
if err := rqv.ParseForm(); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse form data: %v", err)
|
||||
}
|
||||
|
||||
text := rqv.FormValue("text")
|
||||
|
||||
parts := strings.Split(text, "*")
|
||||
if len(parts) == 0 {
|
||||
return nil, fmt.Errorf("no input found")
|
||||
}
|
||||
|
||||
return []byte(parts[len(parts)-1]), nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
config.LoadConfig()
|
||||
|
||||
@@ -233,8 +121,8 @@ func main() {
|
||||
}
|
||||
defer stateStore.Close()
|
||||
|
||||
rp := &atRequestParser{
|
||||
context: ctx,
|
||||
rp := &at.ATRequestParser{
|
||||
Context: ctx,
|
||||
}
|
||||
bsh := handlers.NewBaseSessionHandler(cfg, rs, stateStore, userdataStore, rp, hl)
|
||||
sh := httpserver.NewATSessionHandler(bsh)
|
||||
|
||||
@@ -8,14 +8,15 @@ import (
|
||||
"git.defalsify.org/vise.git/resource"
|
||||
"git.defalsify.org/vise.git/persist"
|
||||
"git.grassecon.net/urdt/ussd/internal/storage"
|
||||
dbstorage "git.grassecon.net/urdt/ussd/internal/storage/db"
|
||||
)
|
||||
|
||||
func StoreToDb(store *UserDataStore) db.Db {
|
||||
return store.Db
|
||||
}
|
||||
|
||||
func StoreToPrefixDb(store *UserDataStore, pfx []byte) storage.PrefixDb {
|
||||
return storage.NewSubPrefixDb(store.Db, pfx)
|
||||
func StoreToPrefixDb(store *UserDataStore, pfx []byte) dbstorage.PrefixDb {
|
||||
return dbstorage.NewSubPrefixDb(store.Db, pfx)
|
||||
}
|
||||
|
||||
type StorageServices interface {
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.grassecon.net/urdt/ussd/internal/storage"
|
||||
dbstorage "git.grassecon.net/urdt/ussd/internal/storage/db"
|
||||
dataserviceapi "github.com/grassrootseconomics/ussd-data-service/pkg/api"
|
||||
)
|
||||
|
||||
@@ -56,7 +56,7 @@ func ProcessTransfers(transfers []dataserviceapi.Last10TxResponse) TransferMetad
|
||||
|
||||
// GetTransferData retrieves and matches transfer data
|
||||
// returns a formatted string of the full transaction/statement
|
||||
func GetTransferData(ctx context.Context, db storage.PrefixDb, publicKey string, index int) (string, error) {
|
||||
func GetTransferData(ctx context.Context, db dbstorage.PrefixDb, publicKey string, index int) (string, error) {
|
||||
keys := []DataTyp{DATA_TX_SENDERS, DATA_TX_RECIPIENTS, DATA_TX_VALUES, DATA_TX_ADDRESSES, DATA_TX_HASHES, DATA_TX_DATES, DATA_TX_SYMBOLS}
|
||||
data := make(map[DataTyp]string)
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"math/big"
|
||||
"strings"
|
||||
|
||||
"git.grassecon.net/urdt/ussd/internal/storage"
|
||||
dbstorage "git.grassecon.net/urdt/ussd/internal/storage/db"
|
||||
dataserviceapi "github.com/grassrootseconomics/ussd-data-service/pkg/api"
|
||||
)
|
||||
|
||||
@@ -63,7 +63,7 @@ func ScaleDownBalance(balance, decimals string) string {
|
||||
}
|
||||
|
||||
// GetVoucherData retrieves and matches voucher data
|
||||
func GetVoucherData(ctx context.Context, db storage.PrefixDb, input string) (*dataserviceapi.TokenHoldings, error) {
|
||||
func GetVoucherData(ctx context.Context, db dbstorage.PrefixDb, input string) (*dataserviceapi.TokenHoldings, error) {
|
||||
keys := []DataTyp{DATA_VOUCHER_SYMBOLS, DATA_VOUCHER_BALANCES, DATA_VOUCHER_DECIMALS, DATA_VOUCHER_ADDRESSES}
|
||||
data := make(map[DataTyp]string)
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
|
||||
visedb "git.defalsify.org/vise.git/db"
|
||||
memdb "git.defalsify.org/vise.git/db/mem"
|
||||
"git.grassecon.net/urdt/ussd/internal/storage"
|
||||
dbstorage "git.grassecon.net/urdt/ussd/internal/storage/db"
|
||||
dataserviceapi "github.com/grassrootseconomics/ussd-data-service/pkg/api"
|
||||
)
|
||||
|
||||
@@ -86,7 +86,7 @@ func TestGetVoucherData(t *testing.T) {
|
||||
}
|
||||
|
||||
prefix := ToBytes(visedb.DATATYPE_USERDATA)
|
||||
spdb := storage.NewSubPrefixDb(db, prefix)
|
||||
spdb := dbstorage.NewSubPrefixDb(db, prefix)
|
||||
|
||||
// Test voucher data
|
||||
mockData := map[DataTyp][]byte{
|
||||
|
||||
@@ -23,7 +23,7 @@ import (
|
||||
"git.grassecon.net/urdt/ussd/remote"
|
||||
"gopkg.in/leonelquinteros/gotext.v1"
|
||||
|
||||
"git.grassecon.net/urdt/ussd/internal/storage"
|
||||
dbstorage "git.grassecon.net/urdt/ussd/internal/storage/db"
|
||||
dataserviceapi "github.com/grassrootseconomics/ussd-data-service/pkg/api"
|
||||
)
|
||||
|
||||
@@ -64,7 +64,7 @@ type Handlers struct {
|
||||
adminstore *utils.AdminStore
|
||||
flagManager *asm.FlagParser
|
||||
accountService remote.AccountServiceInterface
|
||||
prefixDb storage.PrefixDb
|
||||
prefixDb dbstorage.PrefixDb
|
||||
profile *models.Profile
|
||||
ReplaceSeparatorFunc func(string) string
|
||||
}
|
||||
@@ -80,7 +80,7 @@ func NewHandlers(appFlags *asm.FlagParser, userdataStore db.Db, adminstore *util
|
||||
|
||||
// Instantiate the SubPrefixDb with "DATATYPE_USERDATA" prefix
|
||||
prefix := common.ToBytes(db.DATATYPE_USERDATA)
|
||||
prefixDb := storage.NewSubPrefixDb(userdataStore, prefix)
|
||||
prefixDb := dbstorage.NewSubPrefixDb(userdataStore, prefix)
|
||||
|
||||
h := &Handlers{
|
||||
userdataStore: userDb,
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
"git.defalsify.org/vise.git/persist"
|
||||
"git.defalsify.org/vise.git/resource"
|
||||
"git.defalsify.org/vise.git/state"
|
||||
"git.grassecon.net/urdt/ussd/internal/storage"
|
||||
dbstorage "git.grassecon.net/urdt/ussd/internal/storage/db"
|
||||
"git.grassecon.net/urdt/ussd/internal/testutil/mocks"
|
||||
"git.grassecon.net/urdt/ussd/internal/testutil/testservice"
|
||||
"git.grassecon.net/urdt/ussd/internal/utils"
|
||||
@@ -59,14 +59,14 @@ func InitializeTestStore(t *testing.T) (context.Context, *common.UserDataStore)
|
||||
return ctx, store
|
||||
}
|
||||
|
||||
func InitializeTestSubPrefixDb(t *testing.T, ctx context.Context) *storage.SubPrefixDb {
|
||||
func InitializeTestSubPrefixDb(t *testing.T, ctx context.Context) *dbstorage.SubPrefixDb {
|
||||
db := memdb.NewMemDb()
|
||||
err := db.Connect(ctx, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
prefix := common.ToBytes(visedb.DATATYPE_USERDATA)
|
||||
spdb := storage.NewSubPrefixDb(db, prefix)
|
||||
spdb := dbstorage.NewSubPrefixDb(db, prefix)
|
||||
|
||||
return spdb
|
||||
}
|
||||
|
||||
121
internal/http/at/parse.go
Normal file
121
internal/http/at/parse.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package at
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"git.grassecon.net/urdt/ussd/common"
|
||||
"git.grassecon.net/urdt/ussd/internal/handlers"
|
||||
)
|
||||
|
||||
type ATRequestParser struct {
|
||||
Context context.Context
|
||||
}
|
||||
|
||||
func (arp *ATRequestParser) GetSessionId(rq any) (string, error) {
|
||||
rqv, ok := rq.(*http.Request)
|
||||
if !ok {
|
||||
logg.Warnf("got an invalid request", "req", rq)
|
||||
return "", handlers.ErrInvalidRequest
|
||||
}
|
||||
|
||||
// Capture body (if any) for logging
|
||||
body, err := io.ReadAll(rqv.Body)
|
||||
if err != nil {
|
||||
logg.Warnf("failed to read request body", "err", err)
|
||||
return "", fmt.Errorf("failed to read request body: %v", err)
|
||||
}
|
||||
// Reset the body for further reading
|
||||
rqv.Body = io.NopCloser(bytes.NewReader(body))
|
||||
|
||||
// Log the body as JSON
|
||||
bodyLog := map[string]string{"body": string(body)}
|
||||
logBytes, err := json.Marshal(bodyLog)
|
||||
if err != nil {
|
||||
logg.Warnf("failed to marshal request body", "err", err)
|
||||
} else {
|
||||
decodedStr := string(logBytes)
|
||||
sessionId, err := extractATSessionId(decodedStr)
|
||||
if err != nil {
|
||||
context.WithValue(arp.Context, "at-session-id", sessionId)
|
||||
}
|
||||
logg.Debugf("Received request:", decodedStr)
|
||||
}
|
||||
|
||||
if err := rqv.ParseForm(); err != nil {
|
||||
logg.Warnf("failed to parse form data", "err", err)
|
||||
return "", fmt.Errorf("failed to parse form data: %v", err)
|
||||
}
|
||||
|
||||
phoneNumber := rqv.FormValue("phoneNumber")
|
||||
if phoneNumber == "" {
|
||||
return "", fmt.Errorf("no phone number found")
|
||||
}
|
||||
|
||||
formattedNumber, err := common.FormatPhoneNumber(phoneNumber)
|
||||
if err != nil {
|
||||
logg.Warnf("failed to format phone number", "err", err)
|
||||
return "", fmt.Errorf("failed to format number")
|
||||
}
|
||||
|
||||
return formattedNumber, nil
|
||||
}
|
||||
|
||||
func (arp *ATRequestParser) GetInput(rq any) ([]byte, error) {
|
||||
rqv, ok := rq.(*http.Request)
|
||||
if !ok {
|
||||
return nil, handlers.ErrInvalidRequest
|
||||
}
|
||||
if err := rqv.ParseForm(); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse form data: %v", err)
|
||||
}
|
||||
|
||||
text := rqv.FormValue("text")
|
||||
|
||||
parts := strings.Split(text, "*")
|
||||
if len(parts) == 0 {
|
||||
return nil, fmt.Errorf("no input found")
|
||||
}
|
||||
|
||||
return []byte(parts[len(parts)-1]), nil
|
||||
}
|
||||
|
||||
func parseQueryParams(query string) map[string]string {
|
||||
params := make(map[string]string)
|
||||
|
||||
queryParams := strings.Split(query, "&")
|
||||
for _, param := range queryParams {
|
||||
// Split each key-value pair by '='
|
||||
parts := strings.SplitN(param, "=", 2)
|
||||
if len(parts) == 2 {
|
||||
params[parts[0]] = parts[1]
|
||||
}
|
||||
}
|
||||
return params
|
||||
}
|
||||
|
||||
func extractATSessionId(decodedStr string) (string, error) {
|
||||
var data map[string]string
|
||||
err := json.Unmarshal([]byte(decodedStr), &data)
|
||||
|
||||
if err != nil {
|
||||
logg.Errorf("Error unmarshalling JSON: %v", err)
|
||||
return "", nil
|
||||
}
|
||||
decodedBody, err := url.QueryUnescape(data["body"])
|
||||
if err != nil {
|
||||
logg.Errorf("Error URL-decoding body: %v", err)
|
||||
return "", nil
|
||||
}
|
||||
params := parseQueryParams(decodedBody)
|
||||
|
||||
sessionId := params["sessionId"]
|
||||
return sessionId, nil
|
||||
|
||||
}
|
||||
@@ -1,19 +1,25 @@
|
||||
package http
|
||||
package at
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"git.defalsify.org/vise.git/logging"
|
||||
"git.grassecon.net/urdt/ussd/internal/handlers"
|
||||
httpserver "git.grassecon.net/urdt/ussd/internal/http"
|
||||
)
|
||||
|
||||
var (
|
||||
logg = logging.NewVanilla().WithDomain("atserver")
|
||||
)
|
||||
|
||||
type ATSessionHandler struct {
|
||||
*SessionHandler
|
||||
*httpserver.SessionHandler
|
||||
}
|
||||
|
||||
func NewATSessionHandler(h handlers.RequestHandler) *ATSessionHandler {
|
||||
return &ATSessionHandler{
|
||||
SessionHandler: ToSessionHandler(h),
|
||||
SessionHandler: httpserver.ToSessionHandler(h),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,14 +37,14 @@ func (ash *ATSessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request)
|
||||
cfg.SessionId, err = rp.GetSessionId(req)
|
||||
if err != nil {
|
||||
logg.ErrorCtxf(rqs.Ctx, "", "header processing error", err)
|
||||
ash.writeError(w, 400, err)
|
||||
ash.WriteError(w, 400, err)
|
||||
return
|
||||
}
|
||||
rqs.Config = cfg
|
||||
rqs.Input, err = rp.GetInput(req)
|
||||
if err != nil {
|
||||
logg.ErrorCtxf(rqs.Ctx, "", "header processing error", err)
|
||||
ash.writeError(w, 400, err)
|
||||
ash.WriteError(w, 400, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -53,7 +59,7 @@ func (ash *ATSessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request)
|
||||
}
|
||||
|
||||
if code != 200 {
|
||||
ash.writeError(w, 500, err)
|
||||
ash.WriteError(w, 500, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -61,13 +67,13 @@ func (ash *ATSessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request)
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
rqs, err = ash.Output(rqs)
|
||||
if err != nil {
|
||||
ash.writeError(w, 500, err)
|
||||
ash.WriteError(w, 500, err)
|
||||
return
|
||||
}
|
||||
|
||||
rqs, err = ash.Reset(rqs)
|
||||
if err != nil {
|
||||
ash.writeError(w, 500, err)
|
||||
ash.WriteError(w, 500, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -89,4 +95,4 @@ func (ash *ATSessionHandler) Output(rqs handlers.RequestSession) (handlers.Reque
|
||||
|
||||
_, err = rqs.Engine.Flush(rqs.Ctx, rqs.Writer)
|
||||
return rqs, err
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
package http
|
||||
package at
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
@@ -16,16 +15,6 @@ import (
|
||||
"git.grassecon.net/urdt/ussd/internal/testutil/mocks/httpmocks"
|
||||
)
|
||||
|
||||
// invalidRequestType is a custom type to test invalid request scenarios
|
||||
type invalidRequestType struct{}
|
||||
|
||||
// errorReader is a helper type that always returns an error when Read is called
|
||||
type errorReader struct{}
|
||||
|
||||
func (e *errorReader) Read(p []byte) (n int, err error) {
|
||||
return 0, errors.New("read error")
|
||||
}
|
||||
|
||||
func TestNewATSessionHandler(t *testing.T) {
|
||||
mockHandler := &httpmocks.MockRequestHandler{}
|
||||
ash := NewATSessionHandler(mockHandler)
|
||||
@@ -242,208 +231,4 @@ func TestATSessionHandler_Output(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionHandler_ServeHTTP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sessionID string
|
||||
input []byte
|
||||
parserErr error
|
||||
processErr error
|
||||
outputErr error
|
||||
resetErr error
|
||||
expectedStatus int
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
sessionID: "123",
|
||||
input: []byte("test input"),
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Missing Session ID",
|
||||
sessionID: "",
|
||||
parserErr: handlers.ErrSessionMissing,
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "Process Error",
|
||||
sessionID: "123",
|
||||
input: []byte("test input"),
|
||||
processErr: handlers.ErrStorage,
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
},
|
||||
{
|
||||
name: "Output Error",
|
||||
sessionID: "123",
|
||||
input: []byte("test input"),
|
||||
outputErr: errors.New("output error"),
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Reset Error",
|
||||
sessionID: "123",
|
||||
input: []byte("test input"),
|
||||
resetErr: errors.New("reset error"),
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockRequestParser := &httpmocks.MockRequestParser{
|
||||
GetSessionIdFunc: func(any) (string, error) {
|
||||
return tt.sessionID, tt.parserErr
|
||||
},
|
||||
GetInputFunc: func(any) ([]byte, error) {
|
||||
return tt.input, nil
|
||||
},
|
||||
}
|
||||
|
||||
mockRequestHandler := &httpmocks.MockRequestHandler{
|
||||
ProcessFunc: func(rs handlers.RequestSession) (handlers.RequestSession, error) {
|
||||
return rs, tt.processErr
|
||||
},
|
||||
OutputFunc: func(rs handlers.RequestSession) (handlers.RequestSession, error) {
|
||||
return rs, tt.outputErr
|
||||
},
|
||||
ResetFunc: func(rs handlers.RequestSession) (handlers.RequestSession, error) {
|
||||
return rs, tt.resetErr
|
||||
},
|
||||
GetRequestParserFunc: func() handlers.RequestParser {
|
||||
return mockRequestParser
|
||||
},
|
||||
GetConfigFunc: func() engine.Config {
|
||||
return engine.Config{}
|
||||
},
|
||||
}
|
||||
|
||||
sessionHandler := ToSessionHandler(mockRequestHandler)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBuffer(tt.input))
|
||||
req.Header.Set("X-Vise-Session", tt.sessionID)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
sessionHandler.ServeHTTP(rr, req)
|
||||
|
||||
if status := rr.Code; status != tt.expectedStatus {
|
||||
t.Errorf("handler returned wrong status code: got %v want %v",
|
||||
status, tt.expectedStatus)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionHandler_writeError(t *testing.T) {
|
||||
handler := &SessionHandler{}
|
||||
mockWriter := &httpmocks.MockWriter{}
|
||||
err := errors.New("test error")
|
||||
|
||||
handler.writeError(mockWriter, http.StatusBadRequest, err)
|
||||
|
||||
if mockWriter.WrittenString != "" {
|
||||
t.Errorf("Expected empty body, got %s", mockWriter.WrittenString)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultRequestParser_GetSessionId(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
request any
|
||||
expectedID string
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
name: "Valid Session ID",
|
||||
request: func() *http.Request {
|
||||
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
req.Header.Set("X-Vise-Session", "123456")
|
||||
return req
|
||||
}(),
|
||||
expectedID: "123456",
|
||||
expectedError: nil,
|
||||
},
|
||||
{
|
||||
name: "Missing Session ID",
|
||||
request: httptest.NewRequest(http.MethodPost, "/", nil),
|
||||
expectedID: "",
|
||||
expectedError: handlers.ErrSessionMissing,
|
||||
},
|
||||
{
|
||||
name: "Invalid Request Type",
|
||||
request: invalidRequestType{},
|
||||
expectedID: "",
|
||||
expectedError: handlers.ErrInvalidRequest,
|
||||
},
|
||||
}
|
||||
|
||||
parser := &DefaultRequestParser{}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
id, err := parser.GetSessionId(tt.request)
|
||||
|
||||
if id != tt.expectedID {
|
||||
t.Errorf("Expected session ID %s, got %s", tt.expectedID, id)
|
||||
}
|
||||
|
||||
if err != tt.expectedError {
|
||||
t.Errorf("Expected error %v, got %v", tt.expectedError, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultRequestParser_GetInput(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
request any
|
||||
expectedInput []byte
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
name: "Valid Input",
|
||||
request: func() *http.Request {
|
||||
return httptest.NewRequest(http.MethodPost, "/", bytes.NewBufferString("test input"))
|
||||
}(),
|
||||
expectedInput: []byte("test input"),
|
||||
expectedError: nil,
|
||||
},
|
||||
{
|
||||
name: "Empty Input",
|
||||
request: httptest.NewRequest(http.MethodPost, "/", nil),
|
||||
expectedInput: []byte{},
|
||||
expectedError: nil,
|
||||
},
|
||||
{
|
||||
name: "Invalid Request Type",
|
||||
request: invalidRequestType{},
|
||||
expectedInput: nil,
|
||||
expectedError: handlers.ErrInvalidRequest,
|
||||
},
|
||||
{
|
||||
name: "Read Error",
|
||||
request: func() *http.Request {
|
||||
return httptest.NewRequest(http.MethodPost, "/", &errorReader{})
|
||||
}(),
|
||||
expectedInput: nil,
|
||||
expectedError: errors.New("read error"),
|
||||
},
|
||||
}
|
||||
|
||||
parser := &DefaultRequestParser{}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
input, err := parser.GetInput(tt.request)
|
||||
|
||||
if !bytes.Equal(input, tt.expectedInput) {
|
||||
t.Errorf("Expected input %s, got %s", tt.expectedInput, input)
|
||||
}
|
||||
|
||||
if err != tt.expectedError && (err == nil || err.Error() != tt.expectedError.Error()) {
|
||||
t.Errorf("Expected error %v, got %v", tt.expectedError, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
38
internal/http/parse.go
Normal file
38
internal/http/parse.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
||||
"git.grassecon.net/urdt/ussd/internal/handlers"
|
||||
)
|
||||
|
||||
type DefaultRequestParser struct {
|
||||
}
|
||||
|
||||
func (rp *DefaultRequestParser) GetSessionId(rq any) (string, error) {
|
||||
rqv, ok := rq.(*http.Request)
|
||||
if !ok {
|
||||
return "", handlers.ErrInvalidRequest
|
||||
}
|
||||
v := rqv.Header.Get("X-Vise-Session")
|
||||
if v == "" {
|
||||
return "", handlers.ErrSessionMissing
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (rp *DefaultRequestParser) GetInput(rq any) ([]byte, error) {
|
||||
rqv, ok := rq.(*http.Request)
|
||||
if !ok {
|
||||
return nil, handlers.ErrInvalidRequest
|
||||
}
|
||||
defer rqv.Body.Close()
|
||||
v, err := ioutil.ReadAll(rqv.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
@@ -14,34 +13,6 @@ var (
|
||||
logg = logging.NewVanilla().WithDomain("httpserver")
|
||||
)
|
||||
|
||||
type DefaultRequestParser struct {
|
||||
}
|
||||
|
||||
func (rp *DefaultRequestParser) GetSessionId(rq any) (string, error) {
|
||||
rqv, ok := rq.(*http.Request)
|
||||
if !ok {
|
||||
return "", handlers.ErrInvalidRequest
|
||||
}
|
||||
v := rqv.Header.Get("X-Vise-Session")
|
||||
if v == "" {
|
||||
return "", handlers.ErrSessionMissing
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (rp *DefaultRequestParser) GetInput(rq any) ([]byte, error) {
|
||||
rqv, ok := rq.(*http.Request)
|
||||
if !ok {
|
||||
return nil, handlers.ErrInvalidRequest
|
||||
}
|
||||
defer rqv.Body.Close()
|
||||
v, err := ioutil.ReadAll(rqv.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
type SessionHandler struct {
|
||||
handlers.RequestHandler
|
||||
}
|
||||
@@ -52,7 +23,7 @@ func ToSessionHandler(h handlers.RequestHandler) *SessionHandler {
|
||||
}
|
||||
}
|
||||
|
||||
func (f *SessionHandler) writeError(w http.ResponseWriter, code int, err error) {
|
||||
func (f *SessionHandler) WriteError(w http.ResponseWriter, code int, err error) {
|
||||
s := err.Error()
|
||||
w.Header().Set("Content-Length", strconv.Itoa(len(s)))
|
||||
w.WriteHeader(code)
|
||||
@@ -78,13 +49,13 @@ func (f *SessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
cfg.SessionId, err = rp.GetSessionId(req)
|
||||
if err != nil {
|
||||
logg.ErrorCtxf(rqs.Ctx, "", "header processing error", err)
|
||||
f.writeError(w, 400, err)
|
||||
f.WriteError(w, 400, err)
|
||||
}
|
||||
rqs.Config = cfg
|
||||
rqs.Input, err = rp.GetInput(req)
|
||||
if err != nil {
|
||||
logg.ErrorCtxf(rqs.Ctx, "", "header processing error", err)
|
||||
f.writeError(w, 400, err)
|
||||
f.WriteError(w, 400, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -101,7 +72,7 @@ func (f *SessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
}
|
||||
|
||||
if code != 200 {
|
||||
f.writeError(w, 500, err)
|
||||
f.WriteError(w, 500, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -110,11 +81,11 @@ func (f *SessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
rqs, err = f.Output(rqs)
|
||||
rqs, perr = f.Reset(rqs)
|
||||
if err != nil {
|
||||
f.writeError(w, 500, err)
|
||||
f.WriteError(w, 500, err)
|
||||
return
|
||||
}
|
||||
if perr != nil {
|
||||
f.writeError(w, 500, perr)
|
||||
f.WriteError(w, 500, perr)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
229
internal/http/server_test.go
Normal file
229
internal/http/server_test.go
Normal file
@@ -0,0 +1,229 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"git.defalsify.org/vise.git/engine"
|
||||
"git.grassecon.net/urdt/ussd/internal/handlers"
|
||||
"git.grassecon.net/urdt/ussd/internal/testutil/mocks/httpmocks"
|
||||
)
|
||||
|
||||
// invalidRequestType is a custom type to test invalid request scenarios
|
||||
type invalidRequestType struct{}
|
||||
|
||||
// errorReader is a helper type that always returns an error when Read is called
|
||||
type errorReader struct{}
|
||||
|
||||
func (e *errorReader) Read(p []byte) (n int, err error) {
|
||||
return 0, errors.New("read error")
|
||||
}
|
||||
|
||||
func TestSessionHandler_ServeHTTP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sessionID string
|
||||
input []byte
|
||||
parserErr error
|
||||
processErr error
|
||||
outputErr error
|
||||
resetErr error
|
||||
expectedStatus int
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
sessionID: "123",
|
||||
input: []byte("test input"),
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Missing Session ID",
|
||||
sessionID: "",
|
||||
parserErr: handlers.ErrSessionMissing,
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "Process Error",
|
||||
sessionID: "123",
|
||||
input: []byte("test input"),
|
||||
processErr: handlers.ErrStorage,
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
},
|
||||
{
|
||||
name: "Output Error",
|
||||
sessionID: "123",
|
||||
input: []byte("test input"),
|
||||
outputErr: errors.New("output error"),
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Reset Error",
|
||||
sessionID: "123",
|
||||
input: []byte("test input"),
|
||||
resetErr: errors.New("reset error"),
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockRequestParser := &httpmocks.MockRequestParser{
|
||||
GetSessionIdFunc: func(any) (string, error) {
|
||||
return tt.sessionID, tt.parserErr
|
||||
},
|
||||
GetInputFunc: func(any) ([]byte, error) {
|
||||
return tt.input, nil
|
||||
},
|
||||
}
|
||||
|
||||
mockRequestHandler := &httpmocks.MockRequestHandler{
|
||||
ProcessFunc: func(rs handlers.RequestSession) (handlers.RequestSession, error) {
|
||||
return rs, tt.processErr
|
||||
},
|
||||
OutputFunc: func(rs handlers.RequestSession) (handlers.RequestSession, error) {
|
||||
return rs, tt.outputErr
|
||||
},
|
||||
ResetFunc: func(rs handlers.RequestSession) (handlers.RequestSession, error) {
|
||||
return rs, tt.resetErr
|
||||
},
|
||||
GetRequestParserFunc: func() handlers.RequestParser {
|
||||
return mockRequestParser
|
||||
},
|
||||
GetConfigFunc: func() engine.Config {
|
||||
return engine.Config{}
|
||||
},
|
||||
}
|
||||
|
||||
sessionHandler := ToSessionHandler(mockRequestHandler)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBuffer(tt.input))
|
||||
req.Header.Set("X-Vise-Session", tt.sessionID)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
sessionHandler.ServeHTTP(rr, req)
|
||||
|
||||
if status := rr.Code; status != tt.expectedStatus {
|
||||
t.Errorf("handler returned wrong status code: got %v want %v",
|
||||
status, tt.expectedStatus)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionHandler_WriteError(t *testing.T) {
|
||||
handler := &SessionHandler{}
|
||||
mockWriter := &httpmocks.MockWriter{}
|
||||
err := errors.New("test error")
|
||||
|
||||
handler.WriteError(mockWriter, http.StatusBadRequest, err)
|
||||
|
||||
if mockWriter.WrittenString != "" {
|
||||
t.Errorf("Expected empty body, got %s", mockWriter.WrittenString)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultRequestParser_GetSessionId(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
request any
|
||||
expectedID string
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
name: "Valid Session ID",
|
||||
request: func() *http.Request {
|
||||
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
req.Header.Set("X-Vise-Session", "123456")
|
||||
return req
|
||||
}(),
|
||||
expectedID: "123456",
|
||||
expectedError: nil,
|
||||
},
|
||||
{
|
||||
name: "Missing Session ID",
|
||||
request: httptest.NewRequest(http.MethodPost, "/", nil),
|
||||
expectedID: "",
|
||||
expectedError: handlers.ErrSessionMissing,
|
||||
},
|
||||
{
|
||||
name: "Invalid Request Type",
|
||||
request: invalidRequestType{},
|
||||
expectedID: "",
|
||||
expectedError: handlers.ErrInvalidRequest,
|
||||
},
|
||||
}
|
||||
|
||||
parser := &DefaultRequestParser{}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
id, err := parser.GetSessionId(tt.request)
|
||||
|
||||
if id != tt.expectedID {
|
||||
t.Errorf("Expected session ID %s, got %s", tt.expectedID, id)
|
||||
}
|
||||
|
||||
if err != tt.expectedError {
|
||||
t.Errorf("Expected error %v, got %v", tt.expectedError, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultRequestParser_GetInput(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
request any
|
||||
expectedInput []byte
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
name: "Valid Input",
|
||||
request: func() *http.Request {
|
||||
return httptest.NewRequest(http.MethodPost, "/", bytes.NewBufferString("test input"))
|
||||
}(),
|
||||
expectedInput: []byte("test input"),
|
||||
expectedError: nil,
|
||||
},
|
||||
{
|
||||
name: "Empty Input",
|
||||
request: httptest.NewRequest(http.MethodPost, "/", nil),
|
||||
expectedInput: []byte{},
|
||||
expectedError: nil,
|
||||
},
|
||||
{
|
||||
name: "Invalid Request Type",
|
||||
request: invalidRequestType{},
|
||||
expectedInput: nil,
|
||||
expectedError: handlers.ErrInvalidRequest,
|
||||
},
|
||||
{
|
||||
name: "Read Error",
|
||||
request: func() *http.Request {
|
||||
return httptest.NewRequest(http.MethodPost, "/", &errorReader{})
|
||||
}(),
|
||||
expectedInput: nil,
|
||||
expectedError: errors.New("read error"),
|
||||
},
|
||||
}
|
||||
|
||||
parser := &DefaultRequestParser{}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
input, err := parser.GetInput(tt.request)
|
||||
|
||||
if !bytes.Equal(input, tt.expectedInput) {
|
||||
t.Errorf("Expected input %s, got %s", tt.expectedInput, input)
|
||||
}
|
||||
|
||||
if err != tt.expectedError && (err == nil || err.Error() != tt.expectedError.Error()) {
|
||||
t.Errorf("Expected error %v, got %v", tt.expectedError, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -6,12 +6,18 @@ import (
|
||||
"git.defalsify.org/vise.git/db"
|
||||
gdbmdb "git.defalsify.org/vise.git/db/gdbm"
|
||||
"git.defalsify.org/vise.git/lang"
|
||||
"git.defalsify.org/vise.git/logging"
|
||||
)
|
||||
|
||||
var (
|
||||
logg = logging.NewVanilla().WithDomain("gdbmstorage")
|
||||
)
|
||||
|
||||
var (
|
||||
dbC map[string]chan db.Db
|
||||
)
|
||||
|
||||
|
||||
type ThreadGdbmDb struct {
|
||||
db db.Db
|
||||
connStr string
|
||||
@@ -6,6 +6,11 @@ import (
|
||||
"git.defalsify.org/vise.git/db"
|
||||
)
|
||||
|
||||
const (
|
||||
DATATYPE_USERSUB = 64
|
||||
SUBPREFIX_TIME = uint16(1)
|
||||
)
|
||||
|
||||
// PrefixDb interface abstracts the database operations.
|
||||
type PrefixDb interface {
|
||||
Get(ctx context.Context, key []byte) ([]byte, error)
|
||||
@@ -26,8 +31,12 @@ func NewSubPrefixDb(store db.Db, pfx []byte) *SubPrefixDb {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SubPrefixDb) toKey(k []byte) []byte {
|
||||
return append(s.pfx, k...)
|
||||
func(s *SubPrefixDb) SetSession(sessionId string) {
|
||||
s.store.SetSession(sessionId)
|
||||
}
|
||||
|
||||
func(s *SubPrefixDb) toKey(k []byte) []byte {
|
||||
return append(s.pfx, k...)
|
||||
}
|
||||
|
||||
func (s *SubPrefixDb) Get(ctx context.Context, key []byte) ([]byte, error) {
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"git.defalsify.org/vise.git/persist"
|
||||
"git.defalsify.org/vise.git/resource"
|
||||
"git.grassecon.net/urdt/ussd/initializers"
|
||||
gdbmstorage "git.grassecon.net/urdt/ussd/internal/storage/db/gdbm"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -75,7 +76,7 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D
|
||||
connStr := buildConnStr()
|
||||
err = newDb.Connect(ctx, connStr)
|
||||
} else {
|
||||
newDb = NewThreadGdbmDb()
|
||||
newDb = gdbmstorage.NewThreadGdbmDb()
|
||||
storeFile := path.Join(ms.dbDir, fileName)
|
||||
err = newDb.Connect(ctx, storeFile)
|
||||
}
|
||||
|
||||
109
internal/storage/timed.go
Normal file
109
internal/storage/timed.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"time"
|
||||
"encoding/binary"
|
||||
|
||||
"git.defalsify.org/vise.git/db"
|
||||
)
|
||||
|
||||
type TimedDb struct {
|
||||
db.Db
|
||||
tdb *SubPrefixDb
|
||||
ttl time.Duration
|
||||
parentPfx uint8
|
||||
parentSession []byte
|
||||
matchPfx map[uint8][][]byte
|
||||
}
|
||||
|
||||
func NewTimedDb(db db.Db, ttl time.Duration) *TimedDb {
|
||||
var b [2]byte
|
||||
binary.BigEndian.PutUint16(b[:], SUBPREFIX_TIME)
|
||||
sdb := NewSubPrefixDb(db, b[:])
|
||||
return &TimedDb{
|
||||
Db: db,
|
||||
tdb: sdb,
|
||||
ttl: ttl,
|
||||
}
|
||||
}
|
||||
|
||||
func(tib *TimedDb) WithMatch(pfx uint8, keyPart []byte) *TimedDb {
|
||||
if tib.matchPfx == nil {
|
||||
tib.matchPfx = make(map[uint8][][]byte)
|
||||
}
|
||||
tib.matchPfx[pfx] = append(tib.matchPfx[pfx], keyPart)
|
||||
return tib
|
||||
}
|
||||
|
||||
func(tib *TimedDb) checkPrefix(pfx uint8, key []byte) bool {
|
||||
var v []byte
|
||||
if tib.matchPfx == nil {
|
||||
return true
|
||||
}
|
||||
for _, v = range(tib.matchPfx[pfx]) {
|
||||
l := len(v)
|
||||
if l > len(key) {
|
||||
continue
|
||||
}
|
||||
if bytes.Equal(v, key[:l]) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func(tib *TimedDb) SetPrefix(pfx uint8) {
|
||||
tib.Db.SetPrefix(pfx)
|
||||
tib.parentPfx = pfx
|
||||
}
|
||||
|
||||
func(tib *TimedDb) SetSession(session string) {
|
||||
tib.Db.SetSession(session)
|
||||
tib.parentSession = []byte(session)
|
||||
}
|
||||
|
||||
func(tib *TimedDb) Put(ctx context.Context, key []byte, val []byte) error {
|
||||
t := time.Now()
|
||||
b, err := t.MarshalBinary()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = tib.Db.Put(ctx, key, val)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
tib.parentPfx = 0
|
||||
tib.parentSession = nil
|
||||
}()
|
||||
if tib.checkPrefix(tib.parentPfx, key) {
|
||||
tib.tdb.SetSession("")
|
||||
k := db.ToSessionKey(tib.parentPfx, []byte(tib.parentSession), key)
|
||||
k = append([]byte{tib.parentPfx}, k...)
|
||||
err = tib.tdb.Put(ctx, k, b)
|
||||
if err != nil {
|
||||
logg.ErrorCtxf(ctx, "failed to update timestamp of record", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func(tib *TimedDb) Stale(ctx context.Context, pfx uint8, sessionId string, key []byte) bool {
|
||||
tib.tdb.SetSession("")
|
||||
b := db.ToSessionKey(pfx, []byte(sessionId), key)
|
||||
b = append([]byte{pfx}, b...)
|
||||
v, err := tib.tdb.Get(ctx, b)
|
||||
if err != nil {
|
||||
logg.WarnCtxf(ctx, "no time entry", "key", key, "b", b)
|
||||
return false
|
||||
}
|
||||
t_now := time.Now()
|
||||
t_then := time.Time{}
|
||||
err = t_then.UnmarshalBinary(v)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return t_now.After(t_then.Add(tib.ttl))
|
||||
}
|
||||
125
internal/storage/timed_test.go
Normal file
125
internal/storage/timed_test.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.defalsify.org/vise.git/db"
|
||||
memdb "git.defalsify.org/vise.git/db/mem"
|
||||
)
|
||||
|
||||
func TestStaleDb(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mdb := memdb.NewMemDb()
|
||||
err := mdb.Connect(ctx, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tdb := NewTimedDb(mdb, time.Duration(time.Millisecond))
|
||||
tdb.SetPrefix(db.DATATYPE_USERDATA)
|
||||
k := []byte("foo")
|
||||
err = tdb.Put(ctx, k, []byte("bar"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if tdb.Stale(ctx, db.DATATYPE_USERDATA, "", k) {
|
||||
t.Fatal("expected not stale")
|
||||
}
|
||||
time.Sleep(time.Millisecond)
|
||||
if !tdb.Stale(ctx, db.DATATYPE_USERDATA, "", k) {
|
||||
t.Fatal("expected stale")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilteredStaleDb(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mdb := memdb.NewMemDb()
|
||||
err := mdb.Connect(ctx, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
k := []byte("foo")
|
||||
tdb := NewTimedDb(mdb, time.Duration(time.Millisecond))
|
||||
tdb = tdb.WithMatch(db.DATATYPE_STATE, []byte("fo"))
|
||||
tdb.SetPrefix(db.DATATYPE_USERDATA)
|
||||
tdb.SetSession("inky")
|
||||
err = tdb.Put(ctx, k, []byte("bar"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tdb.SetPrefix(db.DATATYPE_STATE)
|
||||
tdb.SetSession("inky")
|
||||
err = tdb.Put(ctx, k, []byte("pinky"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tdb.SetSession("blinky")
|
||||
err = tdb.Put(ctx, k, []byte("clyde"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if tdb.Stale(ctx, db.DATATYPE_USERDATA, "inky", k) {
|
||||
t.Fatal("expected not stale")
|
||||
}
|
||||
if tdb.Stale(ctx, db.DATATYPE_STATE, "inky", k) {
|
||||
t.Fatal("expected not stale")
|
||||
}
|
||||
if tdb.Stale(ctx, db.DATATYPE_STATE, "blinky", k) {
|
||||
t.Fatal("expected not stale")
|
||||
}
|
||||
time.Sleep(time.Millisecond)
|
||||
if tdb.Stale(ctx, db.DATATYPE_USERDATA, "inky", k) {
|
||||
t.Fatal("expected not stale")
|
||||
}
|
||||
if !tdb.Stale(ctx, db.DATATYPE_STATE, "inky", k) {
|
||||
t.Fatal("expected stale")
|
||||
}
|
||||
if tdb.Stale(ctx, db.DATATYPE_STATE, "blinky", k) {
|
||||
t.Fatal("expected not stale")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilteredSameKeypartStaleDb(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mdb := memdb.NewMemDb()
|
||||
err := mdb.Connect(ctx, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tdb := NewTimedDb(mdb, time.Duration(time.Millisecond))
|
||||
tdb = tdb.WithMatch(db.DATATYPE_USERDATA, []byte("ba"))
|
||||
tdb.SetPrefix(db.DATATYPE_USERDATA)
|
||||
tdb.SetSession("xyzzy")
|
||||
err = tdb.Put(ctx, []byte("bar"), []byte("inky"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tdb.SetPrefix(db.DATATYPE_USERDATA)
|
||||
tdb.SetSession("xyzzy")
|
||||
err = tdb.Put(ctx, []byte("baz"), []byte("pinky"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tdb.SetPrefix(db.DATATYPE_USERDATA)
|
||||
tdb.SetSession("xyzzy")
|
||||
err = tdb.Put(ctx, []byte("foo"), []byte("blinky"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
time.Sleep(time.Millisecond)
|
||||
if !tdb.Stale(ctx, db.DATATYPE_USERDATA, "xyzzy", []byte("bar")) {
|
||||
t.Fatal("expected stale")
|
||||
}
|
||||
if !tdb.Stale(ctx, db.DATATYPE_USERDATA, "xyzzy", []byte("baz")) {
|
||||
t.Fatal("expected stale")
|
||||
}
|
||||
if tdb.Stale(ctx, db.DATATYPE_USERDATA, "xyzzy", []byte("foo")) {
|
||||
t.Fatal("expected not stale")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user