Skip to content

Commit

Permalink
treematcher: Optimization when checking for safer mode, and test it.
Browse files Browse the repository at this point in the history
  • Loading branch information
jordibc committed Nov 9, 2023
1 parent 9e3c7ff commit 1017ba1
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
5 changes: 3 additions & 2 deletions ete4/treematcher/treematcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def __init__(self, pattern='', children=None, parser=None, safer=False):
# Add the "code" property with its compiled condition.
self.props['code'] = compile(self.name or 'True', '<string>', 'eval')

self.safer = safer # will use to know if to use eval or safer_eval
for node in self.traverse(): # after init, needs to go to every node
node.safer = safer # will use to know if to use eval or safer_eval

def __str__(self):
return self.to_str(show_internal=True, props=['name'])
Expand Down Expand Up @@ -63,7 +64,7 @@ def match(pattern, node):
'any': any, 'all': all, 'len': len,
'sum': sum, 'abs': abs, 'float': float}

evaluate = safer_eval if pattern.root.safer else eval # risky business
evaluate = safer_eval if pattern.safer else eval # risky business
if not evaluate(pattern.props['code'], context):
return False # no match if the condition for this node if false

Expand Down
17 changes: 16 additions & 1 deletion tests/test_treematcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
Tests related to the treematcher module.
"""

from ete4 import Tree
from ete4 import Tree, PhyloTree
import ete4.treematcher as tm

import pytest


def strip(text):
"""Return the given text stripping the empty lines and indentation."""
Expand Down Expand Up @@ -62,3 +64,16 @@ def test_search():
assert ([n.name for n in tm.search(pattern, tree)] ==
[n.name for n in pattern.search(tree)] ==
expected_result)


def test_safer():
t = PhyloTree('(a,(b,c));')

tp_unsafe = tm.TreePattern('("node.get_species()=={\'c\'}",'
' node.species=="b")')
assert list(tp_unsafe.search(t)) == [t.common_ancestor(['b', 'c'])]

tp_safer = tm.TreePattern('("node.get_species()=={\'c\'}",'
' node.species=="b")', safer=True)
with pytest.raises(ValueError):
list(tp_safer.search(t)) # asked for unknown function get_species()

0 comments on commit 1017ba1

Please sign in to comment.