mark/expression: support compiling once and reusing for multiple evaluations

In current pytest, the same expression is matched against all items. But
it is re-parsed for every match.

Add support for "compiling" an expression and reusing the result. Errors
may only occur during compilation.

This is done by parsing the expression into a Python `ast.Expression`,
then `compile()`ing it into a code object. Evaluation is then done using
`eval()`.

Note: historically we used to use `eval` directly on the user input --
this is not the case here, the expression is entirely under our control
according to our grammar, we just JIT-compile it to Python as a
(completely safe) optimization.
This commit is contained in:
Ran Benita 2020-05-11 11:20:43 +03:00
parent 952762207a
commit 622c4ce02e
3 changed files with 82 additions and 29 deletions

View File

@ -15,10 +15,13 @@ The semantics are:
- ident evaluates to True of False according to a provided matcher function.
- or/and/not evaluate according to the usual boolean semantics.
"""
import ast
import enum
import re
import types
from typing import Callable
from typing import Iterator
from typing import Mapping
from typing import Optional
from typing import Sequence
@ -31,7 +34,7 @@ if TYPE_CHECKING:
__all__ = [
"evaluate",
"Expression",
"ParseError",
]
@ -124,50 +127,92 @@ class Scanner:
)
def expression(s: Scanner, matcher: Callable[[str], bool]) -> bool:
def expression(s: Scanner) -> ast.Expression:
if s.accept(TokenType.EOF):
return False
ret = expr(s, matcher)
s.accept(TokenType.EOF, reject=True)
return ret
ret = ast.NameConstant(False) # type: ast.expr
else:
ret = expr(s)
s.accept(TokenType.EOF, reject=True)
return ast.fix_missing_locations(ast.Expression(ret))
def expr(s: Scanner, matcher: Callable[[str], bool]) -> bool:
ret = and_expr(s, matcher)
def expr(s: Scanner) -> ast.expr:
ret = and_expr(s)
while s.accept(TokenType.OR):
rhs = and_expr(s, matcher)
ret = ret or rhs
rhs = and_expr(s)
ret = ast.BoolOp(ast.Or(), [ret, rhs])
return ret
def and_expr(s: Scanner, matcher: Callable[[str], bool]) -> bool:
ret = not_expr(s, matcher)
def and_expr(s: Scanner) -> ast.expr:
ret = not_expr(s)
while s.accept(TokenType.AND):
rhs = not_expr(s, matcher)
ret = ret and rhs
rhs = not_expr(s)
ret = ast.BoolOp(ast.And(), [ret, rhs])
return ret
def not_expr(s: Scanner, matcher: Callable[[str], bool]) -> bool:
def not_expr(s: Scanner) -> ast.expr:
if s.accept(TokenType.NOT):
return not not_expr(s, matcher)
return ast.UnaryOp(ast.Not(), not_expr(s))
if s.accept(TokenType.LPAREN):
ret = expr(s, matcher)
ret = expr(s)
s.accept(TokenType.RPAREN, reject=True)
return ret
ident = s.accept(TokenType.IDENT)
if ident:
return matcher(ident.value)
return ast.Name(ident.value, ast.Load())
s.reject((TokenType.NOT, TokenType.LPAREN, TokenType.IDENT))
def evaluate(input: str, matcher: Callable[[str], bool]) -> bool:
"""Evaluate a match expression as used by -k and -m.
class MatcherAdapter(Mapping[str, bool]):
"""Adapts a matcher function to a locals mapping as required by eval()."""
:param input: The input expression - one line.
:param matcher: Given an identifier, should return whether it matches or not.
Should be prepared to handle arbitrary strings as input.
def __init__(self, matcher: Callable[[str], bool]) -> None:
self.matcher = matcher
Returns whether the entire expression matches or not.
def __getitem__(self, key: str) -> bool:
return self.matcher(key)
def __iter__(self) -> Iterator[str]:
raise NotImplementedError()
def __len__(self) -> int:
raise NotImplementedError()
class Expression:
"""A compiled match expression as used by -k and -m.
The expression can be evaulated against different matchers.
"""
return expression(Scanner(input), matcher)
__slots__ = ("code",)
def __init__(self, code: types.CodeType) -> None:
self.code = code
@classmethod
def compile(self, input: str) -> "Expression":
"""Compile a match expression.
:param input: The input expression - one line.
"""
astexpr = expression(Scanner(input))
code = compile(
astexpr, filename="<pytest match expression>", mode="eval",
) # type: types.CodeType
return Expression(code)
def evaluate(self, matcher: Callable[[str], bool]) -> bool:
"""Evaluate the match expression.
:param matcher: Given an identifier, should return whether it matches or not.
Should be prepared to handle arbitrary strings as input.
Returns whether the expression matches or not.
"""
ret = eval(
self.code, {"__builtins__": {}}, MatcherAdapter(matcher)
) # type: bool
return ret

View File

@ -8,7 +8,7 @@ import attr
from _pytest.compat import TYPE_CHECKING
from _pytest.config import UsageError
from _pytest.mark.expression import evaluate
from _pytest.mark.expression import Expression
from _pytest.mark.expression import ParseError
if TYPE_CHECKING:
@ -77,11 +77,12 @@ class KeywordMatcher:
def matchmark(colitem, markexpr: str) -> bool:
"""Tries to match on any marker names, attached to the given colitem."""
try:
return evaluate(markexpr, MarkMatcher.from_item(colitem))
expression = Expression.compile(markexpr)
except ParseError as e:
raise UsageError(
"Wrong expression passed to '-m': {}: {}".format(markexpr, e)
) from None
return expression.evaluate(MarkMatcher.from_item(colitem))
def matchkeyword(colitem, keywordexpr: str) -> bool:
@ -94,8 +95,9 @@ def matchkeyword(colitem, keywordexpr: str) -> bool:
any item, as well as names directly assigned to test functions.
"""
try:
return evaluate(keywordexpr, KeywordMatcher.from_item(colitem))
expression = Expression.compile(keywordexpr)
except ParseError as e:
raise UsageError(
"Wrong expression passed to '-k': {}: {}".format(keywordexpr, e)
) from None
return expression.evaluate(KeywordMatcher.from_item(colitem))

View File

@ -1,8 +1,14 @@
from typing import Callable
import pytest
from _pytest.mark.expression import evaluate
from _pytest.mark.expression import Expression
from _pytest.mark.expression import ParseError
def evaluate(input: str, matcher: Callable[[str], bool]) -> bool:
return Expression.compile(input).evaluate(matcher)
def test_empty_is_false() -> None:
assert not evaluate("", lambda ident: False)
assert not evaluate("", lambda ident: True)