Type annotate _pytest.unittest

This commit is contained in:
Ran Benita 2020-05-01 14:40:16 +03:00
parent db52928684
commit b51ea4f1a5
1 changed files with 60 additions and 26 deletions

View File

@ -1,17 +1,23 @@
""" discovery and running of std-library "unittest" style tests. """ """ discovery and running of std-library "unittest" style tests. """
import sys import sys
import traceback import traceback
import types
from typing import Any from typing import Any
from typing import Callable
from typing import Generator from typing import Generator
from typing import Iterable from typing import Iterable
from typing import List
from typing import Optional from typing import Optional
from typing import Tuple
from typing import Union from typing import Union
import _pytest._code import _pytest._code
import pytest import pytest
from _pytest.compat import getimfunc from _pytest.compat import getimfunc
from _pytest.compat import is_async_function from _pytest.compat import is_async_function
from _pytest.compat import TYPE_CHECKING
from _pytest.config import hookimpl from _pytest.config import hookimpl
from _pytest.fixtures import FixtureRequest
from _pytest.nodes import Collector from _pytest.nodes import Collector
from _pytest.nodes import Item from _pytest.nodes import Item
from _pytest.outcomes import exit from _pytest.outcomes import exit
@ -25,6 +31,17 @@ from _pytest.runner import CallInfo
from _pytest.skipping import skipped_by_mark_key from _pytest.skipping import skipped_by_mark_key
from _pytest.skipping import unexpectedsuccess_key from _pytest.skipping import unexpectedsuccess_key
if TYPE_CHECKING:
import unittest
from typing import Type
from _pytest.fixtures import _Scope
_SysExcInfoType = Union[
Tuple[Type[BaseException], BaseException, types.TracebackType],
Tuple[None, None, None],
]
def pytest_pycollect_makeitem( def pytest_pycollect_makeitem(
collector: PyCollector, name: str, obj collector: PyCollector, name: str, obj
@ -78,30 +95,32 @@ class UnitTestCase(Class):
if ut is None or runtest != ut.TestCase.runTest: # type: ignore if ut is None or runtest != ut.TestCase.runTest: # type: ignore
yield TestCaseFunction.from_parent(self, name="runTest") yield TestCaseFunction.from_parent(self, name="runTest")
def _inject_setup_teardown_fixtures(self, cls): def _inject_setup_teardown_fixtures(self, cls: type) -> None:
"""Injects a hidden auto-use fixture to invoke setUpClass/setup_method and corresponding """Injects a hidden auto-use fixture to invoke setUpClass/setup_method and corresponding
teardown functions (#517)""" teardown functions (#517)"""
class_fixture = _make_xunit_fixture( class_fixture = _make_xunit_fixture(
cls, "setUpClass", "tearDownClass", scope="class", pass_self=False cls, "setUpClass", "tearDownClass", scope="class", pass_self=False
) )
if class_fixture: if class_fixture:
cls.__pytest_class_setup = class_fixture cls.__pytest_class_setup = class_fixture # type: ignore[attr-defined] # noqa: F821
method_fixture = _make_xunit_fixture( method_fixture = _make_xunit_fixture(
cls, "setup_method", "teardown_method", scope="function", pass_self=True cls, "setup_method", "teardown_method", scope="function", pass_self=True
) )
if method_fixture: if method_fixture:
cls.__pytest_method_setup = method_fixture cls.__pytest_method_setup = method_fixture # type: ignore[attr-defined] # noqa: F821
def _make_xunit_fixture(obj, setup_name, teardown_name, scope, pass_self): def _make_xunit_fixture(
obj: type, setup_name: str, teardown_name: str, scope: "_Scope", pass_self: bool
):
setup = getattr(obj, setup_name, None) setup = getattr(obj, setup_name, None)
teardown = getattr(obj, teardown_name, None) teardown = getattr(obj, teardown_name, None)
if setup is None and teardown is None: if setup is None and teardown is None:
return None return None
@pytest.fixture(scope=scope, autouse=True) @pytest.fixture(scope=scope, autouse=True)
def fixture(self, request): def fixture(self, request: FixtureRequest) -> Generator[None, None, None]:
if _is_skipped(self): if _is_skipped(self):
reason = self.__unittest_skip_why__ reason = self.__unittest_skip_why__
pytest.skip(reason) pytest.skip(reason)
@ -122,32 +141,33 @@ def _make_xunit_fixture(obj, setup_name, teardown_name, scope, pass_self):
class TestCaseFunction(Function): class TestCaseFunction(Function):
nofuncargs = True nofuncargs = True
_excinfo = None _excinfo = None # type: Optional[List[_pytest._code.ExceptionInfo]]
_testcase = None _testcase = None # type: Optional[unittest.TestCase]
def setup(self): def setup(self) -> None:
# a bound method to be called during teardown() if set (see 'runtest()') # a bound method to be called during teardown() if set (see 'runtest()')
self._explicit_tearDown = None self._explicit_tearDown = None # type: Optional[Callable[[], None]]
self._testcase = self.parent.obj(self.name) assert self.parent is not None
self._testcase = self.parent.obj(self.name) # type: ignore[attr-defined] # noqa: F821
self._obj = getattr(self._testcase, self.name) self._obj = getattr(self._testcase, self.name)
if hasattr(self, "_request"): if hasattr(self, "_request"):
self._request._fillfixtures() self._request._fillfixtures()
def teardown(self): def teardown(self) -> None:
if self._explicit_tearDown is not None: if self._explicit_tearDown is not None:
self._explicit_tearDown() self._explicit_tearDown()
self._explicit_tearDown = None self._explicit_tearDown = None
self._testcase = None self._testcase = None
self._obj = None self._obj = None
def startTest(self, testcase): def startTest(self, testcase: "unittest.TestCase") -> None:
pass pass
def _addexcinfo(self, rawexcinfo): def _addexcinfo(self, rawexcinfo: "_SysExcInfoType") -> None:
# unwrap potential exception info (see twisted trial support below) # unwrap potential exception info (see twisted trial support below)
rawexcinfo = getattr(rawexcinfo, "_rawexcinfo", rawexcinfo) rawexcinfo = getattr(rawexcinfo, "_rawexcinfo", rawexcinfo)
try: try:
excinfo = _pytest._code.ExceptionInfo(rawexcinfo) excinfo = _pytest._code.ExceptionInfo(rawexcinfo) # type: ignore[arg-type] # noqa: F821
# invoke the attributes to trigger storing the traceback # invoke the attributes to trigger storing the traceback
# trial causes some issue there # trial causes some issue there
excinfo.value excinfo.value
@ -176,7 +196,9 @@ class TestCaseFunction(Function):
excinfo = _pytest._code.ExceptionInfo.from_current() excinfo = _pytest._code.ExceptionInfo.from_current()
self.__dict__.setdefault("_excinfo", []).append(excinfo) self.__dict__.setdefault("_excinfo", []).append(excinfo)
def addError(self, testcase, rawexcinfo): def addError(
self, testcase: "unittest.TestCase", rawexcinfo: "_SysExcInfoType"
) -> None:
try: try:
if isinstance(rawexcinfo[1], exit.Exception): if isinstance(rawexcinfo[1], exit.Exception):
exit(rawexcinfo[1].msg) exit(rawexcinfo[1].msg)
@ -184,29 +206,38 @@ class TestCaseFunction(Function):
pass pass
self._addexcinfo(rawexcinfo) self._addexcinfo(rawexcinfo)
def addFailure(self, testcase, rawexcinfo): def addFailure(
self, testcase: "unittest.TestCase", rawexcinfo: "_SysExcInfoType"
) -> None:
self._addexcinfo(rawexcinfo) self._addexcinfo(rawexcinfo)
def addSkip(self, testcase, reason): def addSkip(self, testcase: "unittest.TestCase", reason: str) -> None:
try: try:
skip(reason) skip(reason)
except skip.Exception: except skip.Exception:
self._store[skipped_by_mark_key] = True self._store[skipped_by_mark_key] = True
self._addexcinfo(sys.exc_info()) self._addexcinfo(sys.exc_info())
def addExpectedFailure(self, testcase, rawexcinfo, reason=""): def addExpectedFailure(
self,
testcase: "unittest.TestCase",
rawexcinfo: "_SysExcInfoType",
reason: str = "",
) -> None:
try: try:
xfail(str(reason)) xfail(str(reason))
except xfail.Exception: except xfail.Exception:
self._addexcinfo(sys.exc_info()) self._addexcinfo(sys.exc_info())
def addUnexpectedSuccess(self, testcase, reason=""): def addUnexpectedSuccess(
self, testcase: "unittest.TestCase", reason: str = ""
) -> None:
self._store[unexpectedsuccess_key] = reason self._store[unexpectedsuccess_key] = reason
def addSuccess(self, testcase): def addSuccess(self, testcase: "unittest.TestCase") -> None:
pass pass
def stopTest(self, testcase): def stopTest(self, testcase: "unittest.TestCase") -> None:
pass pass
def _expecting_failure(self, test_method) -> bool: def _expecting_failure(self, test_method) -> bool:
@ -218,14 +249,17 @@ class TestCaseFunction(Function):
expecting_failure_class = getattr(self, "__unittest_expecting_failure__", False) expecting_failure_class = getattr(self, "__unittest_expecting_failure__", False)
return bool(expecting_failure_class or expecting_failure_method) return bool(expecting_failure_class or expecting_failure_method)
def runtest(self): def runtest(self) -> None:
from _pytest.debugging import maybe_wrap_pytest_function_for_tracing from _pytest.debugging import maybe_wrap_pytest_function_for_tracing
assert self._testcase is not None
maybe_wrap_pytest_function_for_tracing(self) maybe_wrap_pytest_function_for_tracing(self)
# let the unittest framework handle async functions # let the unittest framework handle async functions
if is_async_function(self.obj): if is_async_function(self.obj):
self._testcase(self) # Type ignored because self acts as the TestResult, but is not actually one.
self._testcase(result=self) # type: ignore[arg-type] # noqa: F821
else: else:
# when --pdb is given, we want to postpone calling tearDown() otherwise # when --pdb is given, we want to postpone calling tearDown() otherwise
# when entering the pdb prompt, tearDown() would have probably cleaned up # when entering the pdb prompt, tearDown() would have probably cleaned up
@ -241,11 +275,11 @@ class TestCaseFunction(Function):
# wrap_pytest_function_for_tracing replaces self.obj by a wrapper # wrap_pytest_function_for_tracing replaces self.obj by a wrapper
setattr(self._testcase, self.name, self.obj) setattr(self._testcase, self.name, self.obj)
try: try:
self._testcase(result=self) self._testcase(result=self) # type: ignore[arg-type] # noqa: F821
finally: finally:
delattr(self._testcase, self.name) delattr(self._testcase, self.name)
def _prunetraceback(self, excinfo): def _prunetraceback(self, excinfo: _pytest._code.ExceptionInfo) -> None:
Function._prunetraceback(self, excinfo) Function._prunetraceback(self, excinfo)
traceback = excinfo.traceback.filter( traceback = excinfo.traceback.filter(
lambda x: not x.frame.f_globals.get("__unittest") lambda x: not x.frame.f_globals.get("__unittest")
@ -313,7 +347,7 @@ def pytest_runtest_protocol(item: Item) -> Generator[None, None, None]:
yield yield
def check_testcase_implements_trial_reporter(done=[]): def check_testcase_implements_trial_reporter(done: List[int] = []) -> None:
if done: if done:
return return
from zope.interface import classImplements from zope.interface import classImplements