diff --git a/src/_pytest/mark/expression.py b/src/_pytest/mark/expression.py index 008192d4a..04c73411a 100644 --- a/src/_pytest/mark/expression.py +++ b/src/_pytest/mark/expression.py @@ -15,10 +15,13 @@ The semantics are: - ident evaluates to True of False according to a provided matcher function. - or/and/not evaluate according to the usual boolean semantics. """ +import ast import enum import re +import types from typing import Callable from typing import Iterator +from typing import Mapping from typing import Optional from typing import Sequence @@ -31,7 +34,7 @@ if TYPE_CHECKING: __all__ = [ - "evaluate", + "Expression", "ParseError", ] @@ -124,50 +127,92 @@ class Scanner: ) -def expression(s: Scanner, matcher: Callable[[str], bool]) -> bool: +def expression(s: Scanner) -> ast.Expression: if s.accept(TokenType.EOF): - return False - ret = expr(s, matcher) - s.accept(TokenType.EOF, reject=True) - return ret + ret = ast.NameConstant(False) # type: ast.expr + else: + ret = expr(s) + s.accept(TokenType.EOF, reject=True) + return ast.fix_missing_locations(ast.Expression(ret)) -def expr(s: Scanner, matcher: Callable[[str], bool]) -> bool: - ret = and_expr(s, matcher) +def expr(s: Scanner) -> ast.expr: + ret = and_expr(s) while s.accept(TokenType.OR): - rhs = and_expr(s, matcher) - ret = ret or rhs + rhs = and_expr(s) + ret = ast.BoolOp(ast.Or(), [ret, rhs]) return ret -def and_expr(s: Scanner, matcher: Callable[[str], bool]) -> bool: - ret = not_expr(s, matcher) +def and_expr(s: Scanner) -> ast.expr: + ret = not_expr(s) while s.accept(TokenType.AND): - rhs = not_expr(s, matcher) - ret = ret and rhs + rhs = not_expr(s) + ret = ast.BoolOp(ast.And(), [ret, rhs]) return ret -def not_expr(s: Scanner, matcher: Callable[[str], bool]) -> bool: +def not_expr(s: Scanner) -> ast.expr: if s.accept(TokenType.NOT): - return not not_expr(s, matcher) + return ast.UnaryOp(ast.Not(), not_expr(s)) if s.accept(TokenType.LPAREN): - ret = expr(s, matcher) + ret = expr(s) s.accept(TokenType.RPAREN, reject=True) return ret ident = s.accept(TokenType.IDENT) if ident: - return matcher(ident.value) + return ast.Name(ident.value, ast.Load()) s.reject((TokenType.NOT, TokenType.LPAREN, TokenType.IDENT)) -def evaluate(input: str, matcher: Callable[[str], bool]) -> bool: - """Evaluate a match expression as used by -k and -m. +class MatcherAdapter(Mapping[str, bool]): + """Adapts a matcher function to a locals mapping as required by eval().""" - :param input: The input expression - one line. - :param matcher: Given an identifier, should return whether it matches or not. - Should be prepared to handle arbitrary strings as input. + def __init__(self, matcher: Callable[[str], bool]) -> None: + self.matcher = matcher - Returns whether the entire expression matches or not. + def __getitem__(self, key: str) -> bool: + return self.matcher(key) + + def __iter__(self) -> Iterator[str]: + raise NotImplementedError() + + def __len__(self) -> int: + raise NotImplementedError() + + +class Expression: + """A compiled match expression as used by -k and -m. + + The expression can be evaulated against different matchers. """ - return expression(Scanner(input), matcher) + + __slots__ = ("code",) + + def __init__(self, code: types.CodeType) -> None: + self.code = code + + @classmethod + def compile(self, input: str) -> "Expression": + """Compile a match expression. + + :param input: The input expression - one line. + """ + astexpr = expression(Scanner(input)) + code = compile( + astexpr, filename="", mode="eval", + ) # type: types.CodeType + return Expression(code) + + def evaluate(self, matcher: Callable[[str], bool]) -> bool: + """Evaluate the match expression. + + :param matcher: Given an identifier, should return whether it matches or not. + Should be prepared to handle arbitrary strings as input. + + Returns whether the expression matches or not. + """ + ret = eval( + self.code, {"__builtins__": {}}, MatcherAdapter(matcher) + ) # type: bool + return ret diff --git a/src/_pytest/mark/legacy.py b/src/_pytest/mark/legacy.py index 1a9fdee8d..ed707fcc7 100644 --- a/src/_pytest/mark/legacy.py +++ b/src/_pytest/mark/legacy.py @@ -8,7 +8,7 @@ import attr from _pytest.compat import TYPE_CHECKING from _pytest.config import UsageError -from _pytest.mark.expression import evaluate +from _pytest.mark.expression import Expression from _pytest.mark.expression import ParseError if TYPE_CHECKING: @@ -77,11 +77,12 @@ class KeywordMatcher: def matchmark(colitem, markexpr: str) -> bool: """Tries to match on any marker names, attached to the given colitem.""" try: - return evaluate(markexpr, MarkMatcher.from_item(colitem)) + expression = Expression.compile(markexpr) except ParseError as e: raise UsageError( "Wrong expression passed to '-m': {}: {}".format(markexpr, e) ) from None + return expression.evaluate(MarkMatcher.from_item(colitem)) def matchkeyword(colitem, keywordexpr: str) -> bool: @@ -94,8 +95,9 @@ def matchkeyword(colitem, keywordexpr: str) -> bool: any item, as well as names directly assigned to test functions. """ try: - return evaluate(keywordexpr, KeywordMatcher.from_item(colitem)) + expression = Expression.compile(keywordexpr) except ParseError as e: raise UsageError( "Wrong expression passed to '-k': {}: {}".format(keywordexpr, e) ) from None + return expression.evaluate(KeywordMatcher.from_item(colitem)) diff --git a/testing/test_mark_expression.py b/testing/test_mark_expression.py index 74576786d..335888618 100644 --- a/testing/test_mark_expression.py +++ b/testing/test_mark_expression.py @@ -1,8 +1,14 @@ +from typing import Callable + import pytest -from _pytest.mark.expression import evaluate +from _pytest.mark.expression import Expression from _pytest.mark.expression import ParseError +def evaluate(input: str, matcher: Callable[[str], bool]) -> bool: + return Expression.compile(input).evaluate(matcher) + + def test_empty_is_false() -> None: assert not evaluate("", lambda ident: False) assert not evaluate("", lambda ident: True)