Compare commits
6 Commits
master
...
lash/ssh-4
Author | SHA1 | Date | |
---|---|---|---|
9758fd4941 | |||
b02f4bc97e | |||
|
967e53d83b | ||
|
d246cdee51 | ||
|
d518a76536 | ||
|
6f65c33be4 |
@ -29,10 +29,10 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
logg = logging.NewVanilla()
|
||||
scriptDir = path.Join("services", "registration")
|
||||
build = "dev"
|
||||
menuSeparator = ": "
|
||||
logg = logging.NewVanilla()
|
||||
scriptDir = path.Join("services", "registration")
|
||||
|
||||
build = "dev"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@ -130,10 +130,9 @@ func main() {
|
||||
pfp := path.Join(scriptDir, "pp.csv")
|
||||
|
||||
cfg := engine.Config{
|
||||
Root: "root",
|
||||
OutputSize: uint32(size),
|
||||
FlagCount: uint32(128),
|
||||
MenuSeparator: menuSeparator,
|
||||
Root: "root",
|
||||
OutputSize: uint32(size),
|
||||
FlagCount: uint32(128),
|
||||
}
|
||||
|
||||
if engineDebug {
|
||||
|
@ -23,7 +23,6 @@ import (
|
||||
var (
|
||||
logg = logging.NewVanilla()
|
||||
scriptDir = path.Join("services", "registration")
|
||||
menuSeparator = ": "
|
||||
)
|
||||
|
||||
func init() {
|
||||
@ -71,10 +70,9 @@ func main() {
|
||||
pfp := path.Join(scriptDir, "pp.csv")
|
||||
|
||||
cfg := engine.Config{
|
||||
Root: "root",
|
||||
OutputSize: uint32(size),
|
||||
FlagCount: uint32(128),
|
||||
MenuSeparator: menuSeparator,
|
||||
Root: "root",
|
||||
OutputSize: uint32(size),
|
||||
FlagCount: uint32(128),
|
||||
}
|
||||
|
||||
if engineDebug {
|
||||
|
@ -26,7 +26,6 @@ import (
|
||||
var (
|
||||
logg = logging.NewVanilla()
|
||||
scriptDir = path.Join("services", "registration")
|
||||
menuSeparator = ": "
|
||||
)
|
||||
|
||||
func init() {
|
||||
@ -59,10 +58,9 @@ func main() {
|
||||
pfp := path.Join(scriptDir, "pp.csv")
|
||||
|
||||
cfg := engine.Config{
|
||||
Root: "root",
|
||||
OutputSize: uint32(size),
|
||||
FlagCount: uint32(128),
|
||||
MenuSeparator: menuSeparator,
|
||||
Root: "root",
|
||||
OutputSize: uint32(size),
|
||||
FlagCount: uint32(128),
|
||||
}
|
||||
|
||||
if engineDebug {
|
||||
|
14
cmd/main.go
14
cmd/main.go
@ -18,9 +18,8 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
logg = logging.NewVanilla()
|
||||
scriptDir = path.Join("services", "registration")
|
||||
menuSeparator = ": "
|
||||
logg = logging.NewVanilla()
|
||||
scriptDir = path.Join("services", "registration")
|
||||
)
|
||||
|
||||
func init() {
|
||||
@ -50,11 +49,10 @@ func main() {
|
||||
pfp := path.Join(scriptDir, "pp.csv")
|
||||
|
||||
cfg := engine.Config{
|
||||
Root: "root",
|
||||
SessionId: sessionId,
|
||||
OutputSize: uint32(size),
|
||||
FlagCount: uint32(128),
|
||||
MenuSeparator: menuSeparator,
|
||||
Root: "root",
|
||||
SessionId: sessionId,
|
||||
OutputSize: uint32(size),
|
||||
FlagCount: uint32(128),
|
||||
}
|
||||
|
||||
resourceDir := scriptDir
|
||||
|
34
cmd/ssh/README.md
Normal file
34
cmd/ssh/README.md
Normal file
@ -0,0 +1,34 @@
|
||||
# URDT-USSD SSH server
|
||||
|
||||
An SSH server entry point for the vise engine.
|
||||
|
||||
|
||||
## Adding public keys for access
|
||||
|
||||
Map your (client) public key to a session identifier (e.g. phone number)
|
||||
|
||||
```
|
||||
go run -v -tags logtrace ./cmd/ssh/sshkey/main.go -i <session_id> [--dbdir <dbpath>] <client_publickey_filepath>
|
||||
```
|
||||
|
||||
|
||||
## Create a private key for the server
|
||||
|
||||
```
|
||||
ssh-keygen -N "" -f <server_privatekey_filepath>
|
||||
```
|
||||
|
||||
|
||||
## Run the server
|
||||
|
||||
|
||||
```
|
||||
go run -v -tags logtrace ./cmd/ssh/main.go -h <host> -p <port> [--dbdir <dbpath>] <server_privatekey_filepath>
|
||||
```
|
||||
|
||||
|
||||
## Connect to the server
|
||||
|
||||
```
|
||||
ssh [-v] -T -p <port> -i <client_publickey_filepath> <host>
|
||||
```
|
115
cmd/ssh/main.go
Normal file
115
cmd/ssh/main.go
Normal file
@ -0,0 +1,115 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"path"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"git.defalsify.org/vise.git/db"
|
||||
"git.defalsify.org/vise.git/engine"
|
||||
"git.defalsify.org/vise.git/logging"
|
||||
|
||||
"git.grassecon.net/urdt/ussd/internal/ssh"
|
||||
)
|
||||
|
||||
var (
|
||||
wg sync.WaitGroup
|
||||
keyStore db.Db
|
||||
logg = logging.NewVanilla()
|
||||
scriptDir = path.Join("services", "registration")
|
||||
)
|
||||
|
||||
func main() {
|
||||
var dbDir string
|
||||
var resourceDir string
|
||||
var size uint
|
||||
var engineDebug bool
|
||||
var stateDebug bool
|
||||
var host string
|
||||
var port uint
|
||||
flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from")
|
||||
flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir")
|
||||
flag.BoolVar(&engineDebug, "engine-debug", false, "use engine debug output")
|
||||
flag.BoolVar(&stateDebug, "state-debug", false, "use engine debug output")
|
||||
flag.UintVar(&size, "s", 160, "max size of output")
|
||||
flag.StringVar(&host, "h", "127.0.0.1", "http host")
|
||||
flag.UintVar(&port, "p", 7122, "http port")
|
||||
flag.Parse()
|
||||
|
||||
sshKeyFile := flag.Arg(0)
|
||||
_, err := os.Stat(sshKeyFile)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "cannot open ssh server private key file: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
logg.WarnCtxf(ctx, "!!!!! WARNING WARNING WARNING")
|
||||
logg.WarnCtxf(ctx, "!!!!! =======================")
|
||||
logg.WarnCtxf(ctx, "!!!!! This is not a production ready server!")
|
||||
logg.WarnCtxf(ctx, "!!!!! Do not expose to internet and only use with tunnel!")
|
||||
logg.WarnCtxf(ctx, "!!!!! (See ssh -L <...>)")
|
||||
|
||||
logg.Infof("start command", "dbdir", dbDir, "resourcedir", resourceDir, "outputsize", size, "keyfile", sshKeyFile, "host", host, "port", port)
|
||||
|
||||
pfp := path.Join(scriptDir, "pp.csv")
|
||||
|
||||
cfg := engine.Config{
|
||||
Root: "root",
|
||||
OutputSize: uint32(size),
|
||||
FlagCount: uint32(16),
|
||||
}
|
||||
if stateDebug {
|
||||
cfg.StateDebug = true
|
||||
}
|
||||
if engineDebug {
|
||||
cfg.EngineDebug = true
|
||||
}
|
||||
|
||||
authKeyStore, err := ssh.NewSshKeyStore(ctx, dbDir)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "keystore file open error: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer func () {
|
||||
logg.TraceCtxf(ctx, "shutdown auth key store reached")
|
||||
err = authKeyStore.Close()
|
||||
if err != nil {
|
||||
logg.ErrorCtxf(ctx, "keystore close error", "err", err)
|
||||
}
|
||||
}()
|
||||
|
||||
cint := make(chan os.Signal)
|
||||
cterm := make(chan os.Signal)
|
||||
signal.Notify(cint, os.Interrupt, syscall.SIGINT)
|
||||
signal.Notify(cterm, os.Interrupt, syscall.SIGTERM)
|
||||
|
||||
runner := &ssh.SshRunner{
|
||||
Cfg: cfg,
|
||||
Debug: engineDebug,
|
||||
FlagFile: pfp,
|
||||
DbDir: dbDir,
|
||||
ResourceDir: resourceDir,
|
||||
SrvKeyFile: sshKeyFile,
|
||||
Host: host,
|
||||
Port: port,
|
||||
}
|
||||
go func() {
|
||||
select {
|
||||
case _ = <-cint:
|
||||
case _ = <-cterm:
|
||||
}
|
||||
logg.TraceCtxf(ctx, "shutdown runner reached")
|
||||
err := runner.Stop()
|
||||
if err != nil {
|
||||
logg.ErrorCtxf(ctx, "runner stop error", "err", err)
|
||||
}
|
||||
|
||||
}()
|
||||
runner.Run(ctx, authKeyStore)
|
||||
}
|
44
cmd/ssh/sshkey/main.go
Normal file
44
cmd/ssh/sshkey/main.go
Normal file
@ -0,0 +1,44 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"git.grassecon.net/urdt/ussd/internal/ssh"
|
||||
)
|
||||
|
||||
func main() {
|
||||
var dbDir string
|
||||
var sessionId string
|
||||
flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from")
|
||||
flag.StringVar(&sessionId, "i", "", "session id")
|
||||
flag.Parse()
|
||||
|
||||
if sessionId == "" {
|
||||
fmt.Fprintf(os.Stderr, "empty session id\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
sshKeyFile := flag.Arg(0)
|
||||
if sshKeyFile == "" {
|
||||
fmt.Fprintf(os.Stderr, "missing key file argument\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
store, err := ssh.NewSshKeyStore(ctx, dbDir)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
err = store.AddFromFile(ctx, sshKeyFile, sessionId)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
@ -84,18 +84,18 @@ func GetTransferData(ctx context.Context, db storage.PrefixDb, publicKey string,
|
||||
|
||||
// Adjust for 0-based indexing
|
||||
i := index - 1
|
||||
transactionType := "Received"
|
||||
party := fmt.Sprintf("From: %s", strings.TrimSpace(senders[i]))
|
||||
transactionType := "received"
|
||||
party := fmt.Sprintf("from: %s", strings.TrimSpace(senders[i]))
|
||||
if strings.TrimSpace(senders[i]) == publicKey {
|
||||
transactionType = "Sent"
|
||||
party = fmt.Sprintf("To: %s", strings.TrimSpace(recipients[i]))
|
||||
transactionType = "sent"
|
||||
party = fmt.Sprintf("to: %s", strings.TrimSpace(recipients[i]))
|
||||
}
|
||||
|
||||
formattedDate := formatDate(strings.TrimSpace(dates[i]))
|
||||
|
||||
// Build the full transaction detail
|
||||
detail := fmt.Sprintf(
|
||||
"%s %s %s\n%s\nContract address: %s\nTxhash: %s\nDate: %s",
|
||||
"%s %s %s\n%s\ncontract address: %s\ntxhash: %s\ndate: %s",
|
||||
transactionType,
|
||||
strings.TrimSpace(values[i]),
|
||||
strings.TrimSpace(syms[i]),
|
||||
|
@ -2,7 +2,6 @@ package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"git.defalsify.org/vise.git/asm"
|
||||
"git.defalsify.org/vise.git/db"
|
||||
@ -65,11 +64,7 @@ func (ls *LocalHandlerService) SetDataStore(db *db.Db) {
|
||||
}
|
||||
|
||||
func (ls *LocalHandlerService) GetHandler(accountService remote.AccountServiceInterface) (*ussd.Handlers, error) {
|
||||
replaceSeparatorFunc := func(input string) string {
|
||||
return strings.ReplaceAll(input, ":", ls.Cfg.MenuSeparator)
|
||||
}
|
||||
|
||||
ussdHandlers, err := ussd.NewHandlers(ls.Parser, *ls.UserdataStore, ls.AdminStore, accountService, replaceSeparatorFunc)
|
||||
ussdHandlers, err := ussd.NewHandlers(ls.Parser, *ls.UserdataStore, ls.AdminStore, accountService)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -69,19 +69,18 @@ func (fm *FlagManager) GetFlag(label string) (uint32, error) {
|
||||
}
|
||||
|
||||
type Handlers struct {
|
||||
pe *persist.Persister
|
||||
st *state.State
|
||||
ca cache.Memory
|
||||
userdataStore common.DataStore
|
||||
adminstore *utils.AdminStore
|
||||
flagManager *asm.FlagParser
|
||||
accountService remote.AccountServiceInterface
|
||||
prefixDb storage.PrefixDb
|
||||
profile *models.Profile
|
||||
ReplaceSeparatorFunc func(string) string
|
||||
pe *persist.Persister
|
||||
st *state.State
|
||||
ca cache.Memory
|
||||
userdataStore common.DataStore
|
||||
adminstore *utils.AdminStore
|
||||
flagManager *asm.FlagParser
|
||||
accountService remote.AccountServiceInterface
|
||||
prefixDb storage.PrefixDb
|
||||
profile *models.Profile
|
||||
}
|
||||
|
||||
func NewHandlers(appFlags *asm.FlagParser, userdataStore db.Db, adminstore *utils.AdminStore, accountService remote.AccountServiceInterface, replaceSeparatorFunc func(string) string) (*Handlers, error) {
|
||||
func NewHandlers(appFlags *asm.FlagParser, userdataStore db.Db, adminstore *utils.AdminStore, accountService remote.AccountServiceInterface) (*Handlers, error) {
|
||||
if userdataStore == nil {
|
||||
return nil, fmt.Errorf("cannot create handler with nil userdata store")
|
||||
}
|
||||
@ -94,13 +93,12 @@ func NewHandlers(appFlags *asm.FlagParser, userdataStore db.Db, adminstore *util
|
||||
prefixDb := storage.NewSubPrefixDb(userdataStore, prefix)
|
||||
|
||||
h := &Handlers{
|
||||
userdataStore: userDb,
|
||||
flagManager: appFlags,
|
||||
adminstore: adminstore,
|
||||
accountService: accountService,
|
||||
prefixDb: prefixDb,
|
||||
profile: &models.Profile{Max: 6},
|
||||
ReplaceSeparatorFunc: replaceSeparatorFunc,
|
||||
userdataStore: userDb,
|
||||
flagManager: appFlags,
|
||||
adminstore: adminstore,
|
||||
accountService: accountService,
|
||||
prefixDb: prefixDb,
|
||||
profile: &models.Profile{Max: 6},
|
||||
}
|
||||
return h, nil
|
||||
}
|
||||
@ -1685,9 +1683,7 @@ func (h *Handlers) GetVoucherList(ctx context.Context, sym string, input []byte)
|
||||
return res, err
|
||||
}
|
||||
|
||||
formattedData := h.ReplaceSeparatorFunc(string(voucherData))
|
||||
|
||||
res.Content = string(formattedData)
|
||||
res.Content = string(voucherData)
|
||||
|
||||
return res, nil
|
||||
}
|
||||
@ -1850,14 +1846,13 @@ func (h *Handlers) CheckTransactions(ctx context.Context, sym string, input []by
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// GetTransactionsList reads the list of transactions from the db and formats them
|
||||
// GetTransactionsList fetches the list of transactions and formats them
|
||||
func (h *Handlers) GetTransactionsList(ctx context.Context, sym string, input []byte) (resource.Result, error) {
|
||||
var res resource.Result
|
||||
sessionId, ok := ctx.Value("SessionId").(string)
|
||||
if !ok {
|
||||
return res, fmt.Errorf("missing session")
|
||||
}
|
||||
|
||||
store := h.userdataStore
|
||||
publicKey, err := store.ReadEntry(ctx, sessionId, common.DATA_PUBLIC_KEY)
|
||||
if err != nil {
|
||||
@ -1900,14 +1895,12 @@ func (h *Handlers) GetTransactionsList(ctx context.Context, sym string, input []
|
||||
value := strings.TrimSpace(values[i])
|
||||
date := strings.Split(strings.TrimSpace(dates[i]), " ")[0]
|
||||
|
||||
status := "Received"
|
||||
status := "received"
|
||||
if sender == string(publicKey) {
|
||||
status = "Sent"
|
||||
status = "sent"
|
||||
}
|
||||
|
||||
// Use the ReplaceSeparator function for the menu separator
|
||||
transactionLine := fmt.Sprintf("%d%s%s %s %s %s", i+1, h.ReplaceSeparatorFunc(":"), status, value, sym, date)
|
||||
formattedTransactions = append(formattedTransactions, transactionLine)
|
||||
formattedTransactions = append(formattedTransactions, fmt.Sprintf("%d:%s %s %s %s", i+1, status, value, sym, date))
|
||||
}
|
||||
|
||||
res.Content = strings.Join(formattedTransactions, "\n")
|
||||
|
@ -5,7 +5,6 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
"path"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"git.defalsify.org/vise.git/lang"
|
||||
@ -33,11 +32,6 @@ var (
|
||||
flagsPath = path.Join(baseDir, "services", "registration", "pp.csv")
|
||||
)
|
||||
|
||||
// mockReplaceSeparator function
|
||||
var mockReplaceSeparator = func(input string) string {
|
||||
return strings.ReplaceAll(input, ":", ": ")
|
||||
}
|
||||
|
||||
// InitializeTestStore sets up and returns an in-memory database and store.
|
||||
func InitializeTestStore(t *testing.T) (context.Context, *common.UserDataStore) {
|
||||
ctx := context.Background()
|
||||
@ -73,15 +67,12 @@ func TestNewHandlers(t *testing.T) {
|
||||
_, store := InitializeTestStore(t)
|
||||
|
||||
fm, err := NewFlagManager(flagsPath)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
accountService := testservice.TestAccountService{}
|
||||
|
||||
// Test case for valid UserDataStore
|
||||
if err != nil {
|
||||
t.Logf(err.Error())
|
||||
}
|
||||
t.Run("Valid UserDataStore", func(t *testing.T) {
|
||||
handlers, err := NewHandlers(fm.parser, store, nil, &accountService, mockReplaceSeparator)
|
||||
handlers, err := NewHandlers(fm.parser, store, nil, &accountService)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
@ -91,30 +82,19 @@ func TestNewHandlers(t *testing.T) {
|
||||
if handlers.userdataStore == nil {
|
||||
t.Fatal("expected userdataStore to be set in handlers")
|
||||
}
|
||||
if handlers.ReplaceSeparatorFunc == nil {
|
||||
t.Fatal("expected ReplaceSeparatorFunc to be set in handlers")
|
||||
}
|
||||
|
||||
// Test ReplaceSeparatorFunc functionality
|
||||
input := "1:Menu item"
|
||||
expectedOutput := "1: Menu item"
|
||||
if handlers.ReplaceSeparatorFunc(input) != expectedOutput {
|
||||
t.Fatalf("ReplaceSeparatorFunc function did not return expected output: got %v, want %v", handlers.ReplaceSeparatorFunc(input), expectedOutput)
|
||||
}
|
||||
})
|
||||
|
||||
// Test case for nil UserDataStore
|
||||
// Test case for nil userdataStore
|
||||
t.Run("Nil UserDataStore", func(t *testing.T) {
|
||||
handlers, err := NewHandlers(fm.parser, nil, nil, &accountService, mockReplaceSeparator)
|
||||
handlers, err := NewHandlers(fm.parser, nil, nil, &accountService)
|
||||
if err == nil {
|
||||
t.Fatal("expected an error, got none")
|
||||
}
|
||||
if handlers != nil {
|
||||
t.Fatal("expected handlers to be nil")
|
||||
}
|
||||
expectedError := "cannot create handler with nil userdata store"
|
||||
if err.Error() != expectedError {
|
||||
t.Fatalf("expected error '%s', got '%v'", expectedError, err)
|
||||
if err.Error() != "cannot create handler with nil userdata store" {
|
||||
t.Fatalf("expected specific error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -2002,31 +1982,26 @@ func TestCheckVouchers(t *testing.T) {
|
||||
|
||||
func TestGetVoucherList(t *testing.T) {
|
||||
sessionId := "session123"
|
||||
|
||||
ctx := context.WithValue(context.Background(), "SessionId", sessionId)
|
||||
|
||||
spdb := InitializeTestSubPrefixDb(t, ctx)
|
||||
|
||||
// Initialize Handlers
|
||||
h := &Handlers{
|
||||
prefixDb: spdb,
|
||||
ReplaceSeparatorFunc: mockReplaceSeparator,
|
||||
prefixDb: spdb,
|
||||
}
|
||||
|
||||
mockSyms := []byte("1:SRF\n2:MILO")
|
||||
expectedSym := []byte("1:SRF\n2:MILO")
|
||||
|
||||
// Put voucher sym data from the store
|
||||
err := spdb.Put(ctx, common.ToBytes(common.DATA_VOUCHER_SYMBOLS), mockSyms)
|
||||
err := spdb.Put(ctx, common.ToBytes(common.DATA_VOUCHER_SYMBOLS), expectedSym)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expectedSyms := []byte("1: SRF\n2: MILO")
|
||||
|
||||
res, err := h.GetVoucherList(ctx, "", []byte(""))
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, res.Content, string(expectedSyms))
|
||||
assert.Equal(t, res.Content, string(expectedSym))
|
||||
}
|
||||
|
||||
func TestViewVoucher(t *testing.T) {
|
||||
|
64
internal/ssh/keystore.go
Normal file
64
internal/ssh/keystore.go
Normal file
@ -0,0 +1,64 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"git.defalsify.org/vise.git/db"
|
||||
|
||||
"git.grassecon.net/urdt/ussd/internal/storage"
|
||||
)
|
||||
|
||||
type SshKeyStore struct {
|
||||
store db.Db
|
||||
}
|
||||
|
||||
func NewSshKeyStore(ctx context.Context, dbDir string) (*SshKeyStore, error) {
|
||||
keyStore := &SshKeyStore{}
|
||||
keyStoreFile := path.Join(dbDir, "ssh_authorized_keys.gdbm")
|
||||
keyStore.store = storage.NewThreadGdbmDb()
|
||||
err := keyStore.store.Connect(ctx, keyStoreFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return keyStore, nil
|
||||
}
|
||||
|
||||
func(s *SshKeyStore) AddFromFile(ctx context.Context, fp string, sessionId string) error {
|
||||
_, err := os.Stat(fp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot open ssh server public key file: %v\n", err)
|
||||
}
|
||||
|
||||
publicBytes, err := os.ReadFile(fp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to load public key: %v", err)
|
||||
}
|
||||
pubKey, _, _, _, err := ssh.ParseAuthorizedKey(publicBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to parse public key: %v", err)
|
||||
}
|
||||
k := append([]byte{0x01}, pubKey.Marshal()...)
|
||||
s.store.SetPrefix(storage.DATATYPE_EXTEND)
|
||||
logg.Infof("Added key", "sessionId", sessionId, "public key", string(publicBytes))
|
||||
return s.store.Put(ctx, k, []byte(sessionId))
|
||||
}
|
||||
|
||||
func(s *SshKeyStore) Get(ctx context.Context, pubKey ssh.PublicKey) (string, error) {
|
||||
s.store.SetLanguage(nil)
|
||||
s.store.SetPrefix(storage.DATATYPE_EXTEND)
|
||||
k := append([]byte{0x01}, pubKey.Marshal()...)
|
||||
v, err := s.store.Get(ctx, k)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(v), nil
|
||||
}
|
||||
|
||||
func(s *SshKeyStore) Close() error {
|
||||
return s.store.Close()
|
||||
}
|
284
internal/ssh/ssh.go
Normal file
284
internal/ssh/ssh.go
Normal file
@ -0,0 +1,284 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"git.defalsify.org/vise.git/engine"
|
||||
"git.defalsify.org/vise.git/logging"
|
||||
"git.defalsify.org/vise.git/resource"
|
||||
"git.defalsify.org/vise.git/state"
|
||||
|
||||
"git.grassecon.net/urdt/ussd/internal/handlers"
|
||||
"git.grassecon.net/urdt/ussd/internal/storage"
|
||||
)
|
||||
|
||||
var (
|
||||
logg = logging.NewVanilla().WithDomain("ssh")
|
||||
)
|
||||
|
||||
type auther struct {
|
||||
Ctx context.Context
|
||||
keyStore *SshKeyStore
|
||||
auth map[string]string
|
||||
}
|
||||
|
||||
func NewAuther(ctx context.Context, keyStore *SshKeyStore) *auther {
|
||||
return &auther{
|
||||
Ctx: ctx,
|
||||
keyStore: keyStore,
|
||||
auth: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
func(a *auther) Check(conn ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
|
||||
va, err := a.keyStore.Get(a.Ctx, pubKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ka := hex.EncodeToString(conn.SessionID())
|
||||
a.auth[ka] = va
|
||||
fmt.Fprintf(os.Stderr, "connect: %s -> %s\n", ka, va)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func(a *auther) FromConn(c *ssh.ServerConn) (string, error) {
|
||||
if c == nil {
|
||||
return "", errors.New("nil server conn")
|
||||
}
|
||||
if c.Conn == nil {
|
||||
return "", errors.New("nil underlying conn")
|
||||
}
|
||||
return a.Get(c.Conn.SessionID())
|
||||
}
|
||||
|
||||
|
||||
func(a *auther) Get(k []byte) (string, error) {
|
||||
ka := hex.EncodeToString(k)
|
||||
v, ok := a.auth[ka]
|
||||
if !ok {
|
||||
return "", errors.New("not found")
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func(s *SshRunner) serve(ctx context.Context, sessionId string, ch ssh.NewChannel, en engine.Engine) error {
|
||||
if ch == nil {
|
||||
return errors.New("nil channel")
|
||||
}
|
||||
if ch.ChannelType() != "session" {
|
||||
ch.Reject(ssh.UnknownChannelType, "that is not the channel you are looking for")
|
||||
return errors.New("not a session")
|
||||
}
|
||||
channel, requests, err := ch.Accept()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer channel.Close()
|
||||
s.wg.Add(1)
|
||||
go func(reqIn <-chan *ssh.Request) {
|
||||
defer s.wg.Done()
|
||||
for req := range reqIn {
|
||||
req.Reply(req.Type == "shell", nil)
|
||||
}
|
||||
_ = requests
|
||||
}(requests)
|
||||
|
||||
cont, err := en.Exec(ctx, []byte{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("initial engine exec err: %v", err)
|
||||
}
|
||||
|
||||
var input [state.INPUT_LIMIT]byte
|
||||
for cont {
|
||||
c, err := en.Flush(ctx, channel)
|
||||
if err != nil {
|
||||
return fmt.Errorf("flush err: %v", err)
|
||||
}
|
||||
_, err = channel.Write([]byte{0x0a})
|
||||
if err != nil {
|
||||
return fmt.Errorf("newline err: %v", err)
|
||||
}
|
||||
c, err = channel.Read(input[:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("read input fail: %v", err)
|
||||
}
|
||||
logg.TraceCtxf(ctx, "input read", "c", c, "input", input[:c-1])
|
||||
cont, err = en.Exec(ctx, input[:c-1])
|
||||
if err != nil {
|
||||
return fmt.Errorf("engine exec err: %v", err)
|
||||
}
|
||||
logg.TraceCtxf(ctx, "exec cont", "cont", cont, "en", en)
|
||||
_ = c
|
||||
}
|
||||
c, err := en.Flush(ctx, channel)
|
||||
if err != nil {
|
||||
return fmt.Errorf("last flush err: %v", err)
|
||||
}
|
||||
_ = c
|
||||
return nil
|
||||
}
|
||||
|
||||
type SshRunner struct {
|
||||
Ctx context.Context
|
||||
Cfg engine.Config
|
||||
FlagFile string
|
||||
DbDir string
|
||||
ResourceDir string
|
||||
Debug bool
|
||||
SrvKeyFile string
|
||||
Host string
|
||||
Port uint
|
||||
wg sync.WaitGroup
|
||||
lst net.Listener
|
||||
}
|
||||
|
||||
func(s *SshRunner) Stop() error {
|
||||
return s.lst.Close()
|
||||
}
|
||||
|
||||
func(s *SshRunner) GetEngine(sessionId string) (engine.Engine, func(), error) {
|
||||
ctx := s.Ctx
|
||||
menuStorageService := storage.NewMenuStorageService(s.DbDir, s.ResourceDir)
|
||||
|
||||
err := menuStorageService.EnsureDbDir()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
rs, err := menuStorageService.GetResource(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
pe, err := menuStorageService.GetPersister(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
userdatastore, err := menuStorageService.GetUserdataDb(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
dbResource, ok := rs.(*resource.DbResource)
|
||||
if !ok {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
lhs, err := handlers.NewLocalHandlerService(s.FlagFile, true, dbResource, s.Cfg, rs)
|
||||
lhs.SetDataStore(&userdatastore)
|
||||
lhs.SetPersister(pe)
|
||||
lhs.Cfg.SessionId = sessionId
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
hl, err := lhs.GetHandler()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
en := lhs.GetEngine()
|
||||
en = en.WithFirst(hl.Init)
|
||||
if s.Debug {
|
||||
en = en.WithDebug(nil)
|
||||
}
|
||||
// TODO: this is getting very hacky!
|
||||
closer := func() {
|
||||
err := menuStorageService.Close()
|
||||
if err != nil {
|
||||
logg.ErrorCtxf(ctx, "menu storage service cleanup fail", "err", err)
|
||||
}
|
||||
}
|
||||
return en, closer, nil
|
||||
}
|
||||
|
||||
// adapted example from crypto/ssh package, NewServerConn doc
|
||||
func(s *SshRunner) Run(ctx context.Context, keyStore *SshKeyStore) {
|
||||
running := true
|
||||
|
||||
// TODO: waitgroup should probably not be global
|
||||
defer s.wg.Wait()
|
||||
|
||||
auth := NewAuther(ctx, keyStore)
|
||||
cfg := ssh.ServerConfig{
|
||||
PublicKeyCallback: auth.Check,
|
||||
}
|
||||
|
||||
privateBytes, err := os.ReadFile(s.SrvKeyFile)
|
||||
if err != nil {
|
||||
logg.ErrorCtxf(ctx, "Failed to load private key", "err", err)
|
||||
}
|
||||
private, err := ssh.ParsePrivateKey(privateBytes)
|
||||
if err != nil {
|
||||
logg.ErrorCtxf(ctx, "Failed to parse private key", "err", err)
|
||||
}
|
||||
srvPub := private.PublicKey()
|
||||
srvPubStr := base64.StdEncoding.EncodeToString(srvPub.Marshal())
|
||||
logg.InfoCtxf(ctx, "have server key", "type", srvPub.Type(), "public", srvPubStr)
|
||||
cfg.AddHostKey(private)
|
||||
|
||||
s.lst, err = net.Listen("tcp", fmt.Sprintf("%s:%d", s.Host, s.Port))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
for running {
|
||||
conn, err := s.lst.Accept()
|
||||
if err != nil {
|
||||
logg.ErrorCtxf(ctx, "ssh accept error", "err", err)
|
||||
running = false
|
||||
continue
|
||||
}
|
||||
|
||||
go func(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
for true {
|
||||
srvConn, nC, rC, err := ssh.NewServerConn(conn, &cfg)
|
||||
if err != nil {
|
||||
logg.InfoCtxf(ctx, "rejected client", "err", err)
|
||||
return
|
||||
}
|
||||
logg.DebugCtxf(ctx, "ssh client connected", "conn", srvConn)
|
||||
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
ssh.DiscardRequests(rC)
|
||||
s.wg.Done()
|
||||
}()
|
||||
|
||||
sessionId, err := auth.FromConn(srvConn)
|
||||
if err != nil {
|
||||
logg.ErrorCtxf(ctx, "Cannot find authentication")
|
||||
return
|
||||
}
|
||||
en, closer, err := s.GetEngine(sessionId)
|
||||
if err != nil {
|
||||
logg.ErrorCtxf(ctx, "engine won't start", "err", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
err := en.Finish()
|
||||
if err != nil {
|
||||
logg.ErrorCtxf(ctx, "engine won't stop", "err", err)
|
||||
}
|
||||
closer()
|
||||
}()
|
||||
for ch := range nC {
|
||||
err = s.serve(ctx, sessionId, ch, en)
|
||||
logg.ErrorCtxf(ctx, "ssh server finish", "err", err)
|
||||
}
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
}
|
@ -5,6 +5,10 @@ import (
|
||||
"git.defalsify.org/vise.git/persist"
|
||||
)
|
||||
|
||||
const (
|
||||
DATATYPE_EXTEND = 128
|
||||
)
|
||||
|
||||
type Storage struct {
|
||||
Persister *persist.Persister
|
||||
UserdataDb db.Db
|
||||
|
Loading…
Reference in New Issue
Block a user