Compare commits
	
		
			14 Commits
		
	
	
		
			51b6fc0dde
			...
			cc2f7b41df
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 
						 | 
					cc2f7b41df | ||
| 
						 | 
					2024cc96e2 | ||
| 
						 | 
					d2d878d5d7 | ||
| c995143543 | |||
| 44570e20ef | |||
| 362eb209ef | |||
| c69d3896f1 | |||
| 974af6b2a7 | |||
| 47b5ff0435 | |||
| 
						 | 
					bb1a846cb3 | ||
| 
						 | 
					967e53d83b | ||
| 
						 | 
					d246cdee51 | ||
| 
						 | 
					d518a76536 | ||
| 
						 | 
					6f65c33be4 | 
@ -130,9 +130,7 @@ func main() {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	defer stateStore.Close()
 | 
						defer stateStore.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	rp := &at.ATRequestParser{
 | 
						rp := &at.ATRequestParser{}
 | 
				
			||||||
		Context: ctx,
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	bsh := handlers.NewBaseSessionHandler(cfg, rs, stateStore, userdataStore, rp, hl)
 | 
						bsh := handlers.NewBaseSessionHandler(cfg, rs, stateStore, userdataStore, rp, hl)
 | 
				
			||||||
	sh := at.NewATSessionHandler(bsh)
 | 
						sh := at.NewATSessionHandler(bsh)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -21,8 +21,8 @@ import (
 | 
				
			|||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var (
 | 
					var (
 | 
				
			||||||
	logg      = logging.NewVanilla()
 | 
						logg          = logging.NewVanilla()
 | 
				
			||||||
	scriptDir = path.Join("services", "registration")
 | 
						scriptDir     = path.Join("services", "registration")
 | 
				
			||||||
	menuSeparator = ": "
 | 
						menuSeparator = ": "
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -35,7 +35,7 @@ type asyncRequestParser struct {
 | 
				
			|||||||
	input     []byte
 | 
						input     []byte
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (p *asyncRequestParser) GetSessionId(r any) (string, error) {
 | 
					func (p *asyncRequestParser) GetSessionId(ctx context.Context, r any) (string, error) {
 | 
				
			||||||
	return p.sessionId, nil
 | 
						return p.sessionId, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										34
									
								
								cmd/ssh/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								cmd/ssh/README.md
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,34 @@
 | 
				
			|||||||
 | 
					# URDT-USSD SSH server
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					An SSH server entry point for the vise engine.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Adding public keys for access
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Map your (client) public key to a session identifier (e.g. phone number)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					go run -v -tags logtrace ./cmd/ssh/sshkey/main.go -i <session_id> [--dbdir <dbpath>] <client_publickey_filepath>
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Create a private key for the server
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					ssh-keygen -N "" -f <server_privatekey_filepath>
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Run the server
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					go run -v -tags logtrace ./cmd/ssh/main.go -h <host> -p <port> [--dbdir <dbpath>] <server_privatekey_filepath>
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Connect to the server
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					ssh [-v] -T -p <port> -i <client_publickey_filepath> <host>
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
							
								
								
									
										115
									
								
								cmd/ssh/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										115
									
								
								cmd/ssh/main.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,115 @@
 | 
				
			|||||||
 | 
					package main
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
 | 
						"flag"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"path"
 | 
				
			||||||
 | 
						"os"
 | 
				
			||||||
 | 
						"os/signal"
 | 
				
			||||||
 | 
						"sync"
 | 
				
			||||||
 | 
						"syscall"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"git.defalsify.org/vise.git/db"
 | 
				
			||||||
 | 
						"git.defalsify.org/vise.git/engine"
 | 
				
			||||||
 | 
						"git.defalsify.org/vise.git/logging"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"git.grassecon.net/urdt/ussd/internal/ssh"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var (
 | 
				
			||||||
 | 
						wg sync.WaitGroup
 | 
				
			||||||
 | 
						keyStore db.Db
 | 
				
			||||||
 | 
						logg      = logging.NewVanilla()
 | 
				
			||||||
 | 
						scriptDir = path.Join("services", "registration")
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func main() {
 | 
				
			||||||
 | 
						var dbDir string
 | 
				
			||||||
 | 
						var resourceDir string
 | 
				
			||||||
 | 
						var size uint
 | 
				
			||||||
 | 
						var engineDebug bool
 | 
				
			||||||
 | 
						var stateDebug bool
 | 
				
			||||||
 | 
						var host string
 | 
				
			||||||
 | 
						var port uint
 | 
				
			||||||
 | 
						flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from")
 | 
				
			||||||
 | 
						flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir")
 | 
				
			||||||
 | 
						flag.BoolVar(&engineDebug, "engine-debug", false, "use engine debug output")
 | 
				
			||||||
 | 
						flag.BoolVar(&stateDebug, "state-debug", false, "use engine debug output")
 | 
				
			||||||
 | 
						flag.UintVar(&size, "s", 160, "max size of output")
 | 
				
			||||||
 | 
						flag.StringVar(&host, "h", "127.0.0.1", "http host")
 | 
				
			||||||
 | 
						flag.UintVar(&port, "p", 7122, "http port")
 | 
				
			||||||
 | 
						flag.Parse()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						sshKeyFile := flag.Arg(0)
 | 
				
			||||||
 | 
						_, err := os.Stat(sshKeyFile)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							fmt.Fprintf(os.Stderr, "cannot open ssh server private key file: %v\n", err)
 | 
				
			||||||
 | 
							os.Exit(1)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						ctx := context.Background()
 | 
				
			||||||
 | 
						logg.WarnCtxf(ctx, "!!!!! WARNING WARNING WARNING")
 | 
				
			||||||
 | 
						logg.WarnCtxf(ctx, "!!!!! =======================")
 | 
				
			||||||
 | 
						logg.WarnCtxf(ctx, "!!!!! This is not a production ready server!")
 | 
				
			||||||
 | 
						logg.WarnCtxf(ctx, "!!!!! Do not expose to internet and only use with tunnel!")
 | 
				
			||||||
 | 
						logg.WarnCtxf(ctx, "!!!!! (See ssh -L <...>)")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						logg.Infof("start command", "dbdir", dbDir, "resourcedir", resourceDir, "outputsize", size, "keyfile", sshKeyFile, "host", host, "port", port)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						pfp := path.Join(scriptDir, "pp.csv")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						cfg := engine.Config{
 | 
				
			||||||
 | 
							Root:       "root",
 | 
				
			||||||
 | 
							OutputSize: uint32(size),
 | 
				
			||||||
 | 
							FlagCount:  uint32(16),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if stateDebug {
 | 
				
			||||||
 | 
							cfg.StateDebug = true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if engineDebug {
 | 
				
			||||||
 | 
							cfg.EngineDebug = true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						authKeyStore, err := ssh.NewSshKeyStore(ctx, dbDir)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							fmt.Fprintf(os.Stderr, "keystore file open error: %v", err)
 | 
				
			||||||
 | 
							os.Exit(1)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer func () {
 | 
				
			||||||
 | 
							logg.TraceCtxf(ctx, "shutdown auth key store reached")
 | 
				
			||||||
 | 
							err = authKeyStore.Close()
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								logg.ErrorCtxf(ctx, "keystore close error", "err", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						cint := make(chan os.Signal)
 | 
				
			||||||
 | 
						cterm := make(chan os.Signal)
 | 
				
			||||||
 | 
						signal.Notify(cint, os.Interrupt, syscall.SIGINT)
 | 
				
			||||||
 | 
						signal.Notify(cterm, os.Interrupt, syscall.SIGTERM)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						runner := &ssh.SshRunner{
 | 
				
			||||||
 | 
							Cfg: cfg,
 | 
				
			||||||
 | 
							Debug: engineDebug,
 | 
				
			||||||
 | 
							FlagFile: pfp,
 | 
				
			||||||
 | 
							DbDir: dbDir,
 | 
				
			||||||
 | 
							ResourceDir: resourceDir,
 | 
				
			||||||
 | 
							SrvKeyFile: sshKeyFile,
 | 
				
			||||||
 | 
							Host: host,
 | 
				
			||||||
 | 
							Port: port,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						go func() {
 | 
				
			||||||
 | 
							select {
 | 
				
			||||||
 | 
							case _ = <-cint:
 | 
				
			||||||
 | 
							case _ = <-cterm:
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							logg.TraceCtxf(ctx, "shutdown runner reached")
 | 
				
			||||||
 | 
							err := runner.Stop()
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								logg.ErrorCtxf(ctx, "runner stop error", "err", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
						runner.Run(ctx, authKeyStore)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										44
									
								
								cmd/ssh/sshkey/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								cmd/ssh/sshkey/main.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,44 @@
 | 
				
			|||||||
 | 
					package main
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
 | 
						"flag"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"os"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"git.grassecon.net/urdt/ussd/internal/ssh"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func main() {
 | 
				
			||||||
 | 
						var dbDir string
 | 
				
			||||||
 | 
						var sessionId string
 | 
				
			||||||
 | 
						flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from")
 | 
				
			||||||
 | 
						flag.StringVar(&sessionId, "i", "", "session id")
 | 
				
			||||||
 | 
						flag.Parse()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if sessionId == "" {
 | 
				
			||||||
 | 
							fmt.Fprintf(os.Stderr, "empty session id\n")
 | 
				
			||||||
 | 
							os.Exit(1)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						ctx := context.Background()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						sshKeyFile := flag.Arg(0)
 | 
				
			||||||
 | 
						if sshKeyFile == "" {
 | 
				
			||||||
 | 
							fmt.Fprintf(os.Stderr, "missing key file argument\n")
 | 
				
			||||||
 | 
							os.Exit(1)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						store, err := ssh.NewSshKeyStore(ctx, dbDir)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							fmt.Fprintf(os.Stderr, "%v\n", err)
 | 
				
			||||||
 | 
							os.Exit(1)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer store.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = store.AddFromFile(ctx, sshKeyFile, sessionId)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							fmt.Fprintf(os.Stderr, "%v\n", err)
 | 
				
			||||||
 | 
							os.Exit(1)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -6,9 +6,9 @@ import (
 | 
				
			|||||||
	"io"
 | 
						"io"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"git.defalsify.org/vise.git/engine"
 | 
						"git.defalsify.org/vise.git/engine"
 | 
				
			||||||
	"git.defalsify.org/vise.git/resource"
 | 
					 | 
				
			||||||
	"git.defalsify.org/vise.git/persist"
 | 
					 | 
				
			||||||
	"git.defalsify.org/vise.git/logging"
 | 
						"git.defalsify.org/vise.git/logging"
 | 
				
			||||||
 | 
						"git.defalsify.org/vise.git/persist"
 | 
				
			||||||
 | 
						"git.defalsify.org/vise.git/resource"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"git.grassecon.net/urdt/ussd/internal/storage"
 | 
						"git.grassecon.net/urdt/ussd/internal/storage"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
@ -20,33 +20,33 @@ var (
 | 
				
			|||||||
var (
 | 
					var (
 | 
				
			||||||
	ErrInvalidRequest = errors.New("invalid request for context")
 | 
						ErrInvalidRequest = errors.New("invalid request for context")
 | 
				
			||||||
	ErrSessionMissing = errors.New("missing session")
 | 
						ErrSessionMissing = errors.New("missing session")
 | 
				
			||||||
	ErrInvalidInput = errors.New("invalid input")
 | 
						ErrInvalidInput   = errors.New("invalid input")
 | 
				
			||||||
	ErrStorage = errors.New("storage retrieval fail")
 | 
						ErrStorage        = errors.New("storage retrieval fail")
 | 
				
			||||||
	ErrEngineType = errors.New("incompatible engine")
 | 
						ErrEngineType     = errors.New("incompatible engine")
 | 
				
			||||||
	ErrEngineInit = errors.New("engine init fail")
 | 
						ErrEngineInit     = errors.New("engine init fail")
 | 
				
			||||||
	ErrEngineExec = errors.New("engine exec fail")
 | 
						ErrEngineExec     = errors.New("engine exec fail")
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type RequestSession struct {
 | 
					type RequestSession struct {
 | 
				
			||||||
	Ctx context.Context
 | 
						Ctx      context.Context
 | 
				
			||||||
	Config engine.Config
 | 
						Config   engine.Config
 | 
				
			||||||
	Engine engine.Engine
 | 
						Engine   engine.Engine
 | 
				
			||||||
	Input []byte
 | 
						Input    []byte
 | 
				
			||||||
	Storage *storage.Storage
 | 
						Storage  *storage.Storage
 | 
				
			||||||
	Writer io.Writer
 | 
						Writer   io.Writer
 | 
				
			||||||
	Continue bool
 | 
						Continue bool
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// TODO: seems like can remove this.
 | 
					// TODO: seems like can remove this.
 | 
				
			||||||
type RequestParser interface {
 | 
					type RequestParser interface {
 | 
				
			||||||
	GetSessionId(rq any) (string, error)
 | 
						GetSessionId(context context.Context, rq any) (string, error)
 | 
				
			||||||
	GetInput(rq any) ([]byte, error)
 | 
						GetInput(rq any) ([]byte, error)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type RequestHandler interface {
 | 
					type RequestHandler interface {
 | 
				
			||||||
	GetConfig() engine.Config
 | 
						GetConfig() engine.Config
 | 
				
			||||||
	GetRequestParser() RequestParser
 | 
						GetRequestParser() RequestParser
 | 
				
			||||||
	GetEngine(cfg engine.Config, rs resource.Resource, pe *persist.Persister) engine.Engine 
 | 
						GetEngine(cfg engine.Config, rs resource.Resource, pe *persist.Persister) engine.Engine
 | 
				
			||||||
	Process(rs RequestSession) (RequestSession, error)
 | 
						Process(rs RequestSession) (RequestSession, error)
 | 
				
			||||||
	Output(rs RequestSession) (RequestSession, error)
 | 
						Output(rs RequestSession) (RequestSession, error)
 | 
				
			||||||
	Reset(rs RequestSession) (RequestSession, error)
 | 
						Reset(rs RequestSession) (RequestSession, error)
 | 
				
			||||||
 | 
				
			|||||||
@ -28,7 +28,7 @@ import (
 | 
				
			|||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var (
 | 
					var (
 | 
				
			||||||
	logg           = logging.NewVanilla().WithDomain("ussdmenuhandler").WithContextKey("session-id")
 | 
						logg           = logging.NewVanilla().WithDomain("ussdmenuhandler").WithContextKey("SessionId")
 | 
				
			||||||
	scriptDir      = path.Join("services", "registration")
 | 
						scriptDir      = path.Join("services", "registration")
 | 
				
			||||||
	translationDir = path.Join(scriptDir, "locale")
 | 
						translationDir = path.Join(scriptDir, "locale")
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
@ -124,7 +124,7 @@ func (h *Handlers) Init(ctx context.Context, sym string, input []byte) (resource
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	sessionId, ok := ctx.Value("SessionId").(string)
 | 
						sessionId, ok := ctx.Value("SessionId").(string)
 | 
				
			||||||
	if ok {
 | 
						if ok {
 | 
				
			||||||
		context.WithValue(ctx, "session-id", sessionId)
 | 
							ctx = context.WithValue(ctx, "SessionId", sessionId)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	flag_admin_privilege, _ := h.flagManager.GetFlag("flag_admin_privilege")
 | 
						flag_admin_privilege, _ := h.flagManager.GetFlag("flag_admin_privilege")
 | 
				
			||||||
 | 
				
			|||||||
@ -15,16 +15,14 @@ import (
 | 
				
			|||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type ATRequestParser struct {
 | 
					type ATRequestParser struct {
 | 
				
			||||||
	Context context.Context
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (arp *ATRequestParser) GetSessionId(rq any) (string, error) {
 | 
					func (arp *ATRequestParser) GetSessionId(ctx context.Context, rq any) (string, error) {
 | 
				
			||||||
	rqv, ok := rq.(*http.Request)
 | 
						rqv, ok := rq.(*http.Request)
 | 
				
			||||||
	if !ok {
 | 
						if !ok {
 | 
				
			||||||
		logg.Warnf("got an invalid request", "req", rq)
 | 
							logg.Warnf("got an invalid request", "req", rq)
 | 
				
			||||||
		return "", handlers.ErrInvalidRequest
 | 
							return "", handlers.ErrInvalidRequest
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					 | 
				
			||||||
	// Capture body (if any) for logging
 | 
						// Capture body (if any) for logging
 | 
				
			||||||
	body, err := io.ReadAll(rqv.Body)
 | 
						body, err := io.ReadAll(rqv.Body)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
@ -43,9 +41,9 @@ func (arp *ATRequestParser) GetSessionId(rq any) (string, error) {
 | 
				
			|||||||
		decodedStr := string(logBytes)
 | 
							decodedStr := string(logBytes)
 | 
				
			||||||
		sessionId, err := extractATSessionId(decodedStr)
 | 
							sessionId, err := extractATSessionId(decodedStr)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			context.WithValue(arp.Context, "at-session-id", sessionId)
 | 
								ctx = context.WithValue(ctx, "AT-SessionId", sessionId)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		logg.Debugf("Received request:", decodedStr)
 | 
							logg.DebugCtxf(ctx, "Received request:", decodedStr)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if err := rqv.ParseForm(); err != nil {
 | 
						if err := rqv.ParseForm(); err != nil {
 | 
				
			||||||
 | 
				
			|||||||
@ -10,7 +10,7 @@ import (
 | 
				
			|||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var (
 | 
					var (
 | 
				
			||||||
	logg = logging.NewVanilla().WithDomain("atserver")
 | 
						logg = logging.NewVanilla().WithDomain("atserver").WithContextKey("SessionId").WithContextKey("AT-SessionId")
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type ATSessionHandler struct {
 | 
					type ATSessionHandler struct {
 | 
				
			||||||
@ -34,7 +34,7 @@ func (ash *ATSessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request)
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	rp := ash.GetRequestParser()
 | 
						rp := ash.GetRequestParser()
 | 
				
			||||||
	cfg := ash.GetConfig()
 | 
						cfg := ash.GetConfig()
 | 
				
			||||||
	cfg.SessionId, err = rp.GetSessionId(req)
 | 
						cfg.SessionId, err = rp.GetSessionId(req.Context(), req)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		logg.ErrorCtxf(rqs.Ctx, "", "header processing error", err)
 | 
							logg.ErrorCtxf(rqs.Ctx, "", "header processing error", err)
 | 
				
			||||||
		ash.WriteError(w, 400, err)
 | 
							ash.WriteError(w, 400, err)
 | 
				
			||||||
@ -48,7 +48,7 @@ func (ash *ATSessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request)
 | 
				
			|||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	rqs, err = ash.Process(rqs) 
 | 
						rqs, err = ash.Process(rqs)
 | 
				
			||||||
	switch err {
 | 
						switch err {
 | 
				
			||||||
	case nil: // set code to 200 if no err
 | 
						case nil: // set code to 200 if no err
 | 
				
			||||||
		code = 200
 | 
							code = 200
 | 
				
			||||||
 | 
				
			|||||||
@ -1,6 +1,7 @@
 | 
				
			|||||||
package http
 | 
					package http
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
	"io/ioutil"
 | 
						"io/ioutil"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -10,7 +11,7 @@ import (
 | 
				
			|||||||
type DefaultRequestParser struct {
 | 
					type DefaultRequestParser struct {
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (rp *DefaultRequestParser) GetSessionId(rq any) (string, error) {
 | 
					func (rp *DefaultRequestParser) GetSessionId(ctx context.Context, rq any) (string, error) {
 | 
				
			||||||
	rqv, ok := rq.(*http.Request)
 | 
						rqv, ok := rq.(*http.Request)
 | 
				
			||||||
	if !ok {
 | 
						if !ok {
 | 
				
			||||||
		return "", handlers.ErrInvalidRequest
 | 
							return "", handlers.ErrInvalidRequest
 | 
				
			||||||
@ -34,5 +35,3 @@ func (rp *DefaultRequestParser) GetInput(rq any) ([]byte, error) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	return v, nil
 | 
						return v, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -46,7 +46,7 @@ func (f *SessionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	rp := f.GetRequestParser()
 | 
						rp := f.GetRequestParser()
 | 
				
			||||||
	cfg := f.GetConfig()
 | 
						cfg := f.GetConfig()
 | 
				
			||||||
	cfg.SessionId, err = rp.GetSessionId(req)
 | 
						cfg.SessionId, err = rp.GetSessionId(req.Context(), req)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		logg.ErrorCtxf(rqs.Ctx, "", "header processing error", err)
 | 
							logg.ErrorCtxf(rqs.Ctx, "", "header processing error", err)
 | 
				
			||||||
		f.WriteError(w, 400, err)
 | 
							f.WriteError(w, 400, err)
 | 
				
			||||||
 | 
				
			|||||||
@ -2,6 +2,7 @@ package http
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"bytes"
 | 
						"bytes"
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"net/http/httptest"
 | 
						"net/http/httptest"
 | 
				
			||||||
@ -161,7 +162,7 @@ func TestDefaultRequestParser_GetSessionId(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	for _, tt := range tests {
 | 
						for _, tt := range tests {
 | 
				
			||||||
		t.Run(tt.name, func(t *testing.T) {
 | 
							t.Run(tt.name, func(t *testing.T) {
 | 
				
			||||||
			id, err := parser.GetSessionId(tt.request)
 | 
								id, err := parser.GetSessionId(context.Background(),tt.request)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			if id != tt.expectedID {
 | 
								if id != tt.expectedID {
 | 
				
			||||||
				t.Errorf("Expected session ID %s, got %s", tt.expectedID, id)
 | 
									t.Errorf("Expected session ID %s, got %s", tt.expectedID, id)
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										65
									
								
								internal/ssh/keystore.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										65
									
								
								internal/ssh/keystore.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,65 @@
 | 
				
			|||||||
 | 
					package ssh
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"os"
 | 
				
			||||||
 | 
						"path"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"golang.org/x/crypto/ssh"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"git.defalsify.org/vise.git/db"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"git.grassecon.net/urdt/ussd/internal/storage"
 | 
				
			||||||
 | 
						dbstorage "git.grassecon.net/urdt/ussd/internal/storage/db/gdbm"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type SshKeyStore struct {
 | 
				
			||||||
 | 
						store db.Db
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewSshKeyStore(ctx context.Context, dbDir string) (*SshKeyStore, error) {
 | 
				
			||||||
 | 
						keyStore := &SshKeyStore{}
 | 
				
			||||||
 | 
						keyStoreFile := path.Join(dbDir, "ssh_authorized_keys.gdbm")
 | 
				
			||||||
 | 
						keyStore.store = dbstorage.NewThreadGdbmDb()
 | 
				
			||||||
 | 
						err := keyStore.store.Connect(ctx, keyStoreFile)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return keyStore, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func(s *SshKeyStore) AddFromFile(ctx context.Context, fp string, sessionId string) error {
 | 
				
			||||||
 | 
						_, err := os.Stat(fp)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return fmt.Errorf("cannot open ssh server public key file: %v\n", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						publicBytes, err := os.ReadFile(fp)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return fmt.Errorf("Failed to load public key: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						pubKey, _, _, _, err := ssh.ParseAuthorizedKey(publicBytes)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return fmt.Errorf("Failed to parse public key: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						k := append([]byte{0x01}, pubKey.Marshal()...)
 | 
				
			||||||
 | 
						s.store.SetPrefix(storage.DATATYPE_EXTEND)
 | 
				
			||||||
 | 
						logg.Infof("Added key", "sessionId", sessionId, "public key", string(publicBytes))
 | 
				
			||||||
 | 
						return s.store.Put(ctx, k, []byte(sessionId))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func(s *SshKeyStore) Get(ctx context.Context, pubKey ssh.PublicKey) (string, error) {
 | 
				
			||||||
 | 
						s.store.SetLanguage(nil)
 | 
				
			||||||
 | 
						s.store.SetPrefix(storage.DATATYPE_EXTEND)
 | 
				
			||||||
 | 
						k := append([]byte{0x01}, pubKey.Marshal()...)
 | 
				
			||||||
 | 
						v, err := s.store.Get(ctx, k)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return "", err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return string(v), nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func(s *SshKeyStore) Close() error {
 | 
				
			||||||
 | 
						return s.store.Close()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										287
									
								
								internal/ssh/ssh.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										287
									
								
								internal/ssh/ssh.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,287 @@
 | 
				
			|||||||
 | 
					package ssh
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
 | 
						"encoding/hex"
 | 
				
			||||||
 | 
						"encoding/base64"
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"net"
 | 
				
			||||||
 | 
						"os"
 | 
				
			||||||
 | 
						"sync"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"golang.org/x/crypto/ssh"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"git.defalsify.org/vise.git/engine"
 | 
				
			||||||
 | 
						"git.defalsify.org/vise.git/logging"
 | 
				
			||||||
 | 
						"git.defalsify.org/vise.git/resource"
 | 
				
			||||||
 | 
						"git.defalsify.org/vise.git/state"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"git.grassecon.net/urdt/ussd/internal/handlers"
 | 
				
			||||||
 | 
						"git.grassecon.net/urdt/ussd/internal/storage"
 | 
				
			||||||
 | 
						"git.grassecon.net/urdt/ussd/remote"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var (
 | 
				
			||||||
 | 
						logg = logging.NewVanilla().WithDomain("ssh")
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type auther struct {
 | 
				
			||||||
 | 
						Ctx context.Context
 | 
				
			||||||
 | 
						keyStore *SshKeyStore
 | 
				
			||||||
 | 
						auth map[string]string
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewAuther(ctx context.Context, keyStore *SshKeyStore) *auther {
 | 
				
			||||||
 | 
						return &auther{
 | 
				
			||||||
 | 
							Ctx: ctx,
 | 
				
			||||||
 | 
							keyStore: keyStore,
 | 
				
			||||||
 | 
							auth: make(map[string]string),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func(a *auther) Check(conn ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
 | 
				
			||||||
 | 
						va, err := a.keyStore.Get(a.Ctx, pubKey)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						ka := hex.EncodeToString(conn.SessionID())
 | 
				
			||||||
 | 
						a.auth[ka] = va 
 | 
				
			||||||
 | 
						fmt.Fprintf(os.Stderr, "connect: %s -> %s\n", ka, va)
 | 
				
			||||||
 | 
						return nil, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func(a *auther) FromConn(c *ssh.ServerConn) (string, error) {
 | 
				
			||||||
 | 
						if c == nil {
 | 
				
			||||||
 | 
							return "", errors.New("nil server conn")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if c.Conn == nil {
 | 
				
			||||||
 | 
							return "", errors.New("nil underlying conn")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return a.Get(c.Conn.SessionID())
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func(a *auther) Get(k []byte) (string, error) {
 | 
				
			||||||
 | 
						ka := hex.EncodeToString(k)
 | 
				
			||||||
 | 
						v, ok := a.auth[ka]
 | 
				
			||||||
 | 
						if !ok {
 | 
				
			||||||
 | 
							return "", errors.New("not found")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return v, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func(s *SshRunner) serve(ctx context.Context, sessionId string, ch ssh.NewChannel, en engine.Engine) error {
 | 
				
			||||||
 | 
						if ch == nil {
 | 
				
			||||||
 | 
							return errors.New("nil channel")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if ch.ChannelType() != "session" {
 | 
				
			||||||
 | 
							ch.Reject(ssh.UnknownChannelType, "that is not the channel you are looking for")
 | 
				
			||||||
 | 
							return errors.New("not a session")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						channel, requests, err := ch.Accept()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							panic(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer channel.Close()
 | 
				
			||||||
 | 
						s.wg.Add(1)
 | 
				
			||||||
 | 
						go func(reqIn <-chan *ssh.Request) {
 | 
				
			||||||
 | 
							defer s.wg.Done()
 | 
				
			||||||
 | 
							for req := range reqIn {
 | 
				
			||||||
 | 
								req.Reply(req.Type == "shell", nil)	
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							_ = requests
 | 
				
			||||||
 | 
						}(requests)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						cont, err := en.Exec(ctx, []byte{})
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return fmt.Errorf("initial engine exec err: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var input [state.INPUT_LIMIT]byte
 | 
				
			||||||
 | 
						for cont {
 | 
				
			||||||
 | 
							c, err := en.Flush(ctx, channel)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return fmt.Errorf("flush err: %v", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							_, err = channel.Write([]byte{0x0a})
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return fmt.Errorf("newline err: %v", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							c, err = channel.Read(input[:])
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return fmt.Errorf("read input fail: %v", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							logg.TraceCtxf(ctx, "input read", "c", c, "input", input[:c-1])
 | 
				
			||||||
 | 
							cont, err = en.Exec(ctx, input[:c-1])
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return fmt.Errorf("engine exec err: %v", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							logg.TraceCtxf(ctx, "exec cont", "cont", cont, "en", en)
 | 
				
			||||||
 | 
							_ = c
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						c, err := en.Flush(ctx, channel)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return fmt.Errorf("last flush err: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						_ = c
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type SshRunner struct {
 | 
				
			||||||
 | 
						Ctx context.Context
 | 
				
			||||||
 | 
						Cfg engine.Config
 | 
				
			||||||
 | 
						FlagFile string
 | 
				
			||||||
 | 
						DbDir string
 | 
				
			||||||
 | 
						ResourceDir string
 | 
				
			||||||
 | 
						Debug bool
 | 
				
			||||||
 | 
						SrvKeyFile string
 | 
				
			||||||
 | 
						Host string
 | 
				
			||||||
 | 
						Port uint
 | 
				
			||||||
 | 
						wg sync.WaitGroup
 | 
				
			||||||
 | 
						lst net.Listener
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func(s *SshRunner) Stop() error {
 | 
				
			||||||
 | 
						return s.lst.Close()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func(s *SshRunner) GetEngine(sessionId string) (engine.Engine, func(), error) {
 | 
				
			||||||
 | 
						ctx := s.Ctx
 | 
				
			||||||
 | 
						menuStorageService := storage.NewMenuStorageService(s.DbDir, s.ResourceDir)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err := menuStorageService.EnsureDbDir()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						rs, err := menuStorageService.GetResource(ctx)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						pe, err := menuStorageService.GetPersister(ctx)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						userdatastore, err := menuStorageService.GetUserdataDb(ctx)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						dbResource, ok := rs.(*resource.DbResource)
 | 
				
			||||||
 | 
						if !ok {
 | 
				
			||||||
 | 
							return nil, nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						lhs, err := handlers.NewLocalHandlerService(ctx, s.FlagFile, true, dbResource, s.Cfg, rs)
 | 
				
			||||||
 | 
						lhs.SetDataStore(&userdatastore)
 | 
				
			||||||
 | 
						lhs.SetPersister(pe)
 | 
				
			||||||
 | 
						lhs.Cfg.SessionId = sessionId
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// TODO: clear up why pointer here and by-value other cmds
 | 
				
			||||||
 | 
						accountService := &remote.AccountService{}
 | 
				
			||||||
 | 
						hl, err := lhs.GetHandler(accountService)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						en := lhs.GetEngine()
 | 
				
			||||||
 | 
						en = en.WithFirst(hl.Init)
 | 
				
			||||||
 | 
						if s.Debug {
 | 
				
			||||||
 | 
							en = en.WithDebug(nil)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						// TODO: this is getting very hacky!
 | 
				
			||||||
 | 
						closer := func() {
 | 
				
			||||||
 | 
							err := menuStorageService.Close()
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								logg.ErrorCtxf(ctx, "menu storage service cleanup fail", "err", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return en, closer, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// adapted example from crypto/ssh package, NewServerConn doc
 | 
				
			||||||
 | 
					func(s *SshRunner) Run(ctx context.Context, keyStore *SshKeyStore) {
 | 
				
			||||||
 | 
						running := true
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// TODO: waitgroup should probably not be global
 | 
				
			||||||
 | 
						defer s.wg.Wait()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						auth := NewAuther(ctx, keyStore)
 | 
				
			||||||
 | 
						cfg := ssh.ServerConfig{
 | 
				
			||||||
 | 
							PublicKeyCallback: auth.Check,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						privateBytes, err := os.ReadFile(s.SrvKeyFile)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							logg.ErrorCtxf(ctx, "Failed to load private key", "err", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						private, err := ssh.ParsePrivateKey(privateBytes)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							logg.ErrorCtxf(ctx, "Failed to parse private key", "err", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						srvPub := private.PublicKey()
 | 
				
			||||||
 | 
						srvPubStr := base64.StdEncoding.EncodeToString(srvPub.Marshal())
 | 
				
			||||||
 | 
						logg.InfoCtxf(ctx, "have server key", "type", srvPub.Type(), "public", srvPubStr)
 | 
				
			||||||
 | 
						cfg.AddHostKey(private)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						s.lst, err = net.Listen("tcp", fmt.Sprintf("%s:%d", s.Host, s.Port))
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							panic(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for running {
 | 
				
			||||||
 | 
							conn, err := s.lst.Accept()
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								logg.ErrorCtxf(ctx, "ssh accept error", "err", err)
 | 
				
			||||||
 | 
								running = false
 | 
				
			||||||
 | 
								continue
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							go func(conn net.Conn) {
 | 
				
			||||||
 | 
								defer conn.Close()
 | 
				
			||||||
 | 
								for true {
 | 
				
			||||||
 | 
									srvConn, nC, rC, err := ssh.NewServerConn(conn, &cfg)
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										logg.InfoCtxf(ctx, "rejected client", "err", err)
 | 
				
			||||||
 | 
										return
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									logg.DebugCtxf(ctx, "ssh client connected", "conn", srvConn)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									s.wg.Add(1)
 | 
				
			||||||
 | 
									go func() {
 | 
				
			||||||
 | 
										ssh.DiscardRequests(rC)
 | 
				
			||||||
 | 
										s.wg.Done()
 | 
				
			||||||
 | 
									}()
 | 
				
			||||||
 | 
									
 | 
				
			||||||
 | 
									sessionId, err := auth.FromConn(srvConn)
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										logg.ErrorCtxf(ctx, "Cannot find authentication")
 | 
				
			||||||
 | 
										return
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									en, closer, err := s.GetEngine(sessionId)
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										logg.ErrorCtxf(ctx, "engine won't start", "err", err)
 | 
				
			||||||
 | 
										return
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									defer func() {
 | 
				
			||||||
 | 
										err := en.Finish()
 | 
				
			||||||
 | 
										if err != nil {
 | 
				
			||||||
 | 
											logg.ErrorCtxf(ctx, "engine won't stop", "err", err)
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
										closer()
 | 
				
			||||||
 | 
									}()
 | 
				
			||||||
 | 
									for ch := range nC {
 | 
				
			||||||
 | 
										err = s.serve(ctx, sessionId, ch, en)
 | 
				
			||||||
 | 
										logg.ErrorCtxf(ctx, "ssh server finish", "err", err)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}(conn)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -5,6 +5,10 @@ import (
 | 
				
			|||||||
	"git.defalsify.org/vise.git/persist"
 | 
						"git.defalsify.org/vise.git/persist"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const (
 | 
				
			||||||
 | 
						DATATYPE_EXTEND = 128
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type Storage struct {
 | 
					type Storage struct {
 | 
				
			||||||
	Persister *persist.Persister
 | 
						Persister *persist.Persister
 | 
				
			||||||
	UserdataDb db.Db	
 | 
						UserdataDb db.Db	
 | 
				
			||||||
 | 
				
			|||||||
@ -1,12 +1,14 @@
 | 
				
			|||||||
package httpmocks
 | 
					package httpmocks
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import "context"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// MockRequestParser implements the handlers.RequestParser interface for testing
 | 
					// MockRequestParser implements the handlers.RequestParser interface for testing
 | 
				
			||||||
type MockRequestParser struct {
 | 
					type MockRequestParser struct {
 | 
				
			||||||
	GetSessionIdFunc func(any) (string, error)
 | 
						GetSessionIdFunc func(any) (string, error)
 | 
				
			||||||
	GetInputFunc     func(any) ([]byte, error)
 | 
						GetInputFunc     func(any) ([]byte, error)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (m *MockRequestParser) GetSessionId(rq any) (string, error) {
 | 
					func (m *MockRequestParser) GetSessionId(ctx context.Context, rq any) (string, error) {
 | 
				
			||||||
	return m.GetSessionIdFunc(rq)
 | 
						return m.GetSessionIdFunc(rq)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user