diff --git a/src/_pytest/python.py b/src/_pytest/python.py index 45d3384df..e46d498ab 100644 --- a/src/_pytest/python.py +++ b/src/_pytest/python.py @@ -15,6 +15,7 @@ from typing import Callable from typing import Dict from typing import Iterable from typing import List +from typing import Mapping from typing import Optional from typing import Set from typing import Tuple @@ -44,6 +45,7 @@ from _pytest.compat import safe_isclass from _pytest.compat import STRING_TYPES from _pytest.config import Config from _pytest.config import ExitCode +from _pytest.compat import TYPE_CHECKING from _pytest.config import hookimpl from _pytest.config.argparsing import Parser from _pytest.deprecated import FUNCARGNAMES @@ -53,6 +55,7 @@ from _pytest.mark import MARK_GEN from _pytest.mark import ParameterSet from _pytest.mark.structures import get_unpacked_marks from _pytest.mark.structures import Mark +from _pytest.mark.structures import MarkDecorator from _pytest.mark.structures import normalize_mark_list from _pytest.outcomes import fail from _pytest.outcomes import skip @@ -60,6 +63,9 @@ from _pytest.pathlib import parts from _pytest.warning_types import PytestCollectionWarning from _pytest.warning_types import PytestUnhandledCoroutineWarning +if TYPE_CHECKING: + from typing_extensions import Literal + def pytest_addoption(parser: Parser) -> None: group = parser.getgroup("general") @@ -772,16 +778,17 @@ def hasnew(obj): class CallSpec2: - def __init__(self, metafunc): + def __init__(self, metafunc: "Metafunc") -> None: self.metafunc = metafunc - self.funcargs = {} - self._idlist = [] - self.params = {} - self._arg2scopenum = {} # used for sorting parametrized resources - self.marks = [] - self.indices = {} + self.funcargs = {} # type: Dict[str, object] + self._idlist = [] # type: List[str] + self.params = {} # type: Dict[str, object] + # Used for sorting parametrized resources. + self._arg2scopenum = {} # type: Dict[str, int] + self.marks = [] # type: List[Mark] + self.indices = {} # type: Dict[str, int] - def copy(self): + def copy(self) -> "CallSpec2": cs = CallSpec2(self.metafunc) cs.funcargs.update(self.funcargs) cs.params.update(self.params) @@ -791,25 +798,39 @@ class CallSpec2: cs._idlist = list(self._idlist) return cs - def _checkargnotcontained(self, arg): + def _checkargnotcontained(self, arg: str) -> None: if arg in self.params or arg in self.funcargs: raise ValueError("duplicate {!r}".format(arg)) - def getparam(self, name): + def getparam(self, name: str) -> object: try: return self.params[name] except KeyError: raise ValueError(name) @property - def id(self): + def id(self) -> str: return "-".join(map(str, self._idlist)) - def setmulti2(self, valtypes, argnames, valset, id, marks, scopenum, param_index): + def setmulti2( + self, + valtypes: "Mapping[str, Literal['params', 'funcargs']]", + argnames: typing.Sequence[str], + valset: Iterable[object], + id: str, + marks: Iterable[Union[Mark, MarkDecorator]], + scopenum: int, + param_index: int, + ) -> None: for arg, val in zip(argnames, valset): self._checkargnotcontained(arg) valtype_for_arg = valtypes[arg] - getattr(self, valtype_for_arg)[arg] = val + if valtype_for_arg == "params": + self.params[arg] = val + elif valtype_for_arg == "funcargs": + self.funcargs[arg] = val + else: # pragma: no cover + assert False, "Unhandled valtype for arg: {}".format(valtype_for_arg) self.indices[arg] = param_index self._arg2scopenum[arg] = scopenum self._idlist.append(id) @@ -1049,7 +1070,7 @@ class Metafunc: self, argnames: typing.Sequence[str], indirect: Union[bool, typing.Sequence[str]], - ) -> Dict[str, str]: + ) -> Dict[str, "Literal['params', 'funcargs']"]: """Resolves if each parametrized argument must be considered a parameter to a fixture or a "funcarg" to the function, based on the ``indirect`` parameter of the parametrized() call. @@ -1061,7 +1082,9 @@ class Metafunc: * "funcargs" if the argname should be a parameter to the parametrized test function. """ if isinstance(indirect, bool): - valtypes = dict.fromkeys(argnames, "params" if indirect else "funcargs") + valtypes = dict.fromkeys( + argnames, "params" if indirect else "funcargs" + ) # type: Dict[str, Literal["params", "funcargs"]] elif isinstance(indirect, Sequence): valtypes = dict.fromkeys(argnames, "funcargs") for arg in indirect: