Type annotate _pytest.mark.evaluate

This commit is contained in:
Ran Benita 2020-05-01 14:40:16 +03:00
parent fc325bc0c3
commit 709bcbf3c4
1 changed files with 21 additions and 16 deletions

View File

@ -4,10 +4,14 @@ import sys
import traceback
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from ..outcomes import fail
from ..outcomes import TEST_OUTCOME
from .structures import Mark
from _pytest.config import Config
from _pytest.nodes import Item
from _pytest.store import StoreKey
@ -28,29 +32,29 @@ def cached_eval(config: Config, expr: str, d: Dict[str, object]) -> Any:
class MarkEvaluator:
def __init__(self, item, name):
def __init__(self, item: Item, name: str) -> None:
self.item = item
self._marks = None
self._mark = None
self._marks = None # type: Optional[List[Mark]]
self._mark = None # type: Optional[Mark]
self._mark_name = name
def __bool__(self):
def __bool__(self) -> bool:
# don't cache here to prevent staleness
return bool(self._get_marks())
def wasvalid(self):
def wasvalid(self) -> bool:
return not hasattr(self, "exc")
def _get_marks(self):
def _get_marks(self) -> List[Mark]:
return list(self.item.iter_markers(name=self._mark_name))
def invalidraise(self, exc):
def invalidraise(self, exc) -> Optional[bool]:
raises = self.get("raises")
if not raises:
return
return None
return not isinstance(exc, raises)
def istrue(self):
def istrue(self) -> bool:
try:
return self._istrue()
except TEST_OUTCOME:
@ -69,25 +73,26 @@ class MarkEvaluator:
pytrace=False,
)
def _getglobals(self):
def _getglobals(self) -> Dict[str, object]:
d = {"os": os, "sys": sys, "platform": platform, "config": self.item.config}
if hasattr(self.item, "obj"):
d.update(self.item.obj.__globals__)
d.update(self.item.obj.__globals__) # type: ignore[attr-defined] # noqa: F821
return d
def _istrue(self):
def _istrue(self) -> bool:
if hasattr(self, "result"):
return self.result
result = getattr(self, "result") # type: bool
return result
self._marks = self._get_marks()
if self._marks:
self.result = False
for mark in self._marks:
self._mark = mark
if "condition" in mark.kwargs:
args = (mark.kwargs["condition"],)
else:
if "condition" not in mark.kwargs:
args = mark.args
else:
args = (mark.kwargs["condition"],)
for expr in args:
self.expr = expr