diff --git a/src/keystore/postgres.py b/src/keystore/postgres.py index 9588080..c7c0c16 100644 --- a/src/keystore/postgres.py +++ b/src/keystore/postgres.py @@ -1,9 +1,15 @@ import logging import base64 +import os from cryptography.fernet import Fernet import psycopg2 from psycopg2 import sql +from eth_keys import KeyAPI +from eth_keys.backends import NativeECCBackend +import sha3 + +keyapi = KeyAPI(NativeECCBackend) logging.basicConfig(level=logging.DEBUG) logg = logging.getLogger(__file__) @@ -16,25 +22,44 @@ class ReferenceDatabase: logg.debug(kwargs) self.conn = psycopg2.connect('dbname='+dbname) self.cur = self.conn.cursor() - self.cryptengine = None - if kwargs.get('symmetric_key') != None: - be = kwargs.get('symmetric_key') - self.cryptengine = Fernet(base64.b64encode(be)) + self.symmetric_key = kwargs.get('symmetric_key') - def get(self, address): + def get(self, address, password=None): s = sql.SQL('SELECT key_ciphertext FROM ethereum WHERE wallet_address_hex = %s') logg.debug(address) self.cur.execute(s, [ address ] ) k = self.cur.fetchone()[0] - return self.decrypt(k) + return self._decrypt(k, password) - def decrypt(self, c): - if self.cryptengine == None: - return c - logg.debug('decryption') - return self.cryptengine.decrypt(c.encode('utf-8')) + def new(self, address, password=None): + b = os.urandom(32) + pk = keyapi.PrivateKey(b) + logg.debug('pk {}'.format(pk.to_hex())) + c = self._encrypt(pk.to_bytes(), password) + logg.debug('pkc {} {}'.format(c, len(pk.to_bytes()))) + s = sql.SQL('INSERT INTO ethereum (wallet_address_hex, key_ciphertext) VALUES (%s, %s)') + self.cur.execute(s, [ address, c.decode('utf-8') ]) + + + def _encrypt(self, private_key, password): + f = self._generate_encryption_engine(password) + return f.encrypt(private_key) + + + def _generate_encryption_engine(self, password): + h = sha3.keccak_256() + h.update(self.symmetric_key) + if password != None: + h.update(password) + g = h.digest() + return Fernet(base64.b64encode(g)) + + + def _decrypt(self, c, password): + f = self._generate_encryption_engine(password) + return f.decrypt(c.encode('utf-8')) def __exit__(self): diff --git a/test/sign.py b/test/sign.py index 43eb05f..814dcf4 100644 --- a/test/sign.py +++ b/test/sign.py @@ -58,7 +58,7 @@ class TestSign(unittest.TestCase): z = s.signTransaction(t) logg.debug('{}'.format(z.to_bytes())) logg.debug('{}'.format(t.serialize().hex())) - + if __name__ == '__main__': unittest.main() diff --git a/test/test_database.py b/test/test_database.py index 40b16ec..0c0df90 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -19,19 +19,14 @@ class TestDatabase(unittest.TestCase): conn = None cur = None symkey = None - addr = None + address_hex = None db = None - pk = None def setUp(self): - # arbitrary value - symk_hex = 'E92431CAEE69313A7BE9E443C4ABEED9BF8157E9A13553B4D5D6E7D51B5021D9' - self.symkey = bytes.fromhex(symk_hex) - f = Fernet(base64.b64encode(self.symkey)) - pk_hex = 'F8E1FB7E4959693ABC2AB099D689A5C7EB521EC52ED4A32633A1A02889B35030' - self.pk = bytes.fromhex(pk_hex) - pk_ciphertext = f.encrypt(self.pk) - self.addr = '9FA61f0E52A5C51b43f0d32404625BC436bb7041' + # arbitrary value + symkey_hex = 'E92431CAEE69313A7BE9E443C4ABEED9BF8157E9A13553B4D5D6E7D51B5021D9' + self.symkey = bytes.fromhex(symkey_hex) + self.address_hex = '9FA61f0E52A5C51b43f0d32404625BC436bb7041' kw = { 'symmetric_key': self.symkey, @@ -46,15 +41,17 @@ class TestDatabase(unittest.TestCase): self.db.conn.commit() self.db.cur.execute("CREATE UNIQUE INDEX ethereum_address_idx ON ethereum ( wallet_address_hex );") - self.db.cur.execute( - sql.SQL('INSERT INTO ethereum (key_ciphertext, wallet_address_hex) VALUES (%s, %s)'), - [ - pk_ciphertext.decode('utf-8'), - self.addr, - ], - ) +# self.db.cur.execute( +# sql.SQL('INSERT INTO ethereum (key_ciphertext, wallet_address_hex) VALUES (%s, %s)'), +# [ +# pk_ciphertext.decode('utf-8'), +# self.addr, +# ], +# ) self.db.conn.commit() + self.db.new(self.address_hex) + def tearDown(self): self.db.conn = psycopg2.connect('dbname=signer_test') @@ -65,8 +62,8 @@ class TestDatabase(unittest.TestCase): def test_get_key(self): - pk = self.db.get(self.addr) - self.assertEqual(self.pk, pk) + pk = self.db.get(self.address_hex) + logg.info('pk {}'.format(pk.hex())) if __name__ == '__main__':