Type annotate _pytest.mark.evaluate

This commit is contained in:
Ran Benita 2020-05-01 14:40:16 +03:00
parent fc325bc0c3
commit 709bcbf3c4
1 changed files with 21 additions and 16 deletions

View File

@ -4,10 +4,14 @@ import sys
import traceback import traceback
from typing import Any from typing import Any
from typing import Dict from typing import Dict
from typing import List
from typing import Optional
from ..outcomes import fail from ..outcomes import fail
from ..outcomes import TEST_OUTCOME from ..outcomes import TEST_OUTCOME
from .structures import Mark
from _pytest.config import Config from _pytest.config import Config
from _pytest.nodes import Item
from _pytest.store import StoreKey from _pytest.store import StoreKey
@ -28,29 +32,29 @@ def cached_eval(config: Config, expr: str, d: Dict[str, object]) -> Any:
class MarkEvaluator: class MarkEvaluator:
def __init__(self, item, name): def __init__(self, item: Item, name: str) -> None:
self.item = item self.item = item
self._marks = None self._marks = None # type: Optional[List[Mark]]
self._mark = None self._mark = None # type: Optional[Mark]
self._mark_name = name self._mark_name = name
def __bool__(self): def __bool__(self) -> bool:
# don't cache here to prevent staleness # don't cache here to prevent staleness
return bool(self._get_marks()) return bool(self._get_marks())
def wasvalid(self): def wasvalid(self) -> bool:
return not hasattr(self, "exc") return not hasattr(self, "exc")
def _get_marks(self): def _get_marks(self) -> List[Mark]:
return list(self.item.iter_markers(name=self._mark_name)) return list(self.item.iter_markers(name=self._mark_name))
def invalidraise(self, exc): def invalidraise(self, exc) -> Optional[bool]:
raises = self.get("raises") raises = self.get("raises")
if not raises: if not raises:
return return None
return not isinstance(exc, raises) return not isinstance(exc, raises)
def istrue(self): def istrue(self) -> bool:
try: try:
return self._istrue() return self._istrue()
except TEST_OUTCOME: except TEST_OUTCOME:
@ -69,25 +73,26 @@ class MarkEvaluator:
pytrace=False, pytrace=False,
) )
def _getglobals(self): def _getglobals(self) -> Dict[str, object]:
d = {"os": os, "sys": sys, "platform": platform, "config": self.item.config} d = {"os": os, "sys": sys, "platform": platform, "config": self.item.config}
if hasattr(self.item, "obj"): if hasattr(self.item, "obj"):
d.update(self.item.obj.__globals__) d.update(self.item.obj.__globals__) # type: ignore[attr-defined] # noqa: F821
return d return d
def _istrue(self): def _istrue(self) -> bool:
if hasattr(self, "result"): if hasattr(self, "result"):
return self.result result = getattr(self, "result") # type: bool
return result
self._marks = self._get_marks() self._marks = self._get_marks()
if self._marks: if self._marks:
self.result = False self.result = False
for mark in self._marks: for mark in self._marks:
self._mark = mark self._mark = mark
if "condition" in mark.kwargs: if "condition" not in mark.kwargs:
args = (mark.kwargs["condition"],)
else:
args = mark.args args = mark.args
else:
args = (mark.kwargs["condition"],)
for expr in args: for expr in args:
self.expr = expr self.expr = expr