From 636b60f6f6791cc18ba91da4cdecbe3aabc3dca0 Mon Sep 17 00:00:00 2001 From: nolash Date: Wed, 5 Aug 2020 18:14:25 +0200 Subject: [PATCH] Add postgres with encryption --- src/keystore/__init__.py | 1 + src/keystore/postgres.py | 43 ++++++++++++++++++++++ src/signer/defaultsigner.py | 9 +++-- test/test_database.py | 73 +++++++++++++++++++++++++++++++++++++ 4 files changed, 122 insertions(+), 4 deletions(-) create mode 100644 src/keystore/__init__.py create mode 100644 src/keystore/postgres.py create mode 100644 test/test_database.py diff --git a/src/keystore/__init__.py b/src/keystore/__init__.py new file mode 100644 index 0000000..e11b275 --- /dev/null +++ b/src/keystore/__init__.py @@ -0,0 +1 @@ +from keystore.postgres import ReferenceDatabase diff --git a/src/keystore/postgres.py b/src/keystore/postgres.py new file mode 100644 index 0000000..9588080 --- /dev/null +++ b/src/keystore/postgres.py @@ -0,0 +1,43 @@ +import logging +import base64 + +from cryptography.fernet import Fernet +import psycopg2 +from psycopg2 import sql + +logging.basicConfig(level=logging.DEBUG) +logg = logging.getLogger(__file__) + + +class ReferenceDatabase: + + + def __init__(self, dbname, **kwargs): + logg.debug(kwargs) + self.conn = psycopg2.connect('dbname='+dbname) + self.cur = self.conn.cursor() + self.cryptengine = None + if kwargs.get('symmetric_key') != None: + be = kwargs.get('symmetric_key') + self.cryptengine = Fernet(base64.b64encode(be)) + + + def get(self, address): + s = sql.SQL('SELECT key_ciphertext FROM ethereum WHERE wallet_address_hex = %s') + logg.debug(address) + self.cur.execute(s, [ address ] ) + k = self.cur.fetchone()[0] + return self.decrypt(k) + + + def decrypt(self, c): + if self.cryptengine == None: + return c + logg.debug('decryption') + return self.cryptengine.decrypt(c.encode('utf-8')) + + + def __exit__(self): + self.conn + self.cur.close() + self.conn.close() diff --git a/src/signer/defaultsigner.py b/src/signer/defaultsigner.py index 304c268..e4e886a 100644 --- a/src/signer/defaultsigner.py +++ b/src/signer/defaultsigner.py @@ -10,15 +10,19 @@ logg = logging.getLogger(__name__) class Signer: + def __init__(self, keyGetter): self.keyGetter = keyGetter + def signTransaction(self, tx): raise NotImplementedError + class ReferenceSigner(Signer): - + + def __init__(self, keyGetter): super(ReferenceSigner, self).__init__(keyGetter) @@ -34,6 +38,3 @@ class ReferenceSigner(Signer): tx.r = z[:32] tx.s = z[32:64] return z - - - diff --git a/test/test_database.py b/test/test_database.py new file mode 100644 index 0000000..40b16ec --- /dev/null +++ b/test/test_database.py @@ -0,0 +1,73 @@ +#!/usr/bin/python + +import unittest +import logging +import base64 + +import psycopg2 +from psycopg2 import sql +from cryptography.fernet import Fernet + +from keystore import ReferenceDatabase + +logging.basicConfig(level=logging.DEBUG) +logg = logging.getLogger() + + +class TestDatabase(unittest.TestCase): + + conn = None + cur = None + symkey = None + addr = None + db = None + pk = None + + def setUp(self): + # arbitrary value + symk_hex = 'E92431CAEE69313A7BE9E443C4ABEED9BF8157E9A13553B4D5D6E7D51B5021D9' + self.symkey = bytes.fromhex(symk_hex) + f = Fernet(base64.b64encode(self.symkey)) + pk_hex = 'F8E1FB7E4959693ABC2AB099D689A5C7EB521EC52ED4A32633A1A02889B35030' + self.pk = bytes.fromhex(pk_hex) + pk_ciphertext = f.encrypt(self.pk) + self.addr = '9FA61f0E52A5C51b43f0d32404625BC436bb7041' + + kw = { + 'symmetric_key': self.symkey, + } + self.db = ReferenceDatabase('signer_test', **kw) + self.db.cur.execute("""CREATE TABLE ethereum ( + id SERIAL NOT NULL PRIMARY KEY, + key_ciphertext VARCHAR(256) NOT NULL, + wallet_address_hex CHAR(40) NOT NULL + ); +""") + self.db.conn.commit() + self.db.cur.execute("CREATE UNIQUE INDEX ethereum_address_idx ON ethereum ( wallet_address_hex );") + + self.db.cur.execute( + sql.SQL('INSERT INTO ethereum (key_ciphertext, wallet_address_hex) VALUES (%s, %s)'), + [ + pk_ciphertext.decode('utf-8'), + self.addr, + ], + ) + self.db.conn.commit() + + + def tearDown(self): + self.db.conn = psycopg2.connect('dbname=signer_test') + self.db.cur = self.db.conn.cursor() + self.db.cur.execute('DROP INDEX ethereum_address_idx;') + self.db.cur.execute('DROP TABLE ethereum;') + self.db.conn.commit() + + + def test_get_key(self): + pk = self.db.get(self.addr) + self.assertEqual(self.pk, pk) + + +if __name__ == '__main__': + unittest.main()