forked from urdt/ussd
Compare commits
8 Commits
lash/ssh-4
...
lash/stale
| Author | SHA1 | Date | |
|---|---|---|---|
| 0e61945cad | |||
|
|
94551ba37f
|
||
|
|
973a69455e | ||
|
|
0af7379ae4 | ||
|
|
ce30cb740e
|
||
|
|
659fd00c53
|
||
|
|
9b3ed0d6ae
|
||
|
|
fbcde2f322
|
@@ -1,34 +0,0 @@
|
||||
# URDT-USSD SSH server
|
||||
|
||||
An SSH server entry point for the vise engine.
|
||||
|
||||
|
||||
## Adding public keys for access
|
||||
|
||||
Map your (client) public key to a session identifier (e.g. phone number)
|
||||
|
||||
```
|
||||
go run -v -tags logtrace ./cmd/ssh/sshkey/main.go -i <session_id> [--dbdir <dbpath>] <client_publickey_filepath>
|
||||
```
|
||||
|
||||
|
||||
## Create a private key for the server
|
||||
|
||||
```
|
||||
ssh-keygen -N "" -f <server_privatekey_filepath>
|
||||
```
|
||||
|
||||
|
||||
## Run the server
|
||||
|
||||
|
||||
```
|
||||
go run -v -tags logtrace ./cmd/ssh/main.go -h <host> -p <port> [--dbdir <dbpath>] <server_privatekey_filepath>
|
||||
```
|
||||
|
||||
|
||||
## Connect to the server
|
||||
|
||||
```
|
||||
ssh [-v] -T -p <port> -i <client_publickey_filepath> <host>
|
||||
```
|
||||
115
cmd/ssh/main.go
115
cmd/ssh/main.go
@@ -1,115 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"path"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"git.defalsify.org/vise.git/db"
|
||||
"git.defalsify.org/vise.git/engine"
|
||||
"git.defalsify.org/vise.git/logging"
|
||||
|
||||
"git.grassecon.net/urdt/ussd/internal/ssh"
|
||||
)
|
||||
|
||||
var (
|
||||
wg sync.WaitGroup
|
||||
keyStore db.Db
|
||||
logg = logging.NewVanilla()
|
||||
scriptDir = path.Join("services", "registration")
|
||||
)
|
||||
|
||||
func main() {
|
||||
var dbDir string
|
||||
var resourceDir string
|
||||
var size uint
|
||||
var engineDebug bool
|
||||
var stateDebug bool
|
||||
var host string
|
||||
var port uint
|
||||
flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from")
|
||||
flag.StringVar(&resourceDir, "resourcedir", path.Join("services", "registration"), "resource dir")
|
||||
flag.BoolVar(&engineDebug, "engine-debug", false, "use engine debug output")
|
||||
flag.BoolVar(&stateDebug, "state-debug", false, "use engine debug output")
|
||||
flag.UintVar(&size, "s", 160, "max size of output")
|
||||
flag.StringVar(&host, "h", "127.0.0.1", "http host")
|
||||
flag.UintVar(&port, "p", 7122, "http port")
|
||||
flag.Parse()
|
||||
|
||||
sshKeyFile := flag.Arg(0)
|
||||
_, err := os.Stat(sshKeyFile)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "cannot open ssh server private key file: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
logg.WarnCtxf(ctx, "!!!!! WARNING WARNING WARNING")
|
||||
logg.WarnCtxf(ctx, "!!!!! =======================")
|
||||
logg.WarnCtxf(ctx, "!!!!! This is not a production ready server!")
|
||||
logg.WarnCtxf(ctx, "!!!!! Do not expose to internet and only use with tunnel!")
|
||||
logg.WarnCtxf(ctx, "!!!!! (See ssh -L <...>)")
|
||||
|
||||
logg.Infof("start command", "dbdir", dbDir, "resourcedir", resourceDir, "outputsize", size, "keyfile", sshKeyFile, "host", host, "port", port)
|
||||
|
||||
pfp := path.Join(scriptDir, "pp.csv")
|
||||
|
||||
cfg := engine.Config{
|
||||
Root: "root",
|
||||
OutputSize: uint32(size),
|
||||
FlagCount: uint32(16),
|
||||
}
|
||||
if stateDebug {
|
||||
cfg.StateDebug = true
|
||||
}
|
||||
if engineDebug {
|
||||
cfg.EngineDebug = true
|
||||
}
|
||||
|
||||
authKeyStore, err := ssh.NewSshKeyStore(ctx, dbDir)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "keystore file open error: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer func () {
|
||||
logg.TraceCtxf(ctx, "shutdown auth key store reached")
|
||||
err = authKeyStore.Close()
|
||||
if err != nil {
|
||||
logg.ErrorCtxf(ctx, "keystore close error", "err", err)
|
||||
}
|
||||
}()
|
||||
|
||||
cint := make(chan os.Signal)
|
||||
cterm := make(chan os.Signal)
|
||||
signal.Notify(cint, os.Interrupt, syscall.SIGINT)
|
||||
signal.Notify(cterm, os.Interrupt, syscall.SIGTERM)
|
||||
|
||||
runner := &ssh.SshRunner{
|
||||
Cfg: cfg,
|
||||
Debug: engineDebug,
|
||||
FlagFile: pfp,
|
||||
DbDir: dbDir,
|
||||
ResourceDir: resourceDir,
|
||||
SrvKeyFile: sshKeyFile,
|
||||
Host: host,
|
||||
Port: port,
|
||||
}
|
||||
go func() {
|
||||
select {
|
||||
case _ = <-cint:
|
||||
case _ = <-cterm:
|
||||
}
|
||||
logg.TraceCtxf(ctx, "shutdown runner reached")
|
||||
err := runner.Stop()
|
||||
if err != nil {
|
||||
logg.ErrorCtxf(ctx, "runner stop error", "err", err)
|
||||
}
|
||||
|
||||
}()
|
||||
runner.Run(ctx, authKeyStore)
|
||||
}
|
||||
@@ -1,44 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"git.grassecon.net/urdt/ussd/internal/ssh"
|
||||
)
|
||||
|
||||
func main() {
|
||||
var dbDir string
|
||||
var sessionId string
|
||||
flag.StringVar(&dbDir, "dbdir", ".state", "database dir to read from")
|
||||
flag.StringVar(&sessionId, "i", "", "session id")
|
||||
flag.Parse()
|
||||
|
||||
if sessionId == "" {
|
||||
fmt.Fprintf(os.Stderr, "empty session id\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
sshKeyFile := flag.Arg(0)
|
||||
if sshKeyFile == "" {
|
||||
fmt.Fprintf(os.Stderr, "missing key file argument\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
store, err := ssh.NewSshKeyStore(ctx, dbDir)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
err = store.AddFromFile(ctx, sshKeyFile, sessionId)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
4
go.mod
4
go.mod
@@ -5,14 +5,14 @@ go 1.23.0
|
||||
toolchain go1.23.2
|
||||
|
||||
require (
|
||||
git.defalsify.org/vise.git v0.2.1-0.20241017112704-307fa6fcdc6b
|
||||
git.defalsify.org/vise.git v0.2.1-0.20241031204035-b588301738ed
|
||||
github.com/alecthomas/assert/v2 v2.2.2
|
||||
github.com/grassrootseconomics/eth-custodial v1.3.0-beta
|
||||
github.com/peteole/testdata-loader v0.3.0
|
||||
gopkg.in/leonelquinteros/gotext.v1 v1.3.1
|
||||
)
|
||||
|
||||
require github.com/grassrootseconomics/ussd-data-service v0.0.0-20241003123429-4904b4438a3a // indirect
|
||||
require github.com/grassrootseconomics/ussd-data-service v0.0.0-20241003123429-4904b4438a3a
|
||||
|
||||
require (
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
|
||||
4
go.sum
4
go.sum
@@ -1,5 +1,5 @@
|
||||
git.defalsify.org/vise.git v0.2.1-0.20241017112704-307fa6fcdc6b h1:dxBplsIlzJHV+5EH+gzB+w08Blt7IJbb2jeRe1OEjLU=
|
||||
git.defalsify.org/vise.git v0.2.1-0.20241017112704-307fa6fcdc6b/go.mod h1:jyBMe1qTYUz3mmuoC9JQ/TvFeW0vTanCUcPu3H8p4Ck=
|
||||
git.defalsify.org/vise.git v0.2.1-0.20241031204035-b588301738ed h1:4TrsfbK7NKgsa7KjMPlnV/tjYTkAAXP5PWAZzUfzCdI=
|
||||
git.defalsify.org/vise.git v0.2.1-0.20241031204035-b588301738ed/go.mod h1:jyBMe1qTYUz3mmuoC9JQ/TvFeW0vTanCUcPu3H8p4Ck=
|
||||
github.com/alecthomas/assert/v2 v2.2.2 h1:Z/iVC0xZfWTaFNE6bA3z07T86hd45Xe2eLt6WVy2bbk=
|
||||
github.com/alecthomas/assert/v2 v2.2.2/go.mod h1:pXcQ2Asjp247dahGEmsZ6ru0UVwnkhktn7S0bBDLxvQ=
|
||||
github.com/alecthomas/participle/v2 v2.0.0 h1:Fgrq+MbuSsJwIkw3fEj9h75vDP0Er5JzepJ0/HNHv0g=
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"git.defalsify.org/vise.git/db"
|
||||
|
||||
"git.grassecon.net/urdt/ussd/internal/storage"
|
||||
)
|
||||
|
||||
type SshKeyStore struct {
|
||||
store db.Db
|
||||
}
|
||||
|
||||
func NewSshKeyStore(ctx context.Context, dbDir string) (*SshKeyStore, error) {
|
||||
keyStore := &SshKeyStore{}
|
||||
keyStoreFile := path.Join(dbDir, "ssh_authorized_keys.gdbm")
|
||||
keyStore.store = storage.NewThreadGdbmDb()
|
||||
err := keyStore.store.Connect(ctx, keyStoreFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return keyStore, nil
|
||||
}
|
||||
|
||||
func(s *SshKeyStore) AddFromFile(ctx context.Context, fp string, sessionId string) error {
|
||||
_, err := os.Stat(fp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot open ssh server public key file: %v\n", err)
|
||||
}
|
||||
|
||||
publicBytes, err := os.ReadFile(fp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to load public key: %v", err)
|
||||
}
|
||||
pubKey, _, _, _, err := ssh.ParseAuthorizedKey(publicBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to parse public key: %v", err)
|
||||
}
|
||||
k := append([]byte{0x01}, pubKey.Marshal()...)
|
||||
s.store.SetPrefix(storage.DATATYPE_EXTEND)
|
||||
logg.Infof("Added key", "sessionId", sessionId, "public key", string(publicBytes))
|
||||
return s.store.Put(ctx, k, []byte(sessionId))
|
||||
}
|
||||
|
||||
func(s *SshKeyStore) Get(ctx context.Context, pubKey ssh.PublicKey) (string, error) {
|
||||
s.store.SetLanguage(nil)
|
||||
s.store.SetPrefix(storage.DATATYPE_EXTEND)
|
||||
k := append([]byte{0x01}, pubKey.Marshal()...)
|
||||
v, err := s.store.Get(ctx, k)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(v), nil
|
||||
}
|
||||
|
||||
func(s *SshKeyStore) Close() error {
|
||||
return s.store.Close()
|
||||
}
|
||||
@@ -1,284 +0,0 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"git.defalsify.org/vise.git/engine"
|
||||
"git.defalsify.org/vise.git/logging"
|
||||
"git.defalsify.org/vise.git/resource"
|
||||
"git.defalsify.org/vise.git/state"
|
||||
|
||||
"git.grassecon.net/urdt/ussd/internal/handlers"
|
||||
"git.grassecon.net/urdt/ussd/internal/storage"
|
||||
)
|
||||
|
||||
var (
|
||||
logg = logging.NewVanilla().WithDomain("ssh")
|
||||
)
|
||||
|
||||
type auther struct {
|
||||
Ctx context.Context
|
||||
keyStore *SshKeyStore
|
||||
auth map[string]string
|
||||
}
|
||||
|
||||
func NewAuther(ctx context.Context, keyStore *SshKeyStore) *auther {
|
||||
return &auther{
|
||||
Ctx: ctx,
|
||||
keyStore: keyStore,
|
||||
auth: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
func(a *auther) Check(conn ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
|
||||
va, err := a.keyStore.Get(a.Ctx, pubKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ka := hex.EncodeToString(conn.SessionID())
|
||||
a.auth[ka] = va
|
||||
fmt.Fprintf(os.Stderr, "connect: %s -> %s\n", ka, va)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func(a *auther) FromConn(c *ssh.ServerConn) (string, error) {
|
||||
if c == nil {
|
||||
return "", errors.New("nil server conn")
|
||||
}
|
||||
if c.Conn == nil {
|
||||
return "", errors.New("nil underlying conn")
|
||||
}
|
||||
return a.Get(c.Conn.SessionID())
|
||||
}
|
||||
|
||||
|
||||
func(a *auther) Get(k []byte) (string, error) {
|
||||
ka := hex.EncodeToString(k)
|
||||
v, ok := a.auth[ka]
|
||||
if !ok {
|
||||
return "", errors.New("not found")
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func(s *SshRunner) serve(ctx context.Context, sessionId string, ch ssh.NewChannel, en engine.Engine) error {
|
||||
if ch == nil {
|
||||
return errors.New("nil channel")
|
||||
}
|
||||
if ch.ChannelType() != "session" {
|
||||
ch.Reject(ssh.UnknownChannelType, "that is not the channel you are looking for")
|
||||
return errors.New("not a session")
|
||||
}
|
||||
channel, requests, err := ch.Accept()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer channel.Close()
|
||||
s.wg.Add(1)
|
||||
go func(reqIn <-chan *ssh.Request) {
|
||||
defer s.wg.Done()
|
||||
for req := range reqIn {
|
||||
req.Reply(req.Type == "shell", nil)
|
||||
}
|
||||
_ = requests
|
||||
}(requests)
|
||||
|
||||
cont, err := en.Exec(ctx, []byte{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("initial engine exec err: %v", err)
|
||||
}
|
||||
|
||||
var input [state.INPUT_LIMIT]byte
|
||||
for cont {
|
||||
c, err := en.Flush(ctx, channel)
|
||||
if err != nil {
|
||||
return fmt.Errorf("flush err: %v", err)
|
||||
}
|
||||
_, err = channel.Write([]byte{0x0a})
|
||||
if err != nil {
|
||||
return fmt.Errorf("newline err: %v", err)
|
||||
}
|
||||
c, err = channel.Read(input[:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("read input fail: %v", err)
|
||||
}
|
||||
logg.TraceCtxf(ctx, "input read", "c", c, "input", input[:c-1])
|
||||
cont, err = en.Exec(ctx, input[:c-1])
|
||||
if err != nil {
|
||||
return fmt.Errorf("engine exec err: %v", err)
|
||||
}
|
||||
logg.TraceCtxf(ctx, "exec cont", "cont", cont, "en", en)
|
||||
_ = c
|
||||
}
|
||||
c, err := en.Flush(ctx, channel)
|
||||
if err != nil {
|
||||
return fmt.Errorf("last flush err: %v", err)
|
||||
}
|
||||
_ = c
|
||||
return nil
|
||||
}
|
||||
|
||||
type SshRunner struct {
|
||||
Ctx context.Context
|
||||
Cfg engine.Config
|
||||
FlagFile string
|
||||
DbDir string
|
||||
ResourceDir string
|
||||
Debug bool
|
||||
SrvKeyFile string
|
||||
Host string
|
||||
Port uint
|
||||
wg sync.WaitGroup
|
||||
lst net.Listener
|
||||
}
|
||||
|
||||
func(s *SshRunner) Stop() error {
|
||||
return s.lst.Close()
|
||||
}
|
||||
|
||||
func(s *SshRunner) GetEngine(sessionId string) (engine.Engine, func(), error) {
|
||||
ctx := s.Ctx
|
||||
menuStorageService := storage.NewMenuStorageService(s.DbDir, s.ResourceDir)
|
||||
|
||||
err := menuStorageService.EnsureDbDir()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
rs, err := menuStorageService.GetResource(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
pe, err := menuStorageService.GetPersister(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
userdatastore, err := menuStorageService.GetUserdataDb(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
dbResource, ok := rs.(*resource.DbResource)
|
||||
if !ok {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
lhs, err := handlers.NewLocalHandlerService(s.FlagFile, true, dbResource, s.Cfg, rs)
|
||||
lhs.SetDataStore(&userdatastore)
|
||||
lhs.SetPersister(pe)
|
||||
lhs.Cfg.SessionId = sessionId
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
hl, err := lhs.GetHandler()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
en := lhs.GetEngine()
|
||||
en = en.WithFirst(hl.Init)
|
||||
if s.Debug {
|
||||
en = en.WithDebug(nil)
|
||||
}
|
||||
// TODO: this is getting very hacky!
|
||||
closer := func() {
|
||||
err := menuStorageService.Close()
|
||||
if err != nil {
|
||||
logg.ErrorCtxf(ctx, "menu storage service cleanup fail", "err", err)
|
||||
}
|
||||
}
|
||||
return en, closer, nil
|
||||
}
|
||||
|
||||
// adapted example from crypto/ssh package, NewServerConn doc
|
||||
func(s *SshRunner) Run(ctx context.Context, keyStore *SshKeyStore) {
|
||||
running := true
|
||||
|
||||
// TODO: waitgroup should probably not be global
|
||||
defer s.wg.Wait()
|
||||
|
||||
auth := NewAuther(ctx, keyStore)
|
||||
cfg := ssh.ServerConfig{
|
||||
PublicKeyCallback: auth.Check,
|
||||
}
|
||||
|
||||
privateBytes, err := os.ReadFile(s.SrvKeyFile)
|
||||
if err != nil {
|
||||
logg.ErrorCtxf(ctx, "Failed to load private key", "err", err)
|
||||
}
|
||||
private, err := ssh.ParsePrivateKey(privateBytes)
|
||||
if err != nil {
|
||||
logg.ErrorCtxf(ctx, "Failed to parse private key", "err", err)
|
||||
}
|
||||
srvPub := private.PublicKey()
|
||||
srvPubStr := base64.StdEncoding.EncodeToString(srvPub.Marshal())
|
||||
logg.InfoCtxf(ctx, "have server key", "type", srvPub.Type(), "public", srvPubStr)
|
||||
cfg.AddHostKey(private)
|
||||
|
||||
s.lst, err = net.Listen("tcp", fmt.Sprintf("%s:%d", s.Host, s.Port))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
for running {
|
||||
conn, err := s.lst.Accept()
|
||||
if err != nil {
|
||||
logg.ErrorCtxf(ctx, "ssh accept error", "err", err)
|
||||
running = false
|
||||
continue
|
||||
}
|
||||
|
||||
go func(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
for true {
|
||||
srvConn, nC, rC, err := ssh.NewServerConn(conn, &cfg)
|
||||
if err != nil {
|
||||
logg.InfoCtxf(ctx, "rejected client", "err", err)
|
||||
return
|
||||
}
|
||||
logg.DebugCtxf(ctx, "ssh client connected", "conn", srvConn)
|
||||
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
ssh.DiscardRequests(rC)
|
||||
s.wg.Done()
|
||||
}()
|
||||
|
||||
sessionId, err := auth.FromConn(srvConn)
|
||||
if err != nil {
|
||||
logg.ErrorCtxf(ctx, "Cannot find authentication")
|
||||
return
|
||||
}
|
||||
en, closer, err := s.GetEngine(sessionId)
|
||||
if err != nil {
|
||||
logg.ErrorCtxf(ctx, "engine won't start", "err", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
err := en.Finish()
|
||||
if err != nil {
|
||||
logg.ErrorCtxf(ctx, "engine won't stop", "err", err)
|
||||
}
|
||||
closer()
|
||||
}()
|
||||
for ch := range nC {
|
||||
err = s.serve(ctx, sessionId, ch, en)
|
||||
logg.ErrorCtxf(ctx, "ssh server finish", "err", err)
|
||||
}
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
}
|
||||
@@ -10,6 +10,10 @@ const (
|
||||
DATATYPE_USERSUB = 64
|
||||
)
|
||||
|
||||
const (
|
||||
SUBPREFIX_TIME = uint16(1)
|
||||
)
|
||||
|
||||
// PrefixDb interface abstracts the database operations.
|
||||
type PrefixDb interface {
|
||||
Get(ctx context.Context, key []byte) ([]byte, error)
|
||||
@@ -30,8 +34,12 @@ func NewSubPrefixDb(store db.Db, pfx []byte) *SubPrefixDb {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SubPrefixDb) toKey(k []byte) []byte {
|
||||
return append(s.pfx, k...)
|
||||
func(s *SubPrefixDb) SetSession(sessionId string) {
|
||||
s.store.SetSession(sessionId)
|
||||
}
|
||||
|
||||
func(s *SubPrefixDb) toKey(k []byte) []byte {
|
||||
return append(s.pfx, k...)
|
||||
}
|
||||
|
||||
func (s *SubPrefixDb) Get(ctx context.Context, key []byte) ([]byte, error) {
|
||||
|
||||
@@ -12,6 +12,7 @@ var (
|
||||
dbC map[string]chan db.Db
|
||||
)
|
||||
|
||||
|
||||
type ThreadGdbmDb struct {
|
||||
db db.Db
|
||||
connStr string
|
||||
|
||||
@@ -5,10 +5,6 @@ import (
|
||||
"git.defalsify.org/vise.git/persist"
|
||||
)
|
||||
|
||||
const (
|
||||
DATATYPE_EXTEND = 128
|
||||
)
|
||||
|
||||
type Storage struct {
|
||||
Persister *persist.Persister
|
||||
UserdataDb db.Db
|
||||
|
||||
109
internal/storage/timed.go
Normal file
109
internal/storage/timed.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"time"
|
||||
"encoding/binary"
|
||||
|
||||
"git.defalsify.org/vise.git/db"
|
||||
)
|
||||
|
||||
type TimedDb struct {
|
||||
db.Db
|
||||
tdb *SubPrefixDb
|
||||
ttl time.Duration
|
||||
parentPfx uint8
|
||||
parentSession []byte
|
||||
matchPfx map[uint8][][]byte
|
||||
}
|
||||
|
||||
func NewTimedDb(db db.Db, ttl time.Duration) *TimedDb {
|
||||
var b [2]byte
|
||||
binary.BigEndian.PutUint16(b[:], SUBPREFIX_TIME)
|
||||
sdb := NewSubPrefixDb(db, b[:])
|
||||
return &TimedDb{
|
||||
Db: db,
|
||||
tdb: sdb,
|
||||
ttl: ttl,
|
||||
}
|
||||
}
|
||||
|
||||
func(tib *TimedDb) WithMatch(pfx uint8, keyPart []byte) *TimedDb {
|
||||
if tib.matchPfx == nil {
|
||||
tib.matchPfx = make(map[uint8][][]byte)
|
||||
}
|
||||
tib.matchPfx[pfx] = append(tib.matchPfx[pfx], keyPart)
|
||||
return tib
|
||||
}
|
||||
|
||||
func(tib *TimedDb) checkPrefix(pfx uint8, key []byte) bool {
|
||||
var v []byte
|
||||
if tib.matchPfx == nil {
|
||||
return true
|
||||
}
|
||||
for _, v = range(tib.matchPfx[pfx]) {
|
||||
l := len(v)
|
||||
if l > len(key) {
|
||||
continue
|
||||
}
|
||||
if bytes.Equal(v, key[:l]) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func(tib *TimedDb) SetPrefix(pfx uint8) {
|
||||
tib.Db.SetPrefix(pfx)
|
||||
tib.parentPfx = pfx
|
||||
}
|
||||
|
||||
func(tib *TimedDb) SetSession(session string) {
|
||||
tib.Db.SetSession(session)
|
||||
tib.parentSession = []byte(session)
|
||||
}
|
||||
|
||||
func(tib *TimedDb) Put(ctx context.Context, key []byte, val []byte) error {
|
||||
t := time.Now()
|
||||
b, err := t.MarshalBinary()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = tib.Db.Put(ctx, key, val)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
tib.parentPfx = 0
|
||||
tib.parentSession = nil
|
||||
}()
|
||||
if tib.checkPrefix(tib.parentPfx, key) {
|
||||
tib.tdb.SetSession("")
|
||||
k := db.ToSessionKey(tib.parentPfx, []byte(tib.parentSession), key)
|
||||
k = append([]byte{tib.parentPfx}, k...)
|
||||
err = tib.tdb.Put(ctx, k, b)
|
||||
if err != nil {
|
||||
logg.ErrorCtxf(ctx, "failed to update timestamp of record", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func(tib *TimedDb) Stale(ctx context.Context, pfx uint8, sessionId string, key []byte) bool {
|
||||
tib.tdb.SetSession("")
|
||||
b := db.ToSessionKey(pfx, []byte(sessionId), key)
|
||||
b = append([]byte{pfx}, b...)
|
||||
v, err := tib.tdb.Get(ctx, b)
|
||||
if err != nil {
|
||||
logg.WarnCtxf(ctx, "no time entry", "key", key, "b", b)
|
||||
return false
|
||||
}
|
||||
t_now := time.Now()
|
||||
t_then := time.Time{}
|
||||
err = t_then.UnmarshalBinary(v)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return t_now.After(t_then.Add(tib.ttl))
|
||||
}
|
||||
125
internal/storage/timed_test.go
Normal file
125
internal/storage/timed_test.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.defalsify.org/vise.git/db"
|
||||
memdb "git.defalsify.org/vise.git/db/mem"
|
||||
)
|
||||
|
||||
func TestStaleDb(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mdb := memdb.NewMemDb()
|
||||
err := mdb.Connect(ctx, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tdb := NewTimedDb(mdb, time.Duration(time.Millisecond))
|
||||
tdb.SetPrefix(db.DATATYPE_USERDATA)
|
||||
k := []byte("foo")
|
||||
err = tdb.Put(ctx, k, []byte("bar"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if tdb.Stale(ctx, db.DATATYPE_USERDATA, "", k) {
|
||||
t.Fatal("expected not stale")
|
||||
}
|
||||
time.Sleep(time.Millisecond)
|
||||
if !tdb.Stale(ctx, db.DATATYPE_USERDATA, "", k) {
|
||||
t.Fatal("expected stale")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilteredStaleDb(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mdb := memdb.NewMemDb()
|
||||
err := mdb.Connect(ctx, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
k := []byte("foo")
|
||||
tdb := NewTimedDb(mdb, time.Duration(time.Millisecond))
|
||||
tdb = tdb.WithMatch(db.DATATYPE_STATE, []byte("fo"))
|
||||
tdb.SetPrefix(db.DATATYPE_USERDATA)
|
||||
tdb.SetSession("inky")
|
||||
err = tdb.Put(ctx, k, []byte("bar"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tdb.SetPrefix(db.DATATYPE_STATE)
|
||||
tdb.SetSession("inky")
|
||||
err = tdb.Put(ctx, k, []byte("pinky"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tdb.SetSession("blinky")
|
||||
err = tdb.Put(ctx, k, []byte("clyde"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if tdb.Stale(ctx, db.DATATYPE_USERDATA, "inky", k) {
|
||||
t.Fatal("expected not stale")
|
||||
}
|
||||
if tdb.Stale(ctx, db.DATATYPE_STATE, "inky", k) {
|
||||
t.Fatal("expected not stale")
|
||||
}
|
||||
if tdb.Stale(ctx, db.DATATYPE_STATE, "blinky", k) {
|
||||
t.Fatal("expected not stale")
|
||||
}
|
||||
time.Sleep(time.Millisecond)
|
||||
if tdb.Stale(ctx, db.DATATYPE_USERDATA, "inky", k) {
|
||||
t.Fatal("expected not stale")
|
||||
}
|
||||
if !tdb.Stale(ctx, db.DATATYPE_STATE, "inky", k) {
|
||||
t.Fatal("expected stale")
|
||||
}
|
||||
if tdb.Stale(ctx, db.DATATYPE_STATE, "blinky", k) {
|
||||
t.Fatal("expected not stale")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilteredSameKeypartStaleDb(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mdb := memdb.NewMemDb()
|
||||
err := mdb.Connect(ctx, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tdb := NewTimedDb(mdb, time.Duration(time.Millisecond))
|
||||
tdb = tdb.WithMatch(db.DATATYPE_USERDATA, []byte("ba"))
|
||||
tdb.SetPrefix(db.DATATYPE_USERDATA)
|
||||
tdb.SetSession("xyzzy")
|
||||
err = tdb.Put(ctx, []byte("bar"), []byte("inky"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tdb.SetPrefix(db.DATATYPE_USERDATA)
|
||||
tdb.SetSession("xyzzy")
|
||||
err = tdb.Put(ctx, []byte("baz"), []byte("pinky"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tdb.SetPrefix(db.DATATYPE_USERDATA)
|
||||
tdb.SetSession("xyzzy")
|
||||
err = tdb.Put(ctx, []byte("foo"), []byte("blinky"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
time.Sleep(time.Millisecond)
|
||||
if !tdb.Stale(ctx, db.DATATYPE_USERDATA, "xyzzy", []byte("bar")) {
|
||||
t.Fatal("expected stale")
|
||||
}
|
||||
if !tdb.Stale(ctx, db.DATATYPE_USERDATA, "xyzzy", []byte("baz")) {
|
||||
t.Fatal("expected stale")
|
||||
}
|
||||
if tdb.Stale(ctx, db.DATATYPE_USERDATA, "xyzzy", []byte("foo")) {
|
||||
t.Fatal("expected not stale")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user