From fa694c957b19ca1a0b507f6c082d8f7ccd8276b0 Mon Sep 17 00:00:00 2001 From: lash Date: Thu, 17 Aug 2023 10:28:27 +0100 Subject: [PATCH] Start filter test writing --- eth_monitor/rules.py | 24 +++++++++++++++++++----- tests/test_filter.py | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 5 deletions(-) create mode 100644 tests/test_filter.py diff --git a/eth_monitor/rules.py b/eth_monitor/rules.py index 2423154..09074a4 100644 --- a/eth_monitor/rules.py +++ b/eth_monitor/rules.py @@ -71,16 +71,19 @@ class RuleMethod: class RuleSimple: - def __init__(self, outputs, inputs, executables, description=None): + def __init__(self, outputs, inputs, executables, description=None, match_all=False): self.description = description if self.description == None: self.description = str(uuid.uuid4()) self.outputs = outputs self.inputs = inputs self.executables = executables + self.match_all = match_all - + def check(self, sender, recipient, data, tx_hash): + have_fail = False + have_match = False for rule in self.outputs: if rule != None and is_same_address(sender, rule): logg.debug('tx {} rule {} match in SENDER {}'.format(tx_hash, self.description, sender)) @@ -107,10 +110,11 @@ class RuleSimple: class AddressRules: - def __init__(self, include_by_default=False): + def __init__(self, include_by_default=False, match_all=False): self.excludes = [] self.includes = [] self.include_by_default = include_by_default + self.match_all = match_all def exclude(self, rule): @@ -130,17 +134,27 @@ class AddressRules: # TODO: rename def apply_rules_addresses(self, sender, recipient, data, tx_hash): v = self.include_by_default + have_fail = False + have_match = False for rule in self.includes: if rule.check(sender, recipient, data, tx_hash): v = True logg.info('match in includes rule: {}'.format(rule)) + if not self.match_all: + break + elif self.match_all: + v = False break + if not v: + return v + for rule in self.excludes: if rule.check(sender, recipient, data, tx_hash): v = False logg.info('match in excludes rule: {}'.format(rule)) - break - + if not self.match_all: + break + return v diff --git a/tests/test_filter.py b/tests/test_filter.py new file mode 100644 index 0000000..b28c84e --- /dev/null +++ b/tests/test_filter.py @@ -0,0 +1,40 @@ +# standard imports +import logging +import unittest +import os + +# local imports +from eth_monitor.rules import * + +logging.basicConfig(level=logging.DEBUG) +logg = logging.getLogger() + + +class TestRule(unittest.TestCase): + + def setUp(self): + self.alice = os.urandom(20).hex() + self.bob = os.urandom(20).hex() + self.carol = os.urandom(20).hex() + self.dave = os.urandom(20).hex() + self.x = os.urandom(20).hex() + self.y = os.urandom(20).hex() + self.hsh = os.urandom(32).hex() + + + def test_address_include(self): + outs = [self.alice] + ins = [] + execs = [] + rule = RuleSimple(outs, ins, execs) + c = AddressRules() + c.include(rule) + data = b'' + r = c.apply_rules_addresses(self.alice, self.bob, data, self.hsh) + self.assertTrue(r) + r = c.apply_rules_addresses(self.bob, self.alice, data, self.hsh) + self.assertFalse(r) + + +if __name__ == '__main__': + unittest.main()