Add wrong password test

This commit is contained in:
nolash 2020-08-05 19:54:19 +02:00
parent a845aecda1
commit fa7b3ca774
Signed by: lash
GPG Key ID: 93EC1C676274C889
2 changed files with 4 additions and 7 deletions

View File

@ -24,7 +24,6 @@ class ReferenceDatabase:
def __init__(self, dbname, **kwargs): def __init__(self, dbname, **kwargs):
logg.debug(kwargs)
self.conn = psycopg2.connect('dbname='+dbname) self.conn = psycopg2.connect('dbname='+dbname)
self.cur = self.conn.cursor() self.cur = self.conn.cursor()
self.symmetric_key = kwargs.get('symmetric_key') self.symmetric_key = kwargs.get('symmetric_key')
@ -32,7 +31,6 @@ class ReferenceDatabase:
def get(self, address, password=None): def get(self, address, password=None):
s = sql.SQL('SELECT key_ciphertext FROM ethereum WHERE wallet_address_hex = %s') s = sql.SQL('SELECT key_ciphertext FROM ethereum WHERE wallet_address_hex = %s')
logg.debug(address)
self.cur.execute(s, [ address ] ) self.cur.execute(s, [ address ] )
k = self.cur.fetchone()[0] k = self.cur.fetchone()[0]
return self._decrypt(k, password) return self._decrypt(k, password)
@ -41,9 +39,7 @@ class ReferenceDatabase:
def new(self, address, password=None): def new(self, address, password=None):
b = os.urandom(32) b = os.urandom(32)
pk = keyapi.PrivateKey(b) pk = keyapi.PrivateKey(b)
logg.debug('pk {}'.format(pk.to_hex()))
c = self._encrypt(pk.to_bytes(), password) c = self._encrypt(pk.to_bytes(), password)
logg.debug('pkc {} {}'.format(c, len(pk.to_bytes())))
s = sql.SQL('INSERT INTO ethereum (wallet_address_hex, key_ciphertext) VALUES (%s, %s)') s = sql.SQL('INSERT INTO ethereum (wallet_address_hex, key_ciphertext) VALUES (%s, %s)')
self.cur.execute(s, [ address, c.decode('utf-8') ]) self.cur.execute(s, [ address, c.decode('utf-8') ])

View File

@ -6,7 +6,7 @@ import base64
import psycopg2 import psycopg2
from psycopg2 import sql from psycopg2 import sql
from cryptography.fernet import Fernet from cryptography.fernet import Fernet, InvalidToken
from keystore import ReferenceDatabase from keystore import ReferenceDatabase
@ -52,8 +52,9 @@ class TestDatabase(unittest.TestCase):
def test_get_key(self): def test_get_key(self):
pk = self.db.get(self.address_hex, 'foo') self.db.get(self.address_hex, 'foo')
logg.info('pk {}'.format(pk.hex())) with self.assertRaises(InvalidToken):
self.db.get(self.address_hex, 'bar')
if __name__ == '__main__': if __name__ == '__main__':