Merge branch 'master' into lash/dump-format
This commit is contained in:
		
						commit
						83857026d3
					
				@ -1,35 +1,31 @@
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"flag"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"os"
 | 
			
		||||
	"os/signal"
 | 
			
		||||
	"path"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"syscall"
 | 
			
		||||
 | 
			
		||||
	"git.defalsify.org/vise.git/engine"
 | 
			
		||||
	"git.defalsify.org/vise.git/logging"
 | 
			
		||||
	"git.defalsify.org/vise.git/resource"
 | 
			
		||||
 | 
			
		||||
	"git.grassecon.net/urdt/ussd/common"
 | 
			
		||||
	"git.grassecon.net/urdt/ussd/config"
 | 
			
		||||
	"git.grassecon.net/urdt/ussd/initializers"
 | 
			
		||||
	"git.grassecon.net/urdt/ussd/internal/handlers"
 | 
			
		||||
	httpserver "git.grassecon.net/urdt/ussd/internal/http"
 | 
			
		||||
	"git.grassecon.net/urdt/ussd/internal/http/at"
 | 
			
		||||
	httpserver "git.grassecon.net/urdt/ussd/internal/http/at"
 | 
			
		||||
	"git.grassecon.net/urdt/ussd/internal/storage"
 | 
			
		||||
	"git.grassecon.net/urdt/ussd/remote"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	logg          = logging.NewVanilla()
 | 
			
		||||
	logg          = logging.NewVanilla().WithDomain("AfricasTalking").WithContextKey("at-session-id")
 | 
			
		||||
	scriptDir     = path.Join("services", "registration")
 | 
			
		||||
	build         = "dev"
 | 
			
		||||
	menuSeparator = ": "
 | 
			
		||||
@ -38,72 +34,6 @@ var (
 | 
			
		||||
func init() {
 | 
			
		||||
	initializers.LoadEnvVariables()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type atRequestParser struct{}
 | 
			
		||||
 | 
			
		||||
func (arp *atRequestParser) GetSessionId(rq any) (string, error) {
 | 
			
		||||
	rqv, ok := rq.(*http.Request)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		logg.Warnf("got an invalid request", "req", rq)
 | 
			
		||||
		return "", handlers.ErrInvalidRequest
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Capture body (if any) for logging
 | 
			
		||||
	body, err := io.ReadAll(rqv.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logg.Warnf("failed to read request body", "err", err)
 | 
			
		||||
		return "", fmt.Errorf("failed to read request body: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	// Reset the body for further reading
 | 
			
		||||
	rqv.Body = io.NopCloser(bytes.NewReader(body))
 | 
			
		||||
 | 
			
		||||
	// Log the body as JSON
 | 
			
		||||
	bodyLog := map[string]string{"body": string(body)}
 | 
			
		||||
	logBytes, err := json.Marshal(bodyLog)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logg.Warnf("failed to marshal request body", "err", err)
 | 
			
		||||
	} else {
 | 
			
		||||
		logg.Debugf("received request", "bytes", logBytes)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := rqv.ParseForm(); err != nil {
 | 
			
		||||
		logg.Warnf("failed to parse form data", "err", err)
 | 
			
		||||
		return "", fmt.Errorf("failed to parse form data: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	phoneNumber := rqv.FormValue("phoneNumber")
 | 
			
		||||
	if phoneNumber == "" {
 | 
			
		||||
		return "", fmt.Errorf("no phone number found")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	formattedNumber, err := common.FormatPhoneNumber(phoneNumber)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logg.Warnf("failed to format phone number", "err", err)
 | 
			
		||||
		return "", fmt.Errorf("failed to format number")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return formattedNumber, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (arp *atRequestParser) GetInput(rq any) ([]byte, error) {
 | 
			
		||||
	rqv, ok := rq.(*http.Request)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return nil, handlers.ErrInvalidRequest
 | 
			
		||||
	}
 | 
			
		||||
	if err := rqv.ParseForm(); err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("failed to parse form data: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	text := rqv.FormValue("text")
 | 
			
		||||
 | 
			
		||||
	parts := strings.Split(text, "*")
 | 
			
		||||
	if len(parts) == 0 {
 | 
			
		||||
		return nil, fmt.Errorf("no input found")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return []byte(parts[len(parts)-1]), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func main() {
 | 
			
		||||
	config.LoadConfig()
 | 
			
		||||
 | 
			
		||||
@ -191,7 +121,9 @@ func main() {
 | 
			
		||||
	}
 | 
			
		||||
	defer stateStore.Close()
 | 
			
		||||
 | 
			
		||||
	rp := &atRequestParser{}
 | 
			
		||||
	rp := &at.ATRequestParser{
 | 
			
		||||
		Context: ctx,
 | 
			
		||||
	}
 | 
			
		||||
	bsh := handlers.NewBaseSessionHandler(cfg, rs, stateStore, userdataStore, rp, hl)
 | 
			
		||||
	sh := httpserver.NewATSessionHandler(bsh)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -8,14 +8,15 @@ import (
 | 
			
		||||
	"git.defalsify.org/vise.git/resource"
 | 
			
		||||
	"git.defalsify.org/vise.git/persist"
 | 
			
		||||
	"git.grassecon.net/urdt/ussd/internal/storage"
 | 
			
		||||
	dbstorage "git.grassecon.net/urdt/ussd/internal/storage/db"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func StoreToDb(store *UserDataStore) db.Db {
 | 
			
		||||
	return store.Db
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func StoreToPrefixDb(store *UserDataStore, pfx []byte) storage.PrefixDb {
 | 
			
		||||
	return storage.NewSubPrefixDb(store.Db, pfx)	
 | 
			
		||||
func StoreToPrefixDb(store *UserDataStore, pfx []byte) dbstorage.PrefixDb {
 | 
			
		||||
	return dbstorage.NewSubPrefixDb(store.Db, pfx)	
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type StorageServices interface {
 | 
			
		||||
 | 
			
		||||
@ -6,7 +6,7 @@ import (
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.grassecon.net/urdt/ussd/internal/storage"
 | 
			
		||||
	dbstorage "git.grassecon.net/urdt/ussd/internal/storage/db"
 | 
			
		||||
	dataserviceapi "github.com/grassrootseconomics/ussd-data-service/pkg/api"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -56,7 +56,7 @@ func ProcessTransfers(transfers []dataserviceapi.Last10TxResponse) TransferMetad
 | 
			
		||||
 | 
			
		||||
// GetTransferData retrieves and matches transfer data
 | 
			
		||||
// returns a formatted string of the full transaction/statement
 | 
			
		||||
func GetTransferData(ctx context.Context, db storage.PrefixDb, publicKey string, index int) (string, error) {
 | 
			
		||||
func GetTransferData(ctx context.Context, db dbstorage.PrefixDb, publicKey string, index int) (string, error) {
 | 
			
		||||
	keys := []DataTyp{DATA_TX_SENDERS, DATA_TX_RECIPIENTS, DATA_TX_VALUES, DATA_TX_ADDRESSES, DATA_TX_HASHES, DATA_TX_DATES, DATA_TX_SYMBOLS}
 | 
			
		||||
	data := make(map[DataTyp]string)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -6,7 +6,7 @@ import (
 | 
			
		||||
	"math/big"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"git.grassecon.net/urdt/ussd/internal/storage"
 | 
			
		||||
	dbstorage "git.grassecon.net/urdt/ussd/internal/storage/db"
 | 
			
		||||
	dataserviceapi "github.com/grassrootseconomics/ussd-data-service/pkg/api"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -63,7 +63,7 @@ func ScaleDownBalance(balance, decimals string) string {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetVoucherData retrieves and matches voucher data
 | 
			
		||||
func GetVoucherData(ctx context.Context, db storage.PrefixDb, input string) (*dataserviceapi.TokenHoldings, error) {
 | 
			
		||||
func GetVoucherData(ctx context.Context, db dbstorage.PrefixDb, input string) (*dataserviceapi.TokenHoldings, error) {
 | 
			
		||||
	keys := []DataTyp{DATA_VOUCHER_SYMBOLS, DATA_VOUCHER_BALANCES, DATA_VOUCHER_DECIMALS, DATA_VOUCHER_ADDRESSES}
 | 
			
		||||
	data := make(map[DataTyp]string)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -10,7 +10,7 @@ import (
 | 
			
		||||
 | 
			
		||||
	visedb "git.defalsify.org/vise.git/db"
 | 
			
		||||
	memdb "git.defalsify.org/vise.git/db/mem"
 | 
			
		||||
	"git.grassecon.net/urdt/ussd/internal/storage"
 | 
			
		||||
	dbstorage "git.grassecon.net/urdt/ussd/internal/storage/db"
 | 
			
		||||
	dataserviceapi "github.com/grassrootseconomics/ussd-data-service/pkg/api"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -86,7 +86,7 @@ func TestGetVoucherData(t *testing.T) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	prefix := ToBytes(visedb.DATATYPE_USERDATA)
 | 
			
		||||
	spdb := storage.NewSubPrefixDb(db, prefix)
 | 
			
		||||
	spdb := dbstorage.NewSubPrefixDb(db, prefix)
 | 
			
		||||
 | 
			
		||||
	// Test voucher data
 | 
			
		||||
	mockData := map[DataTyp][]byte{
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										4
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								go.mod
									
									
									
									
									
								
							@ -3,7 +3,7 @@ module git.grassecon.net/urdt/ussd
 | 
			
		||||
go 1.23.0
 | 
			
		||||
 | 
			
		||||
require (
 | 
			
		||||
	git.defalsify.org/vise.git v0.2.3-0.20241231085136-8582c7e157d9
 | 
			
		||||
	git.defalsify.org/vise.git v0.2.3-0.20250103172917-3e190a44568d
 | 
			
		||||
	github.com/alecthomas/assert/v2 v2.2.2
 | 
			
		||||
	github.com/gofrs/uuid v4.4.0+incompatible
 | 
			
		||||
	github.com/grassrootseconomics/eth-custodial v1.3.0-beta
 | 
			
		||||
@ -11,6 +11,7 @@ require (
 | 
			
		||||
	github.com/joho/godotenv v1.5.1
 | 
			
		||||
	github.com/peteole/testdata-loader v0.3.0
 | 
			
		||||
	github.com/stretchr/testify v1.9.0
 | 
			
		||||
	golang.org/x/crypto v0.27.0
 | 
			
		||||
	gopkg.in/leonelquinteros/gotext.v1 v1.3.1
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -32,7 +33,6 @@ require (
 | 
			
		||||
	github.com/rogpeppe/go-internal v1.13.1 // indirect
 | 
			
		||||
	github.com/stretchr/objx v0.5.2 // indirect
 | 
			
		||||
	github.com/x448/float16 v0.8.4 // indirect
 | 
			
		||||
	golang.org/x/crypto v0.27.0 // indirect
 | 
			
		||||
	golang.org/x/sync v0.8.0 // indirect
 | 
			
		||||
	golang.org/x/text v0.18.0 // indirect
 | 
			
		||||
	gopkg.in/yaml.v3 v3.0.1 // indirect
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										4
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								go.sum
									
									
									
									
									
								
							@ -1,5 +1,5 @@
 | 
			
		||||
git.defalsify.org/vise.git v0.2.3-0.20241231085136-8582c7e157d9 h1:O3m+NgWDWtJm8OculT99c4bDMAO4xLe2c8hpCKpsd9g=
 | 
			
		||||
git.defalsify.org/vise.git v0.2.3-0.20241231085136-8582c7e157d9/go.mod h1:jyBMe1qTYUz3mmuoC9JQ/TvFeW0vTanCUcPu3H8p4Ck=
 | 
			
		||||
git.defalsify.org/vise.git v0.2.3-0.20250103172917-3e190a44568d h1:bPAOVZOX4frSGhfOdcj7kc555f8dc9DmMd2YAyC2AMw=
 | 
			
		||||
git.defalsify.org/vise.git v0.2.3-0.20250103172917-3e190a44568d/go.mod h1:jyBMe1qTYUz3mmuoC9JQ/TvFeW0vTanCUcPu3H8p4Ck=
 | 
			
		||||
github.com/alecthomas/assert/v2 v2.2.2 h1:Z/iVC0xZfWTaFNE6bA3z07T86hd45Xe2eLt6WVy2bbk=
 | 
			
		||||
github.com/alecthomas/assert/v2 v2.2.2/go.mod h1:pXcQ2Asjp247dahGEmsZ6ru0UVwnkhktn7S0bBDLxvQ=
 | 
			
		||||
github.com/alecthomas/participle/v2 v2.0.0 h1:Fgrq+MbuSsJwIkw3fEj9h75vDP0Er5JzepJ0/HNHv0g=
 | 
			
		||||
 | 
			
		||||
@ -23,12 +23,12 @@ import (
 | 
			
		||||
	"git.grassecon.net/urdt/ussd/remote"
 | 
			
		||||
	"gopkg.in/leonelquinteros/gotext.v1"
 | 
			
		||||
 | 
			
		||||
	"git.grassecon.net/urdt/ussd/internal/storage"
 | 
			
		||||
	dbstorage "git.grassecon.net/urdt/ussd/internal/storage/db"
 | 
			
		||||
	dataserviceapi "github.com/grassrootseconomics/ussd-data-service/pkg/api"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	logg           = logging.NewVanilla().WithDomain("ussdmenuhandler")
 | 
			
		||||
	logg           = logging.NewVanilla().WithDomain("ussdmenuhandler").WithContextKey("session-id")
 | 
			
		||||
	scriptDir      = path.Join("services", "registration")
 | 
			
		||||
	translationDir = path.Join(scriptDir, "locale")
 | 
			
		||||
)
 | 
			
		||||
@ -64,7 +64,7 @@ type Handlers struct {
 | 
			
		||||
	adminstore           *utils.AdminStore
 | 
			
		||||
	flagManager          *asm.FlagParser
 | 
			
		||||
	accountService       remote.AccountServiceInterface
 | 
			
		||||
	prefixDb             storage.PrefixDb
 | 
			
		||||
	prefixDb             dbstorage.PrefixDb
 | 
			
		||||
	profile              *models.Profile
 | 
			
		||||
	ReplaceSeparatorFunc func(string) string
 | 
			
		||||
}
 | 
			
		||||
@ -80,7 +80,7 @@ func NewHandlers(appFlags *asm.FlagParser, userdataStore db.Db, adminstore *util
 | 
			
		||||
 | 
			
		||||
	// Instantiate the SubPrefixDb with "DATATYPE_USERDATA" prefix
 | 
			
		||||
	prefix := common.ToBytes(db.DATATYPE_USERDATA)
 | 
			
		||||
	prefixDb := storage.NewSubPrefixDb(userdataStore, prefix)
 | 
			
		||||
	prefixDb := dbstorage.NewSubPrefixDb(userdataStore, prefix)
 | 
			
		||||
 | 
			
		||||
	h := &Handlers{
 | 
			
		||||
		userdataStore:        userDb,
 | 
			
		||||
@ -122,9 +122,12 @@ func (h *Handlers) Init(ctx context.Context, sym string, input []byte) (resource
 | 
			
		||||
		h.st.Code = []byte{}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	sessionId, _ := ctx.Value("SessionId").(string)
 | 
			
		||||
	flag_admin_privilege, _ := h.flagManager.GetFlag("flag_admin_privilege")
 | 
			
		||||
	sessionId, ok := ctx.Value("SessionId").(string)
 | 
			
		||||
	if ok {
 | 
			
		||||
		context.WithValue(ctx, "session-id", sessionId)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	flag_admin_privilege, _ := h.flagManager.GetFlag("flag_admin_privilege")
 | 
			
		||||
	isAdmin, _ := h.adminstore.IsAdmin(sessionId)
 | 
			
		||||
 | 
			
		||||
	if isAdmin {
 | 
			
		||||
 | 
			
		||||
@ -13,7 +13,7 @@ import (
 | 
			
		||||
	"git.defalsify.org/vise.git/persist"
 | 
			
		||||
	"git.defalsify.org/vise.git/resource"
 | 
			
		||||
	"git.defalsify.org/vise.git/state"
 | 
			
		||||
	"git.grassecon.net/urdt/ussd/internal/storage"
 | 
			
		||||
	dbstorage "git.grassecon.net/urdt/ussd/internal/storage/db"
 | 
			
		||||
	"git.grassecon.net/urdt/ussd/internal/testutil/mocks"
 | 
			
		||||
	"git.grassecon.net/urdt/ussd/internal/testutil/testservice"
 | 
			
		||||
	"git.grassecon.net/urdt/ussd/internal/utils"
 | 
			
		||||
@ -59,14 +59,14 @@ func InitializeTestStore(t *testing.T) (context.Context, *common.UserDataStore)
 | 
			
		||||
	return ctx, store
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func InitializeTestSubPrefixDb(t *testing.T, ctx context.Context) *storage.SubPrefixDb {
 | 
			
		||||
func InitializeTestSubPrefixDb(t *testing.T, ctx context.Context) *dbstorage.SubPrefixDb {
 | 
			
		||||
	db := memdb.NewMemDb()
 | 
			
		||||
	err := db.Connect(ctx, "")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	prefix := common.ToBytes(visedb.DATATYPE_USERDATA)
 | 
			
		||||
	spdb := storage.NewSubPrefixDb(db, prefix)
 | 
			
		||||
	spdb := dbstorage.NewSubPrefixDb(db, prefix)
 | 
			
		||||
 | 
			
		||||
	return spdb
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										121
									
								
								internal/http/at/parse.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										121
									
								
								internal/http/at/parse.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,121 @@
 | 
			
		||||
package at
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"git.grassecon.net/urdt/ussd/common"
 | 
			
		||||
	"git.grassecon.net/urdt/ussd/internal/handlers"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type ATRequestParser struct {
 | 
			
		||||
	Context context.Context
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (arp *ATRequestParser) GetSessionId(rq any) (string, error) {
 | 
			
		||||
	rqv, ok := rq.(*http.Request)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		logg.Warnf("got an invalid request", "req", rq)
 | 
			
		||||
		return "", handlers.ErrInvalidRequest
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Capture body (if any) for logging
 | 
			
		||||
	body, err := io.ReadAll(rqv.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logg.Warnf("failed to read request body", "err", err)
 | 
			
		||||
		return "", fmt.Errorf("failed to read request body: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	// Reset the body for further reading
 | 
			
		||||
	rqv.Body = io.NopCloser(bytes.NewReader(body))
 | 
			
		||||
 | 
			
		||||
	// Log the body as JSON
 | 
			
		||||
	bodyLog := map[string]string{"body": string(body)}
 | 
			
		||||
	logBytes, err := json.Marshal(bodyLog)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logg.Warnf("failed to marshal request body", "err", err)
 | 
			
		||||
	} else {
 | 
			
		||||
		decodedStr := string(logBytes)
 | 
			
		||||
		sessionId, err := extractATSessionId(decodedStr)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			context.WithValue(arp.Context, "at-session-id", sessionId)
 | 
			
		||||
		}
 | 
			
		||||
		logg.Debugf("Received request:", decodedStr)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := rqv.ParseForm(); err != nil {
 | 
			
		||||
		logg.Warnf("failed to parse form data", "err", err)
 | 
			
		||||
		return "", fmt.Errorf("failed to parse form data: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	phoneNumber := rqv.FormValue("phoneNumber")
 | 
			
		||||
	if phoneNumber == "" {
 | 
			
		||||
		return "", fmt.Errorf("no phone number found")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	formattedNumber, err := common.FormatPhoneNumber(phoneNumber)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logg.Warnf("failed to format phone number", "err", err)
 | 
			
		||||
		return "", fmt.Errorf("failed to format number")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return formattedNumber, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (arp *ATRequestParser) GetInput(rq any) ([]byte, error) {
 | 
			
		||||
	rqv, ok := rq.(*http.Request)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return nil, handlers.ErrInvalidRequest
 | 
			
		||||
	}
 | 
			
		||||
	if err := rqv.ParseForm(); err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("failed to parse form data: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	text := rqv.FormValue("text")
 | 
			
		||||
 | 
			
		||||
	parts := strings.Split(text, "*")
 | 
			
		||||
	if len(parts) == 0 {
 | 
			
		||||
		return nil, fmt.Errorf("no input found")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return []byte(parts[len(parts)-1]), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func parseQueryParams(query string) map[string]string {
 | 
			
		||||
	params := make(map[string]string)
 | 
			
		||||
 | 
			
		||||
	queryParams := strings.Split(query, "&")
 | 
			
		||||
	for _, param := range queryParams {
 | 
			
		||||
		// Split each key-value pair by '='
 | 
			
		||||
		parts := strings.SplitN(param, "=", 2)
 | 
			
		||||
		if len(parts) == 2 {
 | 
			
		||||
			params[parts[0]] = parts[1]
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return params
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func extractATSessionId(decodedStr string) (string, error) {
 | 
			
		||||
	var data map[string]string
 | 
			
		||||
	err := json.Unmarshal([]byte(decodedStr), &data)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logg.Errorf("Error unmarshalling JSON: %v", err)
 | 
			
		||||
		return "", nil
 | 
			
		||||
	}
 | 
			
		||||
	decodedBody, err := url.QueryUnescape(data["body"])
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logg.Errorf("Error URL-decoding body: %v", err)
 | 
			
		||||
		return "", nil
 | 
			
		||||
	}
 | 
			
		||||
	params := parseQueryParams(decodedBody)
 | 
			
		||||
 | 
			
		||||
	sessionId := params["sessionId"]
 | 
			
		||||
	return sessionId, nil
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
@ -1,19 +1,25 @@
 | 
			
		||||
package http
 | 
			
		||||
package at
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
 | 
			
		||||
	"git.defalsify.org/vise.git/logging"
 | 
			
		||||
	"git.grassecon.net/urdt/ussd/internal/handlers"
 | 
			
		||||
	httpserver "git.grassecon.net/urdt/ussd/internal/http"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	logg = logging.NewVanilla().WithDomain("atserver")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type ATSessionHandler struct {
 | 
			
		||||
	*SessionHandler
 | 
			
		||||
	*httpserver.SessionHandler
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewATSessionHandler(h handlers.RequestHandler) *ATSessionHandler {
 | 
			
		||||
	return &ATSessionHandler{
 | 
			
		||||
		SessionHandler: ToSessionHandler(h),
 | 
			
		||||
		SessionHandler: httpserver.ToSessionHandler(h),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -31,14 +37,14 @@ func (ash *ATSessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request)
 | 
			
		||||
	cfg.SessionId, err = rp.GetSessionId(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logg.ErrorCtxf(rqs.Ctx, "", "header processing error", err)
 | 
			
		||||
		ash.writeError(w, 400, err)
 | 
			
		||||
		ash.WriteError(w, 400, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	rqs.Config = cfg
 | 
			
		||||
	rqs.Input, err = rp.GetInput(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logg.ErrorCtxf(rqs.Ctx, "", "header processing error", err)
 | 
			
		||||
		ash.writeError(w, 400, err)
 | 
			
		||||
		ash.WriteError(w, 400, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@ -53,7 +59,7 @@ func (ash *ATSessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if code != 200 {
 | 
			
		||||
		ash.writeError(w, 500, err)
 | 
			
		||||
		ash.WriteError(w, 500, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@ -61,13 +67,13 @@ func (ash *ATSessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request)
 | 
			
		||||
	w.Header().Set("Content-Type", "text/plain")
 | 
			
		||||
	rqs, err = ash.Output(rqs)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		ash.writeError(w, 500, err)
 | 
			
		||||
		ash.WriteError(w, 500, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rqs, err = ash.Reset(rqs)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		ash.writeError(w, 500, err)
 | 
			
		||||
		ash.WriteError(w, 500, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -89,4 +95,4 @@ func (ash *ATSessionHandler) Output(rqs handlers.RequestSession) (handlers.Reque
 | 
			
		||||
 | 
			
		||||
	_, err = rqs.Engine.Flush(rqs.Ctx, rqs.Writer)
 | 
			
		||||
	return rqs, err
 | 
			
		||||
}
 | 
			
		||||
}
 | 
			
		||||
@ -1,7 +1,6 @@
 | 
			
		||||
package http
 | 
			
		||||
package at
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"context"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"io"
 | 
			
		||||
@ -16,16 +15,6 @@ import (
 | 
			
		||||
	"git.grassecon.net/urdt/ussd/internal/testutil/mocks/httpmocks"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// invalidRequestType is a custom type to test invalid request scenarios
 | 
			
		||||
type invalidRequestType struct{}
 | 
			
		||||
 | 
			
		||||
// errorReader is a helper type that always returns an error when Read is called
 | 
			
		||||
type errorReader struct{}
 | 
			
		||||
 | 
			
		||||
func (e *errorReader) Read(p []byte) (n int, err error) {
 | 
			
		||||
	return 0, errors.New("read error")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestNewATSessionHandler(t *testing.T) {
 | 
			
		||||
	mockHandler := &httpmocks.MockRequestHandler{}
 | 
			
		||||
	ash := NewATSessionHandler(mockHandler)
 | 
			
		||||
@ -242,208 +231,4 @@ func TestATSessionHandler_Output(t *testing.T) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestSessionHandler_ServeHTTP(t *testing.T) {
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		name           string
 | 
			
		||||
		sessionID      string
 | 
			
		||||
		input          []byte
 | 
			
		||||
		parserErr      error
 | 
			
		||||
		processErr     error
 | 
			
		||||
		outputErr      error
 | 
			
		||||
		resetErr       error
 | 
			
		||||
		expectedStatus int
 | 
			
		||||
	}{
 | 
			
		||||
		{
 | 
			
		||||
			name:           "Success",
 | 
			
		||||
			sessionID:      "123",
 | 
			
		||||
			input:          []byte("test input"),
 | 
			
		||||
			expectedStatus: http.StatusOK,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:           "Missing Session ID",
 | 
			
		||||
			sessionID:      "",
 | 
			
		||||
			parserErr:      handlers.ErrSessionMissing,
 | 
			
		||||
			expectedStatus: http.StatusBadRequest,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:           "Process Error",
 | 
			
		||||
			sessionID:      "123",
 | 
			
		||||
			input:          []byte("test input"),
 | 
			
		||||
			processErr:     handlers.ErrStorage,
 | 
			
		||||
			expectedStatus: http.StatusInternalServerError,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:           "Output Error",
 | 
			
		||||
			sessionID:      "123",
 | 
			
		||||
			input:          []byte("test input"),
 | 
			
		||||
			outputErr:      errors.New("output error"),
 | 
			
		||||
			expectedStatus: http.StatusOK,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:           "Reset Error",
 | 
			
		||||
			sessionID:      "123",
 | 
			
		||||
			input:          []byte("test input"),
 | 
			
		||||
			resetErr:       errors.New("reset error"),
 | 
			
		||||
			expectedStatus: http.StatusOK,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.name, func(t *testing.T) {
 | 
			
		||||
			mockRequestParser := &httpmocks.MockRequestParser{
 | 
			
		||||
				GetSessionIdFunc: func(any) (string, error) {
 | 
			
		||||
					return tt.sessionID, tt.parserErr
 | 
			
		||||
				},
 | 
			
		||||
				GetInputFunc: func(any) ([]byte, error) {
 | 
			
		||||
					return tt.input, nil
 | 
			
		||||
				},
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			mockRequestHandler := &httpmocks.MockRequestHandler{
 | 
			
		||||
				ProcessFunc: func(rs handlers.RequestSession) (handlers.RequestSession, error) {
 | 
			
		||||
					return rs, tt.processErr
 | 
			
		||||
				},
 | 
			
		||||
				OutputFunc: func(rs handlers.RequestSession) (handlers.RequestSession, error) {
 | 
			
		||||
					return rs, tt.outputErr
 | 
			
		||||
				},
 | 
			
		||||
				ResetFunc: func(rs handlers.RequestSession) (handlers.RequestSession, error) {
 | 
			
		||||
					return rs, tt.resetErr
 | 
			
		||||
				},
 | 
			
		||||
				GetRequestParserFunc: func() handlers.RequestParser {
 | 
			
		||||
					return mockRequestParser
 | 
			
		||||
				},
 | 
			
		||||
				GetConfigFunc: func() engine.Config {
 | 
			
		||||
					return engine.Config{}
 | 
			
		||||
				},
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			sessionHandler := ToSessionHandler(mockRequestHandler)
 | 
			
		||||
 | 
			
		||||
			req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBuffer(tt.input))
 | 
			
		||||
			req.Header.Set("X-Vise-Session", tt.sessionID)
 | 
			
		||||
 | 
			
		||||
			rr := httptest.NewRecorder()
 | 
			
		||||
 | 
			
		||||
			sessionHandler.ServeHTTP(rr, req)
 | 
			
		||||
 | 
			
		||||
			if status := rr.Code; status != tt.expectedStatus {
 | 
			
		||||
				t.Errorf("handler returned wrong status code: got %v want %v",
 | 
			
		||||
					status, tt.expectedStatus)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestSessionHandler_writeError(t *testing.T) {
 | 
			
		||||
	handler := &SessionHandler{}
 | 
			
		||||
	mockWriter := &httpmocks.MockWriter{}
 | 
			
		||||
	err := errors.New("test error")
 | 
			
		||||
 | 
			
		||||
	handler.writeError(mockWriter, http.StatusBadRequest, err)
 | 
			
		||||
 | 
			
		||||
	if mockWriter.WrittenString != "" {
 | 
			
		||||
		t.Errorf("Expected empty body, got %s", mockWriter.WrittenString)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestDefaultRequestParser_GetSessionId(t *testing.T) {
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		name          string
 | 
			
		||||
		request       any
 | 
			
		||||
		expectedID    string
 | 
			
		||||
		expectedError error
 | 
			
		||||
	}{
 | 
			
		||||
		{
 | 
			
		||||
			name: "Valid Session ID",
 | 
			
		||||
			request: func() *http.Request {
 | 
			
		||||
				req := httptest.NewRequest(http.MethodPost, "/", nil)
 | 
			
		||||
				req.Header.Set("X-Vise-Session", "123456")
 | 
			
		||||
				return req
 | 
			
		||||
			}(),
 | 
			
		||||
			expectedID:    "123456",
 | 
			
		||||
			expectedError: nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:          "Missing Session ID",
 | 
			
		||||
			request:       httptest.NewRequest(http.MethodPost, "/", nil),
 | 
			
		||||
			expectedID:    "",
 | 
			
		||||
			expectedError: handlers.ErrSessionMissing,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:          "Invalid Request Type",
 | 
			
		||||
			request:       invalidRequestType{},
 | 
			
		||||
			expectedID:    "",
 | 
			
		||||
			expectedError: handlers.ErrInvalidRequest,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	parser := &DefaultRequestParser{}
 | 
			
		||||
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.name, func(t *testing.T) {
 | 
			
		||||
			id, err := parser.GetSessionId(tt.request)
 | 
			
		||||
 | 
			
		||||
			if id != tt.expectedID {
 | 
			
		||||
				t.Errorf("Expected session ID %s, got %s", tt.expectedID, id)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if err != tt.expectedError {
 | 
			
		||||
				t.Errorf("Expected error %v, got %v", tt.expectedError, err)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestDefaultRequestParser_GetInput(t *testing.T) {
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		name          string
 | 
			
		||||
		request       any
 | 
			
		||||
		expectedInput []byte
 | 
			
		||||
		expectedError error
 | 
			
		||||
	}{
 | 
			
		||||
		{
 | 
			
		||||
			name: "Valid Input",
 | 
			
		||||
			request: func() *http.Request {
 | 
			
		||||
				return httptest.NewRequest(http.MethodPost, "/", bytes.NewBufferString("test input"))
 | 
			
		||||
			}(),
 | 
			
		||||
			expectedInput: []byte("test input"),
 | 
			
		||||
			expectedError: nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:          "Empty Input",
 | 
			
		||||
			request:       httptest.NewRequest(http.MethodPost, "/", nil),
 | 
			
		||||
			expectedInput: []byte{},
 | 
			
		||||
			expectedError: nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:          "Invalid Request Type",
 | 
			
		||||
			request:       invalidRequestType{},
 | 
			
		||||
			expectedInput: nil,
 | 
			
		||||
			expectedError: handlers.ErrInvalidRequest,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "Read Error",
 | 
			
		||||
			request: func() *http.Request {
 | 
			
		||||
				return httptest.NewRequest(http.MethodPost, "/", &errorReader{})
 | 
			
		||||
			}(),
 | 
			
		||||
			expectedInput: nil,
 | 
			
		||||
			expectedError: errors.New("read error"),
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	parser := &DefaultRequestParser{}
 | 
			
		||||
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.name, func(t *testing.T) {
 | 
			
		||||
			input, err := parser.GetInput(tt.request)
 | 
			
		||||
 | 
			
		||||
			if !bytes.Equal(input, tt.expectedInput) {
 | 
			
		||||
				t.Errorf("Expected input %s, got %s", tt.expectedInput, input)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if err != tt.expectedError && (err == nil || err.Error() != tt.expectedError.Error()) {
 | 
			
		||||
				t.Errorf("Expected error %v, got %v", tt.expectedError, err)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										38
									
								
								internal/http/parse.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								internal/http/parse.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,38 @@
 | 
			
		||||
package http
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"io/ioutil"
 | 
			
		||||
	"net/http"
 | 
			
		||||
 | 
			
		||||
	"git.grassecon.net/urdt/ussd/internal/handlers"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type DefaultRequestParser struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (rp *DefaultRequestParser) GetSessionId(rq any) (string, error) {
 | 
			
		||||
	rqv, ok := rq.(*http.Request)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return "", handlers.ErrInvalidRequest
 | 
			
		||||
	}
 | 
			
		||||
	v := rqv.Header.Get("X-Vise-Session")
 | 
			
		||||
	if v == "" {
 | 
			
		||||
		return "", handlers.ErrSessionMissing
 | 
			
		||||
	}
 | 
			
		||||
	return v, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (rp *DefaultRequestParser) GetInput(rq any) ([]byte, error) {
 | 
			
		||||
	rqv, ok := rq.(*http.Request)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return nil, handlers.ErrInvalidRequest
 | 
			
		||||
	}
 | 
			
		||||
	defer rqv.Body.Close()
 | 
			
		||||
	v, err := ioutil.ReadAll(rqv.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return v, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,6 @@
 | 
			
		||||
package http
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"io/ioutil"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strconv"
 | 
			
		||||
 | 
			
		||||
@ -14,34 +13,6 @@ var (
 | 
			
		||||
	logg = logging.NewVanilla().WithDomain("httpserver")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type DefaultRequestParser struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (rp *DefaultRequestParser) GetSessionId(rq any) (string, error) {
 | 
			
		||||
	rqv, ok := rq.(*http.Request)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return "", handlers.ErrInvalidRequest
 | 
			
		||||
	}
 | 
			
		||||
	v := rqv.Header.Get("X-Vise-Session")
 | 
			
		||||
	if v == "" {
 | 
			
		||||
		return "", handlers.ErrSessionMissing
 | 
			
		||||
	}
 | 
			
		||||
	return v, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (rp *DefaultRequestParser) GetInput(rq any) ([]byte, error) {
 | 
			
		||||
	rqv, ok := rq.(*http.Request)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return nil, handlers.ErrInvalidRequest
 | 
			
		||||
	}
 | 
			
		||||
	defer rqv.Body.Close()
 | 
			
		||||
	v, err := ioutil.ReadAll(rqv.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return v, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type SessionHandler struct {
 | 
			
		||||
	handlers.RequestHandler
 | 
			
		||||
}
 | 
			
		||||
@ -52,7 +23,7 @@ func ToSessionHandler(h handlers.RequestHandler) *SessionHandler {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f *SessionHandler) writeError(w http.ResponseWriter, code int, err error) {
 | 
			
		||||
func (f *SessionHandler) WriteError(w http.ResponseWriter, code int, err error) {
 | 
			
		||||
	s := err.Error()
 | 
			
		||||
	w.Header().Set("Content-Length", strconv.Itoa(len(s)))
 | 
			
		||||
	w.WriteHeader(code)
 | 
			
		||||
@ -78,13 +49,13 @@ func (f *SessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
 | 
			
		||||
	cfg.SessionId, err = rp.GetSessionId(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logg.ErrorCtxf(rqs.Ctx, "", "header processing error", err)
 | 
			
		||||
		f.writeError(w, 400, err)
 | 
			
		||||
		f.WriteError(w, 400, err)
 | 
			
		||||
	}
 | 
			
		||||
	rqs.Config = cfg
 | 
			
		||||
	rqs.Input, err = rp.GetInput(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logg.ErrorCtxf(rqs.Ctx, "", "header processing error", err)
 | 
			
		||||
		f.writeError(w, 400, err)
 | 
			
		||||
		f.WriteError(w, 400, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@ -101,7 +72,7 @@ func (f *SessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if code != 200 {
 | 
			
		||||
		f.writeError(w, 500, err)
 | 
			
		||||
		f.WriteError(w, 500, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@ -110,11 +81,11 @@ func (f *SessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
 | 
			
		||||
	rqs, err = f.Output(rqs)
 | 
			
		||||
	rqs, perr = f.Reset(rqs)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		f.writeError(w, 500, err)
 | 
			
		||||
		f.WriteError(w, 500, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if perr != nil {
 | 
			
		||||
		f.writeError(w, 500, perr)
 | 
			
		||||
		f.WriteError(w, 500, perr)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										229
									
								
								internal/http/server_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										229
									
								
								internal/http/server_test.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,229 @@
 | 
			
		||||
package http
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/http/httptest"
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
	"git.defalsify.org/vise.git/engine"
 | 
			
		||||
	"git.grassecon.net/urdt/ussd/internal/handlers"
 | 
			
		||||
	"git.grassecon.net/urdt/ussd/internal/testutil/mocks/httpmocks"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// invalidRequestType is a custom type to test invalid request scenarios
 | 
			
		||||
type invalidRequestType struct{}
 | 
			
		||||
 | 
			
		||||
// errorReader is a helper type that always returns an error when Read is called
 | 
			
		||||
type errorReader struct{}
 | 
			
		||||
 | 
			
		||||
func (e *errorReader) Read(p []byte) (n int, err error) {
 | 
			
		||||
	return 0, errors.New("read error")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestSessionHandler_ServeHTTP(t *testing.T) {
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		name           string
 | 
			
		||||
		sessionID      string
 | 
			
		||||
		input          []byte
 | 
			
		||||
		parserErr      error
 | 
			
		||||
		processErr     error
 | 
			
		||||
		outputErr      error
 | 
			
		||||
		resetErr       error
 | 
			
		||||
		expectedStatus int
 | 
			
		||||
	}{
 | 
			
		||||
		{
 | 
			
		||||
			name:           "Success",
 | 
			
		||||
			sessionID:      "123",
 | 
			
		||||
			input:          []byte("test input"),
 | 
			
		||||
			expectedStatus: http.StatusOK,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:           "Missing Session ID",
 | 
			
		||||
			sessionID:      "",
 | 
			
		||||
			parserErr:      handlers.ErrSessionMissing,
 | 
			
		||||
			expectedStatus: http.StatusBadRequest,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:           "Process Error",
 | 
			
		||||
			sessionID:      "123",
 | 
			
		||||
			input:          []byte("test input"),
 | 
			
		||||
			processErr:     handlers.ErrStorage,
 | 
			
		||||
			expectedStatus: http.StatusInternalServerError,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:           "Output Error",
 | 
			
		||||
			sessionID:      "123",
 | 
			
		||||
			input:          []byte("test input"),
 | 
			
		||||
			outputErr:      errors.New("output error"),
 | 
			
		||||
			expectedStatus: http.StatusOK,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:           "Reset Error",
 | 
			
		||||
			sessionID:      "123",
 | 
			
		||||
			input:          []byte("test input"),
 | 
			
		||||
			resetErr:       errors.New("reset error"),
 | 
			
		||||
			expectedStatus: http.StatusOK,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.name, func(t *testing.T) {
 | 
			
		||||
			mockRequestParser := &httpmocks.MockRequestParser{
 | 
			
		||||
				GetSessionIdFunc: func(any) (string, error) {
 | 
			
		||||
					return tt.sessionID, tt.parserErr
 | 
			
		||||
				},
 | 
			
		||||
				GetInputFunc: func(any) ([]byte, error) {
 | 
			
		||||
					return tt.input, nil
 | 
			
		||||
				},
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			mockRequestHandler := &httpmocks.MockRequestHandler{
 | 
			
		||||
				ProcessFunc: func(rs handlers.RequestSession) (handlers.RequestSession, error) {
 | 
			
		||||
					return rs, tt.processErr
 | 
			
		||||
				},
 | 
			
		||||
				OutputFunc: func(rs handlers.RequestSession) (handlers.RequestSession, error) {
 | 
			
		||||
					return rs, tt.outputErr
 | 
			
		||||
				},
 | 
			
		||||
				ResetFunc: func(rs handlers.RequestSession) (handlers.RequestSession, error) {
 | 
			
		||||
					return rs, tt.resetErr
 | 
			
		||||
				},
 | 
			
		||||
				GetRequestParserFunc: func() handlers.RequestParser {
 | 
			
		||||
					return mockRequestParser
 | 
			
		||||
				},
 | 
			
		||||
				GetConfigFunc: func() engine.Config {
 | 
			
		||||
					return engine.Config{}
 | 
			
		||||
				},
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			sessionHandler := ToSessionHandler(mockRequestHandler)
 | 
			
		||||
 | 
			
		||||
			req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBuffer(tt.input))
 | 
			
		||||
			req.Header.Set("X-Vise-Session", tt.sessionID)
 | 
			
		||||
 | 
			
		||||
			rr := httptest.NewRecorder()
 | 
			
		||||
 | 
			
		||||
			sessionHandler.ServeHTTP(rr, req)
 | 
			
		||||
 | 
			
		||||
			if status := rr.Code; status != tt.expectedStatus {
 | 
			
		||||
				t.Errorf("handler returned wrong status code: got %v want %v",
 | 
			
		||||
					status, tt.expectedStatus)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestSessionHandler_WriteError(t *testing.T) {
 | 
			
		||||
	handler := &SessionHandler{}
 | 
			
		||||
	mockWriter := &httpmocks.MockWriter{}
 | 
			
		||||
	err := errors.New("test error")
 | 
			
		||||
 | 
			
		||||
	handler.WriteError(mockWriter, http.StatusBadRequest, err)
 | 
			
		||||
 | 
			
		||||
	if mockWriter.WrittenString != "" {
 | 
			
		||||
		t.Errorf("Expected empty body, got %s", mockWriter.WrittenString)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestDefaultRequestParser_GetSessionId(t *testing.T) {
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		name          string
 | 
			
		||||
		request       any
 | 
			
		||||
		expectedID    string
 | 
			
		||||
		expectedError error
 | 
			
		||||
	}{
 | 
			
		||||
		{
 | 
			
		||||
			name: "Valid Session ID",
 | 
			
		||||
			request: func() *http.Request {
 | 
			
		||||
				req := httptest.NewRequest(http.MethodPost, "/", nil)
 | 
			
		||||
				req.Header.Set("X-Vise-Session", "123456")
 | 
			
		||||
				return req
 | 
			
		||||
			}(),
 | 
			
		||||
			expectedID:    "123456",
 | 
			
		||||
			expectedError: nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:          "Missing Session ID",
 | 
			
		||||
			request:       httptest.NewRequest(http.MethodPost, "/", nil),
 | 
			
		||||
			expectedID:    "",
 | 
			
		||||
			expectedError: handlers.ErrSessionMissing,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:          "Invalid Request Type",
 | 
			
		||||
			request:       invalidRequestType{},
 | 
			
		||||
			expectedID:    "",
 | 
			
		||||
			expectedError: handlers.ErrInvalidRequest,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	parser := &DefaultRequestParser{}
 | 
			
		||||
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.name, func(t *testing.T) {
 | 
			
		||||
			id, err := parser.GetSessionId(tt.request)
 | 
			
		||||
 | 
			
		||||
			if id != tt.expectedID {
 | 
			
		||||
				t.Errorf("Expected session ID %s, got %s", tt.expectedID, id)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if err != tt.expectedError {
 | 
			
		||||
				t.Errorf("Expected error %v, got %v", tt.expectedError, err)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestDefaultRequestParser_GetInput(t *testing.T) {
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		name          string
 | 
			
		||||
		request       any
 | 
			
		||||
		expectedInput []byte
 | 
			
		||||
		expectedError error
 | 
			
		||||
	}{
 | 
			
		||||
		{
 | 
			
		||||
			name: "Valid Input",
 | 
			
		||||
			request: func() *http.Request {
 | 
			
		||||
				return httptest.NewRequest(http.MethodPost, "/", bytes.NewBufferString("test input"))
 | 
			
		||||
			}(),
 | 
			
		||||
			expectedInput: []byte("test input"),
 | 
			
		||||
			expectedError: nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:          "Empty Input",
 | 
			
		||||
			request:       httptest.NewRequest(http.MethodPost, "/", nil),
 | 
			
		||||
			expectedInput: []byte{},
 | 
			
		||||
			expectedError: nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:          "Invalid Request Type",
 | 
			
		||||
			request:       invalidRequestType{},
 | 
			
		||||
			expectedInput: nil,
 | 
			
		||||
			expectedError: handlers.ErrInvalidRequest,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "Read Error",
 | 
			
		||||
			request: func() *http.Request {
 | 
			
		||||
				return httptest.NewRequest(http.MethodPost, "/", &errorReader{})
 | 
			
		||||
			}(),
 | 
			
		||||
			expectedInput: nil,
 | 
			
		||||
			expectedError: errors.New("read error"),
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	parser := &DefaultRequestParser{}
 | 
			
		||||
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.name, func(t *testing.T) {
 | 
			
		||||
			input, err := parser.GetInput(tt.request)
 | 
			
		||||
 | 
			
		||||
			if !bytes.Equal(input, tt.expectedInput) {
 | 
			
		||||
				t.Errorf("Expected input %s, got %s", tt.expectedInput, input)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if err != tt.expectedError && (err == nil || err.Error() != tt.expectedError.Error()) {
 | 
			
		||||
				t.Errorf("Expected error %v, got %v", tt.expectedError, err)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -6,6 +6,11 @@ import (
 | 
			
		||||
	"git.defalsify.org/vise.git/db"
 | 
			
		||||
	gdbmdb "git.defalsify.org/vise.git/db/gdbm"
 | 
			
		||||
	"git.defalsify.org/vise.git/lang"
 | 
			
		||||
	"git.defalsify.org/vise.git/logging"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	logg = logging.NewVanilla().WithDomain("gdbmstorage")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
@ -13,6 +13,7 @@ import (
 | 
			
		||||
	"git.defalsify.org/vise.git/persist"
 | 
			
		||||
	"git.defalsify.org/vise.git/resource"
 | 
			
		||||
	"git.grassecon.net/urdt/ussd/initializers"
 | 
			
		||||
	gdbmstorage "git.grassecon.net/urdt/ussd/internal/storage/db/gdbm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
@ -75,7 +76,7 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D
 | 
			
		||||
		connStr := buildConnStr()
 | 
			
		||||
		err = newDb.Connect(ctx, connStr)
 | 
			
		||||
	} else {
 | 
			
		||||
		newDb = NewThreadGdbmDb()
 | 
			
		||||
		newDb = gdbmstorage.NewThreadGdbmDb()
 | 
			
		||||
		storeFile := path.Join(ms.dbDir, fileName)
 | 
			
		||||
		err = newDb.Connect(ctx, storeFile)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user