Factor out db dump formatting #243
@ -1,35 +1,31 @@
 | 
				
			|||||||
package main
 | 
					package main
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"bytes"
 | 
					 | 
				
			||||||
	"context"
 | 
						"context"
 | 
				
			||||||
	"encoding/json"
 | 
					 | 
				
			||||||
	"flag"
 | 
						"flag"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"io"
 | 
					 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"os"
 | 
						"os"
 | 
				
			||||||
	"os/signal"
 | 
						"os/signal"
 | 
				
			||||||
	"path"
 | 
						"path"
 | 
				
			||||||
	"strconv"
 | 
						"strconv"
 | 
				
			||||||
	"strings"
 | 
					 | 
				
			||||||
	"syscall"
 | 
						"syscall"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"git.defalsify.org/vise.git/engine"
 | 
						"git.defalsify.org/vise.git/engine"
 | 
				
			||||||
	"git.defalsify.org/vise.git/logging"
 | 
						"git.defalsify.org/vise.git/logging"
 | 
				
			||||||
	"git.defalsify.org/vise.git/resource"
 | 
						"git.defalsify.org/vise.git/resource"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"git.grassecon.net/urdt/ussd/common"
 | 
					 | 
				
			||||||
	"git.grassecon.net/urdt/ussd/config"
 | 
						"git.grassecon.net/urdt/ussd/config"
 | 
				
			||||||
	"git.grassecon.net/urdt/ussd/initializers"
 | 
						"git.grassecon.net/urdt/ussd/initializers"
 | 
				
			||||||
	"git.grassecon.net/urdt/ussd/internal/handlers"
 | 
						"git.grassecon.net/urdt/ussd/internal/handlers"
 | 
				
			||||||
	httpserver "git.grassecon.net/urdt/ussd/internal/http"
 | 
						"git.grassecon.net/urdt/ussd/internal/http/at"
 | 
				
			||||||
 | 
						httpserver "git.grassecon.net/urdt/ussd/internal/http/at"
 | 
				
			||||||
	"git.grassecon.net/urdt/ussd/internal/storage"
 | 
						"git.grassecon.net/urdt/ussd/internal/storage"
 | 
				
			||||||
	"git.grassecon.net/urdt/ussd/remote"
 | 
						"git.grassecon.net/urdt/ussd/remote"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var (
 | 
					var (
 | 
				
			||||||
	logg          = logging.NewVanilla()
 | 
						logg          = logging.NewVanilla().WithDomain("AfricasTalking").WithContextKey("at-session-id")
 | 
				
			||||||
	scriptDir     = path.Join("services", "registration")
 | 
						scriptDir     = path.Join("services", "registration")
 | 
				
			||||||
	build         = "dev"
 | 
						build         = "dev"
 | 
				
			||||||
	menuSeparator = ": "
 | 
						menuSeparator = ": "
 | 
				
			||||||
@ -38,72 +34,6 @@ var (
 | 
				
			|||||||
func init() {
 | 
					func init() {
 | 
				
			||||||
	initializers.LoadEnvVariables()
 | 
						initializers.LoadEnvVariables()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					 | 
				
			||||||
type atRequestParser struct{}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (arp *atRequestParser) GetSessionId(rq any) (string, error) {
 | 
					 | 
				
			||||||
	rqv, ok := rq.(*http.Request)
 | 
					 | 
				
			||||||
	if !ok {
 | 
					 | 
				
			||||||
		logg.Warnf("got an invalid request", "req", rq)
 | 
					 | 
				
			||||||
		return "", handlers.ErrInvalidRequest
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Capture body (if any) for logging
 | 
					 | 
				
			||||||
	body, err := io.ReadAll(rqv.Body)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		logg.Warnf("failed to read request body", "err", err)
 | 
					 | 
				
			||||||
		return "", fmt.Errorf("failed to read request body: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	// Reset the body for further reading
 | 
					 | 
				
			||||||
	rqv.Body = io.NopCloser(bytes.NewReader(body))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Log the body as JSON
 | 
					 | 
				
			||||||
	bodyLog := map[string]string{"body": string(body)}
 | 
					 | 
				
			||||||
	logBytes, err := json.Marshal(bodyLog)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		logg.Warnf("failed to marshal request body", "err", err)
 | 
					 | 
				
			||||||
	} else {
 | 
					 | 
				
			||||||
		logg.Debugf("received request", "bytes", logBytes)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := rqv.ParseForm(); err != nil {
 | 
					 | 
				
			||||||
		logg.Warnf("failed to parse form data", "err", err)
 | 
					 | 
				
			||||||
		return "", fmt.Errorf("failed to parse form data: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	phoneNumber := rqv.FormValue("phoneNumber")
 | 
					 | 
				
			||||||
	if phoneNumber == "" {
 | 
					 | 
				
			||||||
		return "", fmt.Errorf("no phone number found")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	formattedNumber, err := common.FormatPhoneNumber(phoneNumber)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		logg.Warnf("failed to format phone number", "err", err)
 | 
					 | 
				
			||||||
		return "", fmt.Errorf("failed to format number")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return formattedNumber, nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (arp *atRequestParser) GetInput(rq any) ([]byte, error) {
 | 
					 | 
				
			||||||
	rqv, ok := rq.(*http.Request)
 | 
					 | 
				
			||||||
	if !ok {
 | 
					 | 
				
			||||||
		return nil, handlers.ErrInvalidRequest
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if err := rqv.ParseForm(); err != nil {
 | 
					 | 
				
			||||||
		return nil, fmt.Errorf("failed to parse form data: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	text := rqv.FormValue("text")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	parts := strings.Split(text, "*")
 | 
					 | 
				
			||||||
	if len(parts) == 0 {
 | 
					 | 
				
			||||||
		return nil, fmt.Errorf("no input found")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return []byte(parts[len(parts)-1]), nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func main() {
 | 
					func main() {
 | 
				
			||||||
	config.LoadConfig()
 | 
						config.LoadConfig()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -191,7 +121,9 @@ func main() {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	defer stateStore.Close()
 | 
						defer stateStore.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	rp := &atRequestParser{}
 | 
						rp := &at.ATRequestParser{
 | 
				
			||||||
 | 
							Context: ctx,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	bsh := handlers.NewBaseSessionHandler(cfg, rs, stateStore, userdataStore, rp, hl)
 | 
						bsh := handlers.NewBaseSessionHandler(cfg, rs, stateStore, userdataStore, rp, hl)
 | 
				
			||||||
	sh := httpserver.NewATSessionHandler(bsh)
 | 
						sh := httpserver.NewATSessionHandler(bsh)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -8,14 +8,15 @@ import (
 | 
				
			|||||||
	"git.defalsify.org/vise.git/resource"
 | 
						"git.defalsify.org/vise.git/resource"
 | 
				
			||||||
	"git.defalsify.org/vise.git/persist"
 | 
						"git.defalsify.org/vise.git/persist"
 | 
				
			||||||
	"git.grassecon.net/urdt/ussd/internal/storage"
 | 
						"git.grassecon.net/urdt/ussd/internal/storage"
 | 
				
			||||||
 | 
						dbstorage "git.grassecon.net/urdt/ussd/internal/storage/db"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func StoreToDb(store *UserDataStore) db.Db {
 | 
					func StoreToDb(store *UserDataStore) db.Db {
 | 
				
			||||||
	return store.Db
 | 
						return store.Db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func StoreToPrefixDb(store *UserDataStore, pfx []byte) storage.PrefixDb {
 | 
					func StoreToPrefixDb(store *UserDataStore, pfx []byte) dbstorage.PrefixDb {
 | 
				
			||||||
	return storage.NewSubPrefixDb(store.Db, pfx)	
 | 
						return dbstorage.NewSubPrefixDb(store.Db, pfx)	
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type StorageServices interface {
 | 
					type StorageServices interface {
 | 
				
			||||||
 | 
				
			|||||||
@ -6,7 +6,7 @@ import (
 | 
				
			|||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"git.grassecon.net/urdt/ussd/internal/storage"
 | 
						dbstorage "git.grassecon.net/urdt/ussd/internal/storage/db"
 | 
				
			||||||
	dataserviceapi "github.com/grassrootseconomics/ussd-data-service/pkg/api"
 | 
						dataserviceapi "github.com/grassrootseconomics/ussd-data-service/pkg/api"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -56,7 +56,7 @@ func ProcessTransfers(transfers []dataserviceapi.Last10TxResponse) TransferMetad
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// GetTransferData retrieves and matches transfer data
 | 
					// GetTransferData retrieves and matches transfer data
 | 
				
			||||||
// returns a formatted string of the full transaction/statement
 | 
					// returns a formatted string of the full transaction/statement
 | 
				
			||||||
func GetTransferData(ctx context.Context, db storage.PrefixDb, publicKey string, index int) (string, error) {
 | 
					func GetTransferData(ctx context.Context, db dbstorage.PrefixDb, publicKey string, index int) (string, error) {
 | 
				
			||||||
	keys := []DataTyp{DATA_TX_SENDERS, DATA_TX_RECIPIENTS, DATA_TX_VALUES, DATA_TX_ADDRESSES, DATA_TX_HASHES, DATA_TX_DATES, DATA_TX_SYMBOLS}
 | 
						keys := []DataTyp{DATA_TX_SENDERS, DATA_TX_RECIPIENTS, DATA_TX_VALUES, DATA_TX_ADDRESSES, DATA_TX_HASHES, DATA_TX_DATES, DATA_TX_SYMBOLS}
 | 
				
			||||||
	data := make(map[DataTyp]string)
 | 
						data := make(map[DataTyp]string)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -6,7 +6,7 @@ import (
 | 
				
			|||||||
	"math/big"
 | 
						"math/big"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"git.grassecon.net/urdt/ussd/internal/storage"
 | 
						dbstorage "git.grassecon.net/urdt/ussd/internal/storage/db"
 | 
				
			||||||
	dataserviceapi "github.com/grassrootseconomics/ussd-data-service/pkg/api"
 | 
						dataserviceapi "github.com/grassrootseconomics/ussd-data-service/pkg/api"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -63,7 +63,7 @@ func ScaleDownBalance(balance, decimals string) string {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// GetVoucherData retrieves and matches voucher data
 | 
					// GetVoucherData retrieves and matches voucher data
 | 
				
			||||||
func GetVoucherData(ctx context.Context, db storage.PrefixDb, input string) (*dataserviceapi.TokenHoldings, error) {
 | 
					func GetVoucherData(ctx context.Context, db dbstorage.PrefixDb, input string) (*dataserviceapi.TokenHoldings, error) {
 | 
				
			||||||
	keys := []DataTyp{DATA_VOUCHER_SYMBOLS, DATA_VOUCHER_BALANCES, DATA_VOUCHER_DECIMALS, DATA_VOUCHER_ADDRESSES}
 | 
						keys := []DataTyp{DATA_VOUCHER_SYMBOLS, DATA_VOUCHER_BALANCES, DATA_VOUCHER_DECIMALS, DATA_VOUCHER_ADDRESSES}
 | 
				
			||||||
	data := make(map[DataTyp]string)
 | 
						data := make(map[DataTyp]string)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -10,7 +10,7 @@ import (
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	visedb "git.defalsify.org/vise.git/db"
 | 
						visedb "git.defalsify.org/vise.git/db"
 | 
				
			||||||
	memdb "git.defalsify.org/vise.git/db/mem"
 | 
						memdb "git.defalsify.org/vise.git/db/mem"
 | 
				
			||||||
	"git.grassecon.net/urdt/ussd/internal/storage"
 | 
						dbstorage "git.grassecon.net/urdt/ussd/internal/storage/db"
 | 
				
			||||||
	dataserviceapi "github.com/grassrootseconomics/ussd-data-service/pkg/api"
 | 
						dataserviceapi "github.com/grassrootseconomics/ussd-data-service/pkg/api"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -86,7 +86,7 @@ func TestGetVoucherData(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	prefix := ToBytes(visedb.DATATYPE_USERDATA)
 | 
						prefix := ToBytes(visedb.DATATYPE_USERDATA)
 | 
				
			||||||
	spdb := storage.NewSubPrefixDb(db, prefix)
 | 
						spdb := dbstorage.NewSubPrefixDb(db, prefix)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Test voucher data
 | 
						// Test voucher data
 | 
				
			||||||
	mockData := map[DataTyp][]byte{
 | 
						mockData := map[DataTyp][]byte{
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										4
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								go.mod
									
									
									
									
									
								
							@ -3,7 +3,7 @@ module git.grassecon.net/urdt/ussd
 | 
				
			|||||||
go 1.23.0
 | 
					go 1.23.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
require (
 | 
					require (
 | 
				
			||||||
	git.defalsify.org/vise.git v0.2.3-0.20241231085136-8582c7e157d9
 | 
						git.defalsify.org/vise.git v0.2.3-0.20250103172917-3e190a44568d
 | 
				
			||||||
	github.com/alecthomas/assert/v2 v2.2.2
 | 
						github.com/alecthomas/assert/v2 v2.2.2
 | 
				
			||||||
	github.com/gofrs/uuid v4.4.0+incompatible
 | 
						github.com/gofrs/uuid v4.4.0+incompatible
 | 
				
			||||||
	github.com/grassrootseconomics/eth-custodial v1.3.0-beta
 | 
						github.com/grassrootseconomics/eth-custodial v1.3.0-beta
 | 
				
			||||||
@ -11,6 +11,7 @@ require (
 | 
				
			|||||||
	github.com/joho/godotenv v1.5.1
 | 
						github.com/joho/godotenv v1.5.1
 | 
				
			||||||
	github.com/peteole/testdata-loader v0.3.0
 | 
						github.com/peteole/testdata-loader v0.3.0
 | 
				
			||||||
	github.com/stretchr/testify v1.9.0
 | 
						github.com/stretchr/testify v1.9.0
 | 
				
			||||||
 | 
						golang.org/x/crypto v0.27.0
 | 
				
			||||||
	gopkg.in/leonelquinteros/gotext.v1 v1.3.1
 | 
						gopkg.in/leonelquinteros/gotext.v1 v1.3.1
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -32,7 +33,6 @@ require (
 | 
				
			|||||||
	github.com/rogpeppe/go-internal v1.13.1 // indirect
 | 
						github.com/rogpeppe/go-internal v1.13.1 // indirect
 | 
				
			||||||
	github.com/stretchr/objx v0.5.2 // indirect
 | 
						github.com/stretchr/objx v0.5.2 // indirect
 | 
				
			||||||
	github.com/x448/float16 v0.8.4 // indirect
 | 
						github.com/x448/float16 v0.8.4 // indirect
 | 
				
			||||||
	golang.org/x/crypto v0.27.0 // indirect
 | 
					 | 
				
			||||||
	golang.org/x/sync v0.8.0 // indirect
 | 
						golang.org/x/sync v0.8.0 // indirect
 | 
				
			||||||
	golang.org/x/text v0.18.0 // indirect
 | 
						golang.org/x/text v0.18.0 // indirect
 | 
				
			||||||
	gopkg.in/yaml.v3 v3.0.1 // indirect
 | 
						gopkg.in/yaml.v3 v3.0.1 // indirect
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										4
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								go.sum
									
									
									
									
									
								
							@ -1,5 +1,5 @@
 | 
				
			|||||||
git.defalsify.org/vise.git v0.2.3-0.20241231085136-8582c7e157d9 h1:O3m+NgWDWtJm8OculT99c4bDMAO4xLe2c8hpCKpsd9g=
 | 
					git.defalsify.org/vise.git v0.2.3-0.20250103172917-3e190a44568d h1:bPAOVZOX4frSGhfOdcj7kc555f8dc9DmMd2YAyC2AMw=
 | 
				
			||||||
git.defalsify.org/vise.git v0.2.3-0.20241231085136-8582c7e157d9/go.mod h1:jyBMe1qTYUz3mmuoC9JQ/TvFeW0vTanCUcPu3H8p4Ck=
 | 
					git.defalsify.org/vise.git v0.2.3-0.20250103172917-3e190a44568d/go.mod h1:jyBMe1qTYUz3mmuoC9JQ/TvFeW0vTanCUcPu3H8p4Ck=
 | 
				
			||||||
github.com/alecthomas/assert/v2 v2.2.2 h1:Z/iVC0xZfWTaFNE6bA3z07T86hd45Xe2eLt6WVy2bbk=
 | 
					github.com/alecthomas/assert/v2 v2.2.2 h1:Z/iVC0xZfWTaFNE6bA3z07T86hd45Xe2eLt6WVy2bbk=
 | 
				
			||||||
github.com/alecthomas/assert/v2 v2.2.2/go.mod h1:pXcQ2Asjp247dahGEmsZ6ru0UVwnkhktn7S0bBDLxvQ=
 | 
					github.com/alecthomas/assert/v2 v2.2.2/go.mod h1:pXcQ2Asjp247dahGEmsZ6ru0UVwnkhktn7S0bBDLxvQ=
 | 
				
			||||||
github.com/alecthomas/participle/v2 v2.0.0 h1:Fgrq+MbuSsJwIkw3fEj9h75vDP0Er5JzepJ0/HNHv0g=
 | 
					github.com/alecthomas/participle/v2 v2.0.0 h1:Fgrq+MbuSsJwIkw3fEj9h75vDP0Er5JzepJ0/HNHv0g=
 | 
				
			||||||
 | 
				
			|||||||
@ -23,12 +23,12 @@ import (
 | 
				
			|||||||
	"git.grassecon.net/urdt/ussd/remote"
 | 
						"git.grassecon.net/urdt/ussd/remote"
 | 
				
			||||||
	"gopkg.in/leonelquinteros/gotext.v1"
 | 
						"gopkg.in/leonelquinteros/gotext.v1"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"git.grassecon.net/urdt/ussd/internal/storage"
 | 
						dbstorage "git.grassecon.net/urdt/ussd/internal/storage/db"
 | 
				
			||||||
	dataserviceapi "github.com/grassrootseconomics/ussd-data-service/pkg/api"
 | 
						dataserviceapi "github.com/grassrootseconomics/ussd-data-service/pkg/api"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var (
 | 
					var (
 | 
				
			||||||
	logg           = logging.NewVanilla().WithDomain("ussdmenuhandler")
 | 
						logg           = logging.NewVanilla().WithDomain("ussdmenuhandler").WithContextKey("session-id")
 | 
				
			||||||
	scriptDir      = path.Join("services", "registration")
 | 
						scriptDir      = path.Join("services", "registration")
 | 
				
			||||||
	translationDir = path.Join(scriptDir, "locale")
 | 
						translationDir = path.Join(scriptDir, "locale")
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
@ -64,7 +64,7 @@ type Handlers struct {
 | 
				
			|||||||
	adminstore           *utils.AdminStore
 | 
						adminstore           *utils.AdminStore
 | 
				
			||||||
	flagManager          *asm.FlagParser
 | 
						flagManager          *asm.FlagParser
 | 
				
			||||||
	accountService       remote.AccountServiceInterface
 | 
						accountService       remote.AccountServiceInterface
 | 
				
			||||||
	prefixDb             storage.PrefixDb
 | 
						prefixDb             dbstorage.PrefixDb
 | 
				
			||||||
	profile              *models.Profile
 | 
						profile              *models.Profile
 | 
				
			||||||
	ReplaceSeparatorFunc func(string) string
 | 
						ReplaceSeparatorFunc func(string) string
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -80,7 +80,7 @@ func NewHandlers(appFlags *asm.FlagParser, userdataStore db.Db, adminstore *util
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	// Instantiate the SubPrefixDb with "DATATYPE_USERDATA" prefix
 | 
						// Instantiate the SubPrefixDb with "DATATYPE_USERDATA" prefix
 | 
				
			||||||
	prefix := common.ToBytes(db.DATATYPE_USERDATA)
 | 
						prefix := common.ToBytes(db.DATATYPE_USERDATA)
 | 
				
			||||||
	prefixDb := storage.NewSubPrefixDb(userdataStore, prefix)
 | 
						prefixDb := dbstorage.NewSubPrefixDb(userdataStore, prefix)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	h := &Handlers{
 | 
						h := &Handlers{
 | 
				
			||||||
		userdataStore:        userDb,
 | 
							userdataStore:        userDb,
 | 
				
			||||||
@ -122,9 +122,12 @@ func (h *Handlers) Init(ctx context.Context, sym string, input []byte) (resource
 | 
				
			|||||||
		h.st.Code = []byte{}
 | 
							h.st.Code = []byte{}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	sessionId, _ := ctx.Value("SessionId").(string)
 | 
						sessionId, ok := ctx.Value("SessionId").(string)
 | 
				
			||||||
	flag_admin_privilege, _ := h.flagManager.GetFlag("flag_admin_privilege")
 | 
						if ok {
 | 
				
			||||||
 | 
							context.WithValue(ctx, "session-id", sessionId)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						flag_admin_privilege, _ := h.flagManager.GetFlag("flag_admin_privilege")
 | 
				
			||||||
	isAdmin, _ := h.adminstore.IsAdmin(sessionId)
 | 
						isAdmin, _ := h.adminstore.IsAdmin(sessionId)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if isAdmin {
 | 
						if isAdmin {
 | 
				
			||||||
 | 
				
			|||||||
@ -13,7 +13,7 @@ import (
 | 
				
			|||||||
	"git.defalsify.org/vise.git/persist"
 | 
						"git.defalsify.org/vise.git/persist"
 | 
				
			||||||
	"git.defalsify.org/vise.git/resource"
 | 
						"git.defalsify.org/vise.git/resource"
 | 
				
			||||||
	"git.defalsify.org/vise.git/state"
 | 
						"git.defalsify.org/vise.git/state"
 | 
				
			||||||
	"git.grassecon.net/urdt/ussd/internal/storage"
 | 
						dbstorage "git.grassecon.net/urdt/ussd/internal/storage/db"
 | 
				
			||||||
	"git.grassecon.net/urdt/ussd/internal/testutil/mocks"
 | 
						"git.grassecon.net/urdt/ussd/internal/testutil/mocks"
 | 
				
			||||||
	"git.grassecon.net/urdt/ussd/internal/testutil/testservice"
 | 
						"git.grassecon.net/urdt/ussd/internal/testutil/testservice"
 | 
				
			||||||
	"git.grassecon.net/urdt/ussd/internal/utils"
 | 
						"git.grassecon.net/urdt/ussd/internal/utils"
 | 
				
			||||||
@ -59,14 +59,14 @@ func InitializeTestStore(t *testing.T) (context.Context, *common.UserDataStore)
 | 
				
			|||||||
	return ctx, store
 | 
						return ctx, store
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func InitializeTestSubPrefixDb(t *testing.T, ctx context.Context) *storage.SubPrefixDb {
 | 
					func InitializeTestSubPrefixDb(t *testing.T, ctx context.Context) *dbstorage.SubPrefixDb {
 | 
				
			||||||
	db := memdb.NewMemDb()
 | 
						db := memdb.NewMemDb()
 | 
				
			||||||
	err := db.Connect(ctx, "")
 | 
						err := db.Connect(ctx, "")
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	prefix := common.ToBytes(visedb.DATATYPE_USERDATA)
 | 
						prefix := common.ToBytes(visedb.DATATYPE_USERDATA)
 | 
				
			||||||
	spdb := storage.NewSubPrefixDb(db, prefix)
 | 
						spdb := dbstorage.NewSubPrefixDb(db, prefix)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return spdb
 | 
						return spdb
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										121
									
								
								internal/http/at/parse.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										121
									
								
								internal/http/at/parse.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,121 @@
 | 
				
			|||||||
 | 
					package at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"bytes"
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
 | 
						"encoding/json"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"io"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"net/url"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"git.grassecon.net/urdt/ussd/common"
 | 
				
			||||||
 | 
						"git.grassecon.net/urdt/ussd/internal/handlers"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type ATRequestParser struct {
 | 
				
			||||||
 | 
						Context context.Context
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (arp *ATRequestParser) GetSessionId(rq any) (string, error) {
 | 
				
			||||||
 | 
						rqv, ok := rq.(*http.Request)
 | 
				
			||||||
 | 
						if !ok {
 | 
				
			||||||
 | 
							logg.Warnf("got an invalid request", "req", rq)
 | 
				
			||||||
 | 
							return "", handlers.ErrInvalidRequest
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Capture body (if any) for logging
 | 
				
			||||||
 | 
						body, err := io.ReadAll(rqv.Body)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							logg.Warnf("failed to read request body", "err", err)
 | 
				
			||||||
 | 
							return "", fmt.Errorf("failed to read request body: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						// Reset the body for further reading
 | 
				
			||||||
 | 
						rqv.Body = io.NopCloser(bytes.NewReader(body))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Log the body as JSON
 | 
				
			||||||
 | 
						bodyLog := map[string]string{"body": string(body)}
 | 
				
			||||||
 | 
						logBytes, err := json.Marshal(bodyLog)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							logg.Warnf("failed to marshal request body", "err", err)
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							decodedStr := string(logBytes)
 | 
				
			||||||
 | 
							sessionId, err := extractATSessionId(decodedStr)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								context.WithValue(arp.Context, "at-session-id", sessionId)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							logg.Debugf("Received request:", decodedStr)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err := rqv.ParseForm(); err != nil {
 | 
				
			||||||
 | 
							logg.Warnf("failed to parse form data", "err", err)
 | 
				
			||||||
 | 
							return "", fmt.Errorf("failed to parse form data: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						phoneNumber := rqv.FormValue("phoneNumber")
 | 
				
			||||||
 | 
						if phoneNumber == "" {
 | 
				
			||||||
 | 
							return "", fmt.Errorf("no phone number found")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						formattedNumber, err := common.FormatPhoneNumber(phoneNumber)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							logg.Warnf("failed to format phone number", "err", err)
 | 
				
			||||||
 | 
							return "", fmt.Errorf("failed to format number")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return formattedNumber, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (arp *ATRequestParser) GetInput(rq any) ([]byte, error) {
 | 
				
			||||||
 | 
						rqv, ok := rq.(*http.Request)
 | 
				
			||||||
 | 
						if !ok {
 | 
				
			||||||
 | 
							return nil, handlers.ErrInvalidRequest
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if err := rqv.ParseForm(); err != nil {
 | 
				
			||||||
 | 
							return nil, fmt.Errorf("failed to parse form data: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						text := rqv.FormValue("text")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						parts := strings.Split(text, "*")
 | 
				
			||||||
 | 
						if len(parts) == 0 {
 | 
				
			||||||
 | 
							return nil, fmt.Errorf("no input found")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return []byte(parts[len(parts)-1]), nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func parseQueryParams(query string) map[string]string {
 | 
				
			||||||
 | 
						params := make(map[string]string)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						queryParams := strings.Split(query, "&")
 | 
				
			||||||
 | 
						for _, param := range queryParams {
 | 
				
			||||||
 | 
							// Split each key-value pair by '='
 | 
				
			||||||
 | 
							parts := strings.SplitN(param, "=", 2)
 | 
				
			||||||
 | 
							if len(parts) == 2 {
 | 
				
			||||||
 | 
								params[parts[0]] = parts[1]
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return params
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func extractATSessionId(decodedStr string) (string, error) {
 | 
				
			||||||
 | 
						var data map[string]string
 | 
				
			||||||
 | 
						err := json.Unmarshal([]byte(decodedStr), &data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							logg.Errorf("Error unmarshalling JSON: %v", err)
 | 
				
			||||||
 | 
							return "", nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						decodedBody, err := url.QueryUnescape(data["body"])
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							logg.Errorf("Error URL-decoding body: %v", err)
 | 
				
			||||||
 | 
							return "", nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						params := parseQueryParams(decodedBody)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						sessionId := params["sessionId"]
 | 
				
			||||||
 | 
						return sessionId, nil
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -1,19 +1,25 @@
 | 
				
			|||||||
package http
 | 
					package at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"io"
 | 
						"io"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"git.defalsify.org/vise.git/logging"
 | 
				
			||||||
	"git.grassecon.net/urdt/ussd/internal/handlers"
 | 
						"git.grassecon.net/urdt/ussd/internal/handlers"
 | 
				
			||||||
 | 
						httpserver "git.grassecon.net/urdt/ussd/internal/http"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var (
 | 
				
			||||||
 | 
						logg = logging.NewVanilla().WithDomain("atserver")
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type ATSessionHandler struct {
 | 
					type ATSessionHandler struct {
 | 
				
			||||||
	*SessionHandler
 | 
						*httpserver.SessionHandler
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func NewATSessionHandler(h handlers.RequestHandler) *ATSessionHandler {
 | 
					func NewATSessionHandler(h handlers.RequestHandler) *ATSessionHandler {
 | 
				
			||||||
	return &ATSessionHandler{
 | 
						return &ATSessionHandler{
 | 
				
			||||||
		SessionHandler: ToSessionHandler(h),
 | 
							SessionHandler: httpserver.ToSessionHandler(h),
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -31,14 +37,14 @@ func (ash *ATSessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request)
 | 
				
			|||||||
	cfg.SessionId, err = rp.GetSessionId(req)
 | 
						cfg.SessionId, err = rp.GetSessionId(req)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		logg.ErrorCtxf(rqs.Ctx, "", "header processing error", err)
 | 
							logg.ErrorCtxf(rqs.Ctx, "", "header processing error", err)
 | 
				
			||||||
		ash.writeError(w, 400, err)
 | 
							ash.WriteError(w, 400, err)
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	rqs.Config = cfg
 | 
						rqs.Config = cfg
 | 
				
			||||||
	rqs.Input, err = rp.GetInput(req)
 | 
						rqs.Input, err = rp.GetInput(req)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		logg.ErrorCtxf(rqs.Ctx, "", "header processing error", err)
 | 
							logg.ErrorCtxf(rqs.Ctx, "", "header processing error", err)
 | 
				
			||||||
		ash.writeError(w, 400, err)
 | 
							ash.WriteError(w, 400, err)
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -53,7 +59,7 @@ func (ash *ATSessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request)
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if code != 200 {
 | 
						if code != 200 {
 | 
				
			||||||
		ash.writeError(w, 500, err)
 | 
							ash.WriteError(w, 500, err)
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -61,13 +67,13 @@ func (ash *ATSessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request)
 | 
				
			|||||||
	w.Header().Set("Content-Type", "text/plain")
 | 
						w.Header().Set("Content-Type", "text/plain")
 | 
				
			||||||
	rqs, err = ash.Output(rqs)
 | 
						rqs, err = ash.Output(rqs)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		ash.writeError(w, 500, err)
 | 
							ash.WriteError(w, 500, err)
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	rqs, err = ash.Reset(rqs)
 | 
						rqs, err = ash.Reset(rqs)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		ash.writeError(w, 500, err)
 | 
							ash.WriteError(w, 500, err)
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -89,4 +95,4 @@ func (ash *ATSessionHandler) Output(rqs handlers.RequestSession) (handlers.Reque
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	_, err = rqs.Engine.Flush(rqs.Ctx, rqs.Writer)
 | 
						_, err = rqs.Engine.Flush(rqs.Ctx, rqs.Writer)
 | 
				
			||||||
	return rqs, err
 | 
						return rqs, err
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -1,7 +1,6 @@
 | 
				
			|||||||
package http
 | 
					package at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"bytes"
 | 
					 | 
				
			||||||
	"context"
 | 
						"context"
 | 
				
			||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
	"io"
 | 
						"io"
 | 
				
			||||||
@ -16,16 +15,6 @@ import (
 | 
				
			|||||||
	"git.grassecon.net/urdt/ussd/internal/testutil/mocks/httpmocks"
 | 
						"git.grassecon.net/urdt/ussd/internal/testutil/mocks/httpmocks"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// invalidRequestType is a custom type to test invalid request scenarios
 | 
					 | 
				
			||||||
type invalidRequestType struct{}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// errorReader is a helper type that always returns an error when Read is called
 | 
					 | 
				
			||||||
type errorReader struct{}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (e *errorReader) Read(p []byte) (n int, err error) {
 | 
					 | 
				
			||||||
	return 0, errors.New("read error")
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestNewATSessionHandler(t *testing.T) {
 | 
					func TestNewATSessionHandler(t *testing.T) {
 | 
				
			||||||
	mockHandler := &httpmocks.MockRequestHandler{}
 | 
						mockHandler := &httpmocks.MockRequestHandler{}
 | 
				
			||||||
	ash := NewATSessionHandler(mockHandler)
 | 
						ash := NewATSessionHandler(mockHandler)
 | 
				
			||||||
@ -242,208 +231,4 @@ func TestATSessionHandler_Output(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestSessionHandler_ServeHTTP(t *testing.T) {
 | 
					 | 
				
			||||||
	tests := []struct {
 | 
					 | 
				
			||||||
		name           string
 | 
					 | 
				
			||||||
		sessionID      string
 | 
					 | 
				
			||||||
		input          []byte
 | 
					 | 
				
			||||||
		parserErr      error
 | 
					 | 
				
			||||||
		processErr     error
 | 
					 | 
				
			||||||
		outputErr      error
 | 
					 | 
				
			||||||
		resetErr       error
 | 
					 | 
				
			||||||
		expectedStatus int
 | 
					 | 
				
			||||||
	}{
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			name:           "Success",
 | 
					 | 
				
			||||||
			sessionID:      "123",
 | 
					 | 
				
			||||||
			input:          []byte("test input"),
 | 
					 | 
				
			||||||
			expectedStatus: http.StatusOK,
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			name:           "Missing Session ID",
 | 
					 | 
				
			||||||
			sessionID:      "",
 | 
					 | 
				
			||||||
			parserErr:      handlers.ErrSessionMissing,
 | 
					 | 
				
			||||||
			expectedStatus: http.StatusBadRequest,
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			name:           "Process Error",
 | 
					 | 
				
			||||||
			sessionID:      "123",
 | 
					 | 
				
			||||||
			input:          []byte("test input"),
 | 
					 | 
				
			||||||
			processErr:     handlers.ErrStorage,
 | 
					 | 
				
			||||||
			expectedStatus: http.StatusInternalServerError,
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			name:           "Output Error",
 | 
					 | 
				
			||||||
			sessionID:      "123",
 | 
					 | 
				
			||||||
			input:          []byte("test input"),
 | 
					 | 
				
			||||||
			outputErr:      errors.New("output error"),
 | 
					 | 
				
			||||||
			expectedStatus: http.StatusOK,
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			name:           "Reset Error",
 | 
					 | 
				
			||||||
			sessionID:      "123",
 | 
					 | 
				
			||||||
			input:          []byte("test input"),
 | 
					 | 
				
			||||||
			resetErr:       errors.New("reset error"),
 | 
					 | 
				
			||||||
			expectedStatus: http.StatusOK,
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for _, tt := range tests {
 | 
					 | 
				
			||||||
		t.Run(tt.name, func(t *testing.T) {
 | 
					 | 
				
			||||||
			mockRequestParser := &httpmocks.MockRequestParser{
 | 
					 | 
				
			||||||
				GetSessionIdFunc: func(any) (string, error) {
 | 
					 | 
				
			||||||
					return tt.sessionID, tt.parserErr
 | 
					 | 
				
			||||||
				},
 | 
					 | 
				
			||||||
				GetInputFunc: func(any) ([]byte, error) {
 | 
					 | 
				
			||||||
					return tt.input, nil
 | 
					 | 
				
			||||||
				},
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			mockRequestHandler := &httpmocks.MockRequestHandler{
 | 
					 | 
				
			||||||
				ProcessFunc: func(rs handlers.RequestSession) (handlers.RequestSession, error) {
 | 
					 | 
				
			||||||
					return rs, tt.processErr
 | 
					 | 
				
			||||||
				},
 | 
					 | 
				
			||||||
				OutputFunc: func(rs handlers.RequestSession) (handlers.RequestSession, error) {
 | 
					 | 
				
			||||||
					return rs, tt.outputErr
 | 
					 | 
				
			||||||
				},
 | 
					 | 
				
			||||||
				ResetFunc: func(rs handlers.RequestSession) (handlers.RequestSession, error) {
 | 
					 | 
				
			||||||
					return rs, tt.resetErr
 | 
					 | 
				
			||||||
				},
 | 
					 | 
				
			||||||
				GetRequestParserFunc: func() handlers.RequestParser {
 | 
					 | 
				
			||||||
					return mockRequestParser
 | 
					 | 
				
			||||||
				},
 | 
					 | 
				
			||||||
				GetConfigFunc: func() engine.Config {
 | 
					 | 
				
			||||||
					return engine.Config{}
 | 
					 | 
				
			||||||
				},
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			sessionHandler := ToSessionHandler(mockRequestHandler)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBuffer(tt.input))
 | 
					 | 
				
			||||||
			req.Header.Set("X-Vise-Session", tt.sessionID)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			rr := httptest.NewRecorder()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			sessionHandler.ServeHTTP(rr, req)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			if status := rr.Code; status != tt.expectedStatus {
 | 
					 | 
				
			||||||
				t.Errorf("handler returned wrong status code: got %v want %v",
 | 
					 | 
				
			||||||
					status, tt.expectedStatus)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		})
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestSessionHandler_writeError(t *testing.T) {
 | 
					 | 
				
			||||||
	handler := &SessionHandler{}
 | 
					 | 
				
			||||||
	mockWriter := &httpmocks.MockWriter{}
 | 
					 | 
				
			||||||
	err := errors.New("test error")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	handler.writeError(mockWriter, http.StatusBadRequest, err)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if mockWriter.WrittenString != "" {
 | 
					 | 
				
			||||||
		t.Errorf("Expected empty body, got %s", mockWriter.WrittenString)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestDefaultRequestParser_GetSessionId(t *testing.T) {
 | 
					 | 
				
			||||||
	tests := []struct {
 | 
					 | 
				
			||||||
		name          string
 | 
					 | 
				
			||||||
		request       any
 | 
					 | 
				
			||||||
		expectedID    string
 | 
					 | 
				
			||||||
		expectedError error
 | 
					 | 
				
			||||||
	}{
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			name: "Valid Session ID",
 | 
					 | 
				
			||||||
			request: func() *http.Request {
 | 
					 | 
				
			||||||
				req := httptest.NewRequest(http.MethodPost, "/", nil)
 | 
					 | 
				
			||||||
				req.Header.Set("X-Vise-Session", "123456")
 | 
					 | 
				
			||||||
				return req
 | 
					 | 
				
			||||||
			}(),
 | 
					 | 
				
			||||||
			expectedID:    "123456",
 | 
					 | 
				
			||||||
			expectedError: nil,
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			name:          "Missing Session ID",
 | 
					 | 
				
			||||||
			request:       httptest.NewRequest(http.MethodPost, "/", nil),
 | 
					 | 
				
			||||||
			expectedID:    "",
 | 
					 | 
				
			||||||
			expectedError: handlers.ErrSessionMissing,
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			name:          "Invalid Request Type",
 | 
					 | 
				
			||||||
			request:       invalidRequestType{},
 | 
					 | 
				
			||||||
			expectedID:    "",
 | 
					 | 
				
			||||||
			expectedError: handlers.ErrInvalidRequest,
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	parser := &DefaultRequestParser{}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	for _, tt := range tests {
 | 
					 | 
				
			||||||
		t.Run(tt.name, func(t *testing.T) {
 | 
					 | 
				
			||||||
			id, err := parser.GetSessionId(tt.request)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			if id != tt.expectedID {
 | 
					 | 
				
			||||||
				t.Errorf("Expected session ID %s, got %s", tt.expectedID, id)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			if err != tt.expectedError {
 | 
					 | 
				
			||||||
				t.Errorf("Expected error %v, got %v", tt.expectedError, err)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		})
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestDefaultRequestParser_GetInput(t *testing.T) {
 | 
					 | 
				
			||||||
	tests := []struct {
 | 
					 | 
				
			||||||
		name          string
 | 
					 | 
				
			||||||
		request       any
 | 
					 | 
				
			||||||
		expectedInput []byte
 | 
					 | 
				
			||||||
		expectedError error
 | 
					 | 
				
			||||||
	}{
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			name: "Valid Input",
 | 
					 | 
				
			||||||
			request: func() *http.Request {
 | 
					 | 
				
			||||||
				return httptest.NewRequest(http.MethodPost, "/", bytes.NewBufferString("test input"))
 | 
					 | 
				
			||||||
			}(),
 | 
					 | 
				
			||||||
			expectedInput: []byte("test input"),
 | 
					 | 
				
			||||||
			expectedError: nil,
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			name:          "Empty Input",
 | 
					 | 
				
			||||||
			request:       httptest.NewRequest(http.MethodPost, "/", nil),
 | 
					 | 
				
			||||||
			expectedInput: []byte{},
 | 
					 | 
				
			||||||
			expectedError: nil,
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			name:          "Invalid Request Type",
 | 
					 | 
				
			||||||
			request:       invalidRequestType{},
 | 
					 | 
				
			||||||
			expectedInput: nil,
 | 
					 | 
				
			||||||
			expectedError: handlers.ErrInvalidRequest,
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			name: "Read Error",
 | 
					 | 
				
			||||||
			request: func() *http.Request {
 | 
					 | 
				
			||||||
				return httptest.NewRequest(http.MethodPost, "/", &errorReader{})
 | 
					 | 
				
			||||||
			}(),
 | 
					 | 
				
			||||||
			expectedInput: nil,
 | 
					 | 
				
			||||||
			expectedError: errors.New("read error"),
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	parser := &DefaultRequestParser{}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	for _, tt := range tests {
 | 
					 | 
				
			||||||
		t.Run(tt.name, func(t *testing.T) {
 | 
					 | 
				
			||||||
			input, err := parser.GetInput(tt.request)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			if !bytes.Equal(input, tt.expectedInput) {
 | 
					 | 
				
			||||||
				t.Errorf("Expected input %s, got %s", tt.expectedInput, input)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			if err != tt.expectedError && (err == nil || err.Error() != tt.expectedError.Error()) {
 | 
					 | 
				
			||||||
				t.Errorf("Expected error %v, got %v", tt.expectedError, err)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		})
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
							
								
								
									
										38
									
								
								internal/http/parse.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								internal/http/parse.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,38 @@
 | 
				
			|||||||
 | 
					package http
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"io/ioutil"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"git.grassecon.net/urdt/ussd/internal/handlers"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type DefaultRequestParser struct {
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (rp *DefaultRequestParser) GetSessionId(rq any) (string, error) {
 | 
				
			||||||
 | 
						rqv, ok := rq.(*http.Request)
 | 
				
			||||||
 | 
						if !ok {
 | 
				
			||||||
 | 
							return "", handlers.ErrInvalidRequest
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						v := rqv.Header.Get("X-Vise-Session")
 | 
				
			||||||
 | 
						if v == "" {
 | 
				
			||||||
 | 
							return "", handlers.ErrSessionMissing
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return v, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (rp *DefaultRequestParser) GetInput(rq any) ([]byte, error) {
 | 
				
			||||||
 | 
						rqv, ok := rq.(*http.Request)
 | 
				
			||||||
 | 
						if !ok {
 | 
				
			||||||
 | 
							return nil, handlers.ErrInvalidRequest
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer rqv.Body.Close()
 | 
				
			||||||
 | 
						v, err := ioutil.ReadAll(rqv.Body)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return v, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -1,7 +1,6 @@
 | 
				
			|||||||
package http
 | 
					package http
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"io/ioutil"
 | 
					 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"strconv"
 | 
						"strconv"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -14,34 +13,6 @@ var (
 | 
				
			|||||||
	logg = logging.NewVanilla().WithDomain("httpserver")
 | 
						logg = logging.NewVanilla().WithDomain("httpserver")
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type DefaultRequestParser struct {
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (rp *DefaultRequestParser) GetSessionId(rq any) (string, error) {
 | 
					 | 
				
			||||||
	rqv, ok := rq.(*http.Request)
 | 
					 | 
				
			||||||
	if !ok {
 | 
					 | 
				
			||||||
		return "", handlers.ErrInvalidRequest
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	v := rqv.Header.Get("X-Vise-Session")
 | 
					 | 
				
			||||||
	if v == "" {
 | 
					 | 
				
			||||||
		return "", handlers.ErrSessionMissing
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return v, nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (rp *DefaultRequestParser) GetInput(rq any) ([]byte, error) {
 | 
					 | 
				
			||||||
	rqv, ok := rq.(*http.Request)
 | 
					 | 
				
			||||||
	if !ok {
 | 
					 | 
				
			||||||
		return nil, handlers.ErrInvalidRequest
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	defer rqv.Body.Close()
 | 
					 | 
				
			||||||
	v, err := ioutil.ReadAll(rqv.Body)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return nil, err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return v, nil
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type SessionHandler struct {
 | 
					type SessionHandler struct {
 | 
				
			||||||
	handlers.RequestHandler
 | 
						handlers.RequestHandler
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -52,7 +23,7 @@ func ToSessionHandler(h handlers.RequestHandler) *SessionHandler {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (f *SessionHandler) writeError(w http.ResponseWriter, code int, err error) {
 | 
					func (f *SessionHandler) WriteError(w http.ResponseWriter, code int, err error) {
 | 
				
			||||||
	s := err.Error()
 | 
						s := err.Error()
 | 
				
			||||||
	w.Header().Set("Content-Length", strconv.Itoa(len(s)))
 | 
						w.Header().Set("Content-Length", strconv.Itoa(len(s)))
 | 
				
			||||||
	w.WriteHeader(code)
 | 
						w.WriteHeader(code)
 | 
				
			||||||
@ -78,13 +49,13 @@ func (f *SessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
 | 
				
			|||||||
	cfg.SessionId, err = rp.GetSessionId(req)
 | 
						cfg.SessionId, err = rp.GetSessionId(req)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		logg.ErrorCtxf(rqs.Ctx, "", "header processing error", err)
 | 
							logg.ErrorCtxf(rqs.Ctx, "", "header processing error", err)
 | 
				
			||||||
		f.writeError(w, 400, err)
 | 
							f.WriteError(w, 400, err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	rqs.Config = cfg
 | 
						rqs.Config = cfg
 | 
				
			||||||
	rqs.Input, err = rp.GetInput(req)
 | 
						rqs.Input, err = rp.GetInput(req)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		logg.ErrorCtxf(rqs.Ctx, "", "header processing error", err)
 | 
							logg.ErrorCtxf(rqs.Ctx, "", "header processing error", err)
 | 
				
			||||||
		f.writeError(w, 400, err)
 | 
							f.WriteError(w, 400, err)
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -101,7 +72,7 @@ func (f *SessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if code != 200 {
 | 
						if code != 200 {
 | 
				
			||||||
		f.writeError(w, 500, err)
 | 
							f.WriteError(w, 500, err)
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -110,11 +81,11 @@ func (f *SessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
 | 
				
			|||||||
	rqs, err = f.Output(rqs)
 | 
						rqs, err = f.Output(rqs)
 | 
				
			||||||
	rqs, perr = f.Reset(rqs)
 | 
						rqs, perr = f.Reset(rqs)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		f.writeError(w, 500, err)
 | 
							f.WriteError(w, 500, err)
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if perr != nil {
 | 
						if perr != nil {
 | 
				
			||||||
		f.writeError(w, 500, perr)
 | 
							f.WriteError(w, 500, perr)
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										229
									
								
								internal/http/server_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										229
									
								
								internal/http/server_test.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,229 @@
 | 
				
			|||||||
 | 
					package http
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"bytes"
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"net/http/httptest"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"git.defalsify.org/vise.git/engine"
 | 
				
			||||||
 | 
						"git.grassecon.net/urdt/ussd/internal/handlers"
 | 
				
			||||||
 | 
						"git.grassecon.net/urdt/ussd/internal/testutil/mocks/httpmocks"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// invalidRequestType is a custom type to test invalid request scenarios
 | 
				
			||||||
 | 
					type invalidRequestType struct{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// errorReader is a helper type that always returns an error when Read is called
 | 
				
			||||||
 | 
					type errorReader struct{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (e *errorReader) Read(p []byte) (n int, err error) {
 | 
				
			||||||
 | 
						return 0, errors.New("read error")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestSessionHandler_ServeHTTP(t *testing.T) {
 | 
				
			||||||
 | 
						tests := []struct {
 | 
				
			||||||
 | 
							name           string
 | 
				
			||||||
 | 
							sessionID      string
 | 
				
			||||||
 | 
							input          []byte
 | 
				
			||||||
 | 
							parserErr      error
 | 
				
			||||||
 | 
							processErr     error
 | 
				
			||||||
 | 
							outputErr      error
 | 
				
			||||||
 | 
							resetErr       error
 | 
				
			||||||
 | 
							expectedStatus int
 | 
				
			||||||
 | 
						}{
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name:           "Success",
 | 
				
			||||||
 | 
								sessionID:      "123",
 | 
				
			||||||
 | 
								input:          []byte("test input"),
 | 
				
			||||||
 | 
								expectedStatus: http.StatusOK,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name:           "Missing Session ID",
 | 
				
			||||||
 | 
								sessionID:      "",
 | 
				
			||||||
 | 
								parserErr:      handlers.ErrSessionMissing,
 | 
				
			||||||
 | 
								expectedStatus: http.StatusBadRequest,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name:           "Process Error",
 | 
				
			||||||
 | 
								sessionID:      "123",
 | 
				
			||||||
 | 
								input:          []byte("test input"),
 | 
				
			||||||
 | 
								processErr:     handlers.ErrStorage,
 | 
				
			||||||
 | 
								expectedStatus: http.StatusInternalServerError,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name:           "Output Error",
 | 
				
			||||||
 | 
								sessionID:      "123",
 | 
				
			||||||
 | 
								input:          []byte("test input"),
 | 
				
			||||||
 | 
								outputErr:      errors.New("output error"),
 | 
				
			||||||
 | 
								expectedStatus: http.StatusOK,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name:           "Reset Error",
 | 
				
			||||||
 | 
								sessionID:      "123",
 | 
				
			||||||
 | 
								input:          []byte("test input"),
 | 
				
			||||||
 | 
								resetErr:       errors.New("reset error"),
 | 
				
			||||||
 | 
								expectedStatus: http.StatusOK,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, tt := range tests {
 | 
				
			||||||
 | 
							t.Run(tt.name, func(t *testing.T) {
 | 
				
			||||||
 | 
								mockRequestParser := &httpmocks.MockRequestParser{
 | 
				
			||||||
 | 
									GetSessionIdFunc: func(any) (string, error) {
 | 
				
			||||||
 | 
										return tt.sessionID, tt.parserErr
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
									GetInputFunc: func(any) ([]byte, error) {
 | 
				
			||||||
 | 
										return tt.input, nil
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								mockRequestHandler := &httpmocks.MockRequestHandler{
 | 
				
			||||||
 | 
									ProcessFunc: func(rs handlers.RequestSession) (handlers.RequestSession, error) {
 | 
				
			||||||
 | 
										return rs, tt.processErr
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
									OutputFunc: func(rs handlers.RequestSession) (handlers.RequestSession, error) {
 | 
				
			||||||
 | 
										return rs, tt.outputErr
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
									ResetFunc: func(rs handlers.RequestSession) (handlers.RequestSession, error) {
 | 
				
			||||||
 | 
										return rs, tt.resetErr
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
									GetRequestParserFunc: func() handlers.RequestParser {
 | 
				
			||||||
 | 
										return mockRequestParser
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
									GetConfigFunc: func() engine.Config {
 | 
				
			||||||
 | 
										return engine.Config{}
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								sessionHandler := ToSessionHandler(mockRequestHandler)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBuffer(tt.input))
 | 
				
			||||||
 | 
								req.Header.Set("X-Vise-Session", tt.sessionID)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								rr := httptest.NewRecorder()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								sessionHandler.ServeHTTP(rr, req)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								if status := rr.Code; status != tt.expectedStatus {
 | 
				
			||||||
 | 
									t.Errorf("handler returned wrong status code: got %v want %v",
 | 
				
			||||||
 | 
										status, tt.expectedStatus)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestSessionHandler_WriteError(t *testing.T) {
 | 
				
			||||||
 | 
						handler := &SessionHandler{}
 | 
				
			||||||
 | 
						mockWriter := &httpmocks.MockWriter{}
 | 
				
			||||||
 | 
						err := errors.New("test error")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						handler.WriteError(mockWriter, http.StatusBadRequest, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if mockWriter.WrittenString != "" {
 | 
				
			||||||
 | 
							t.Errorf("Expected empty body, got %s", mockWriter.WrittenString)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestDefaultRequestParser_GetSessionId(t *testing.T) {
 | 
				
			||||||
 | 
						tests := []struct {
 | 
				
			||||||
 | 
							name          string
 | 
				
			||||||
 | 
							request       any
 | 
				
			||||||
 | 
							expectedID    string
 | 
				
			||||||
 | 
							expectedError error
 | 
				
			||||||
 | 
						}{
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "Valid Session ID",
 | 
				
			||||||
 | 
								request: func() *http.Request {
 | 
				
			||||||
 | 
									req := httptest.NewRequest(http.MethodPost, "/", nil)
 | 
				
			||||||
 | 
									req.Header.Set("X-Vise-Session", "123456")
 | 
				
			||||||
 | 
									return req
 | 
				
			||||||
 | 
								}(),
 | 
				
			||||||
 | 
								expectedID:    "123456",
 | 
				
			||||||
 | 
								expectedError: nil,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name:          "Missing Session ID",
 | 
				
			||||||
 | 
								request:       httptest.NewRequest(http.MethodPost, "/", nil),
 | 
				
			||||||
 | 
								expectedID:    "",
 | 
				
			||||||
 | 
								expectedError: handlers.ErrSessionMissing,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name:          "Invalid Request Type",
 | 
				
			||||||
 | 
								request:       invalidRequestType{},
 | 
				
			||||||
 | 
								expectedID:    "",
 | 
				
			||||||
 | 
								expectedError: handlers.ErrInvalidRequest,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						parser := &DefaultRequestParser{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, tt := range tests {
 | 
				
			||||||
 | 
							t.Run(tt.name, func(t *testing.T) {
 | 
				
			||||||
 | 
								id, err := parser.GetSessionId(tt.request)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								if id != tt.expectedID {
 | 
				
			||||||
 | 
									t.Errorf("Expected session ID %s, got %s", tt.expectedID, id)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								if err != tt.expectedError {
 | 
				
			||||||
 | 
									t.Errorf("Expected error %v, got %v", tt.expectedError, err)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestDefaultRequestParser_GetInput(t *testing.T) {
 | 
				
			||||||
 | 
						tests := []struct {
 | 
				
			||||||
 | 
							name          string
 | 
				
			||||||
 | 
							request       any
 | 
				
			||||||
 | 
							expectedInput []byte
 | 
				
			||||||
 | 
							expectedError error
 | 
				
			||||||
 | 
						}{
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "Valid Input",
 | 
				
			||||||
 | 
								request: func() *http.Request {
 | 
				
			||||||
 | 
									return httptest.NewRequest(http.MethodPost, "/", bytes.NewBufferString("test input"))
 | 
				
			||||||
 | 
								}(),
 | 
				
			||||||
 | 
								expectedInput: []byte("test input"),
 | 
				
			||||||
 | 
								expectedError: nil,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name:          "Empty Input",
 | 
				
			||||||
 | 
								request:       httptest.NewRequest(http.MethodPost, "/", nil),
 | 
				
			||||||
 | 
								expectedInput: []byte{},
 | 
				
			||||||
 | 
								expectedError: nil,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name:          "Invalid Request Type",
 | 
				
			||||||
 | 
								request:       invalidRequestType{},
 | 
				
			||||||
 | 
								expectedInput: nil,
 | 
				
			||||||
 | 
								expectedError: handlers.ErrInvalidRequest,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "Read Error",
 | 
				
			||||||
 | 
								request: func() *http.Request {
 | 
				
			||||||
 | 
									return httptest.NewRequest(http.MethodPost, "/", &errorReader{})
 | 
				
			||||||
 | 
								}(),
 | 
				
			||||||
 | 
								expectedInput: nil,
 | 
				
			||||||
 | 
								expectedError: errors.New("read error"),
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						parser := &DefaultRequestParser{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, tt := range tests {
 | 
				
			||||||
 | 
							t.Run(tt.name, func(t *testing.T) {
 | 
				
			||||||
 | 
								input, err := parser.GetInput(tt.request)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								if !bytes.Equal(input, tt.expectedInput) {
 | 
				
			||||||
 | 
									t.Errorf("Expected input %s, got %s", tt.expectedInput, input)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								if err != tt.expectedError && (err == nil || err.Error() != tt.expectedError.Error()) {
 | 
				
			||||||
 | 
									t.Errorf("Expected error %v, got %v", tt.expectedError, err)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -6,6 +6,11 @@ import (
 | 
				
			|||||||
	"git.defalsify.org/vise.git/db"
 | 
						"git.defalsify.org/vise.git/db"
 | 
				
			||||||
	gdbmdb "git.defalsify.org/vise.git/db/gdbm"
 | 
						gdbmdb "git.defalsify.org/vise.git/db/gdbm"
 | 
				
			||||||
	"git.defalsify.org/vise.git/lang"
 | 
						"git.defalsify.org/vise.git/lang"
 | 
				
			||||||
 | 
						"git.defalsify.org/vise.git/logging"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var (
 | 
				
			||||||
 | 
						logg = logging.NewVanilla().WithDomain("gdbmstorage")
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var (
 | 
					var (
 | 
				
			||||||
@ -13,6 +13,7 @@ import (
 | 
				
			|||||||
	"git.defalsify.org/vise.git/persist"
 | 
						"git.defalsify.org/vise.git/persist"
 | 
				
			||||||
	"git.defalsify.org/vise.git/resource"
 | 
						"git.defalsify.org/vise.git/resource"
 | 
				
			||||||
	"git.grassecon.net/urdt/ussd/initializers"
 | 
						"git.grassecon.net/urdt/ussd/initializers"
 | 
				
			||||||
 | 
						gdbmstorage "git.grassecon.net/urdt/ussd/internal/storage/db/gdbm"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var (
 | 
					var (
 | 
				
			||||||
@ -75,7 +76,7 @@ func (ms *MenuStorageService) getOrCreateDb(ctx context.Context, existingDb db.D
 | 
				
			|||||||
		connStr := buildConnStr()
 | 
							connStr := buildConnStr()
 | 
				
			||||||
		err = newDb.Connect(ctx, connStr)
 | 
							err = newDb.Connect(ctx, connStr)
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		newDb = NewThreadGdbmDb()
 | 
							newDb = gdbmstorage.NewThreadGdbmDb()
 | 
				
			||||||
		storeFile := path.Join(ms.dbDir, fileName)
 | 
							storeFile := path.Join(ms.dbDir, fileName)
 | 
				
			||||||
		err = newDb.Connect(ctx, storeFile)
 | 
							err = newDb.Connect(ctx, storeFile)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user