mark/expression: support compiling once and reusing for multiple evaluations

In current pytest, the same expression is matched against all items. But
it is re-parsed for every match.

Add support for "compiling" an expression and reusing the result. Errors
may only occur during compilation.

This is done by parsing the expression into a Python `ast.Expression`,
then `compile()`ing it into a code object. Evaluation is then done using
`eval()`.

Note: historically we used to use `eval` directly on the user input --
this is not the case here, the expression is entirely under our control
according to our grammar, we just JIT-compile it to Python as a
(completely safe) optimization.
This commit is contained in:
Ran Benita 2020-05-11 11:20:43 +03:00
parent 952762207a
commit 622c4ce02e
3 changed files with 82 additions and 29 deletions

View File

@ -15,10 +15,13 @@ The semantics are:
- ident evaluates to True of False according to a provided matcher function. - ident evaluates to True of False according to a provided matcher function.
- or/and/not evaluate according to the usual boolean semantics. - or/and/not evaluate according to the usual boolean semantics.
""" """
import ast
import enum import enum
import re import re
import types
from typing import Callable from typing import Callable
from typing import Iterator from typing import Iterator
from typing import Mapping
from typing import Optional from typing import Optional
from typing import Sequence from typing import Sequence
@ -31,7 +34,7 @@ if TYPE_CHECKING:
__all__ = [ __all__ = [
"evaluate", "Expression",
"ParseError", "ParseError",
] ]
@ -124,50 +127,92 @@ class Scanner:
) )
def expression(s: Scanner, matcher: Callable[[str], bool]) -> bool: def expression(s: Scanner) -> ast.Expression:
if s.accept(TokenType.EOF): if s.accept(TokenType.EOF):
return False ret = ast.NameConstant(False) # type: ast.expr
ret = expr(s, matcher) else:
s.accept(TokenType.EOF, reject=True) ret = expr(s)
return ret s.accept(TokenType.EOF, reject=True)
return ast.fix_missing_locations(ast.Expression(ret))
def expr(s: Scanner, matcher: Callable[[str], bool]) -> bool: def expr(s: Scanner) -> ast.expr:
ret = and_expr(s, matcher) ret = and_expr(s)
while s.accept(TokenType.OR): while s.accept(TokenType.OR):
rhs = and_expr(s, matcher) rhs = and_expr(s)
ret = ret or rhs ret = ast.BoolOp(ast.Or(), [ret, rhs])
return ret return ret
def and_expr(s: Scanner, matcher: Callable[[str], bool]) -> bool: def and_expr(s: Scanner) -> ast.expr:
ret = not_expr(s, matcher) ret = not_expr(s)
while s.accept(TokenType.AND): while s.accept(TokenType.AND):
rhs = not_expr(s, matcher) rhs = not_expr(s)
ret = ret and rhs ret = ast.BoolOp(ast.And(), [ret, rhs])
return ret return ret
def not_expr(s: Scanner, matcher: Callable[[str], bool]) -> bool: def not_expr(s: Scanner) -> ast.expr:
if s.accept(TokenType.NOT): if s.accept(TokenType.NOT):
return not not_expr(s, matcher) return ast.UnaryOp(ast.Not(), not_expr(s))
if s.accept(TokenType.LPAREN): if s.accept(TokenType.LPAREN):
ret = expr(s, matcher) ret = expr(s)
s.accept(TokenType.RPAREN, reject=True) s.accept(TokenType.RPAREN, reject=True)
return ret return ret
ident = s.accept(TokenType.IDENT) ident = s.accept(TokenType.IDENT)
if ident: if ident:
return matcher(ident.value) return ast.Name(ident.value, ast.Load())
s.reject((TokenType.NOT, TokenType.LPAREN, TokenType.IDENT)) s.reject((TokenType.NOT, TokenType.LPAREN, TokenType.IDENT))
def evaluate(input: str, matcher: Callable[[str], bool]) -> bool: class MatcherAdapter(Mapping[str, bool]):
"""Evaluate a match expression as used by -k and -m. """Adapts a matcher function to a locals mapping as required by eval()."""
:param input: The input expression - one line. def __init__(self, matcher: Callable[[str], bool]) -> None:
:param matcher: Given an identifier, should return whether it matches or not. self.matcher = matcher
Should be prepared to handle arbitrary strings as input.
Returns whether the entire expression matches or not. def __getitem__(self, key: str) -> bool:
return self.matcher(key)
def __iter__(self) -> Iterator[str]:
raise NotImplementedError()
def __len__(self) -> int:
raise NotImplementedError()
class Expression:
"""A compiled match expression as used by -k and -m.
The expression can be evaulated against different matchers.
""" """
return expression(Scanner(input), matcher)
__slots__ = ("code",)
def __init__(self, code: types.CodeType) -> None:
self.code = code
@classmethod
def compile(self, input: str) -> "Expression":
"""Compile a match expression.
:param input: The input expression - one line.
"""
astexpr = expression(Scanner(input))
code = compile(
astexpr, filename="<pytest match expression>", mode="eval",
) # type: types.CodeType
return Expression(code)
def evaluate(self, matcher: Callable[[str], bool]) -> bool:
"""Evaluate the match expression.
:param matcher: Given an identifier, should return whether it matches or not.
Should be prepared to handle arbitrary strings as input.
Returns whether the expression matches or not.
"""
ret = eval(
self.code, {"__builtins__": {}}, MatcherAdapter(matcher)
) # type: bool
return ret

View File

@ -8,7 +8,7 @@ import attr
from _pytest.compat import TYPE_CHECKING from _pytest.compat import TYPE_CHECKING
from _pytest.config import UsageError from _pytest.config import UsageError
from _pytest.mark.expression import evaluate from _pytest.mark.expression import Expression
from _pytest.mark.expression import ParseError from _pytest.mark.expression import ParseError
if TYPE_CHECKING: if TYPE_CHECKING:
@ -77,11 +77,12 @@ class KeywordMatcher:
def matchmark(colitem, markexpr: str) -> bool: def matchmark(colitem, markexpr: str) -> bool:
"""Tries to match on any marker names, attached to the given colitem.""" """Tries to match on any marker names, attached to the given colitem."""
try: try:
return evaluate(markexpr, MarkMatcher.from_item(colitem)) expression = Expression.compile(markexpr)
except ParseError as e: except ParseError as e:
raise UsageError( raise UsageError(
"Wrong expression passed to '-m': {}: {}".format(markexpr, e) "Wrong expression passed to '-m': {}: {}".format(markexpr, e)
) from None ) from None
return expression.evaluate(MarkMatcher.from_item(colitem))
def matchkeyword(colitem, keywordexpr: str) -> bool: def matchkeyword(colitem, keywordexpr: str) -> bool:
@ -94,8 +95,9 @@ def matchkeyword(colitem, keywordexpr: str) -> bool:
any item, as well as names directly assigned to test functions. any item, as well as names directly assigned to test functions.
""" """
try: try:
return evaluate(keywordexpr, KeywordMatcher.from_item(colitem)) expression = Expression.compile(keywordexpr)
except ParseError as e: except ParseError as e:
raise UsageError( raise UsageError(
"Wrong expression passed to '-k': {}: {}".format(keywordexpr, e) "Wrong expression passed to '-k': {}: {}".format(keywordexpr, e)
) from None ) from None
return expression.evaluate(KeywordMatcher.from_item(colitem))

View File

@ -1,8 +1,14 @@
from typing import Callable
import pytest import pytest
from _pytest.mark.expression import evaluate from _pytest.mark.expression import Expression
from _pytest.mark.expression import ParseError from _pytest.mark.expression import ParseError
def evaluate(input: str, matcher: Callable[[str], bool]) -> bool:
return Expression.compile(input).evaluate(matcher)
def test_empty_is_false() -> None: def test_empty_is_false() -> None:
assert not evaluate("", lambda ident: False) assert not evaluate("", lambda ident: False)
assert not evaluate("", lambda ident: True) assert not evaluate("", lambda ident: True)