unittest: make `obj` work more like `Function`/`Class`

Previously, the `obj` of a `TestCaseFunction` (the unittest plugin item
type) was the unbound method. This is unlike regular `Class` where the
`obj` is a bound method to a fresh instance.

This difference necessitated several special cases in in places outside
of the unittest plugin, such as `FixtureDef` and `FixtureRequest`, and
made things a bit harder to understand.

Instead, match how the python plugin does it, including collecting
fixtures from a fresh instance.

The downside is that now this instance for fixture-collection is kept
around in memory, but it's the same as `Class` so nothing new. Users
should only initialize stuff in `setUp`/`setUpClass` and similar
methods, and not in `__init__` which is generally off-limits in
`TestCase` subclasses.

I am not sure why there was a difference in the first place, though I
will say the previous unittest approach is probably the preferable one,
but first let's get consistency.
This commit is contained in:
Ran Benita 2024-03-07 00:02:51 +02:00
parent 03e54712dd
commit 1a5e0eb71d
5 changed files with 52 additions and 70 deletions

View File

@ -86,7 +86,6 @@ def getfuncargnames(
function: Callable[..., object], function: Callable[..., object],
*, *,
name: str = "", name: str = "",
is_method: bool = False,
cls: type | None = None, cls: type | None = None,
) -> tuple[str, ...]: ) -> tuple[str, ...]:
"""Return the names of a function's mandatory arguments. """Return the names of a function's mandatory arguments.
@ -97,9 +96,8 @@ def getfuncargnames(
* Aren't bound with functools.partial. * Aren't bound with functools.partial.
* Aren't replaced with mocks. * Aren't replaced with mocks.
The is_method and cls arguments indicate that the function should The cls arguments indicate that the function should be treated as a bound
be treated as a bound method even though it's not unless, only in method even though it's not unless the function is a static method.
the case of cls, the function is a static method.
The name parameter should be the original name in which the function was collected. The name parameter should be the original name in which the function was collected.
""" """
@ -137,7 +135,7 @@ def getfuncargnames(
# If this function should be treated as a bound method even though # If this function should be treated as a bound method even though
# it's passed as an unbound method or function, remove the first # it's passed as an unbound method or function, remove the first
# parameter name. # parameter name.
if is_method or ( if (
# Not using `getattr` because we don't want to resolve the staticmethod. # Not using `getattr` because we don't want to resolve the staticmethod.
# Not using `cls.__dict__` because we want to check the entire MRO. # Not using `cls.__dict__` because we want to check the entire MRO.
cls cls

View File

@ -462,10 +462,6 @@ class FixtureRequest(abc.ABC):
@property @property
def instance(self): def instance(self):
"""Instance (can be None) on which test function was collected.""" """Instance (can be None) on which test function was collected."""
# unittest support hack, see _pytest.unittest.TestCaseFunction.
try:
return self._pyfuncitem._testcase # type: ignore[attr-defined]
except AttributeError:
function = getattr(self, "function", None) function = getattr(self, "function", None)
return getattr(function, "__self__", None) return getattr(function, "__self__", None)
@ -965,7 +961,6 @@ class FixtureDef(Generic[FixtureValue]):
func: "_FixtureFunc[FixtureValue]", func: "_FixtureFunc[FixtureValue]",
scope: Union[Scope, _ScopeName, Callable[[str, Config], _ScopeName], None], scope: Union[Scope, _ScopeName, Callable[[str, Config], _ScopeName], None],
params: Optional[Sequence[object]], params: Optional[Sequence[object]],
unittest: bool = False,
ids: Optional[ ids: Optional[
Union[Tuple[Optional[object], ...], Callable[[Any], Optional[object]]] Union[Tuple[Optional[object], ...], Callable[[Any], Optional[object]]]
] = None, ] = None,
@ -1011,9 +1006,7 @@ class FixtureDef(Generic[FixtureValue]):
# a parameter value. # a parameter value.
self.ids: Final = ids self.ids: Final = ids
# The names requested by the fixtures. # The names requested by the fixtures.
self.argnames: Final = getfuncargnames(func, name=argname, is_method=unittest) self.argnames: Final = getfuncargnames(func, name=argname)
# Whether the fixture was collected from a unittest TestCase class.
self.unittest: Final = unittest
# If the fixture was executed, the current value of the fixture. # If the fixture was executed, the current value of the fixture.
# Can change if the fixture is executed with different parameters. # Can change if the fixture is executed with different parameters.
self.cached_result: Optional[_FixtureCachedResult[FixtureValue]] = None self.cached_result: Optional[_FixtureCachedResult[FixtureValue]] = None
@ -1092,11 +1085,6 @@ def resolve_fixture_function(
"""Get the actual callable that can be called to obtain the fixture """Get the actual callable that can be called to obtain the fixture
value, dealing with unittest-specific instances and bound methods.""" value, dealing with unittest-specific instances and bound methods."""
fixturefunc = fixturedef.func fixturefunc = fixturedef.func
if fixturedef.unittest:
if request.instance is not None:
# Bind the unbound method to the TestCase instance.
fixturefunc = fixturedef.func.__get__(request.instance) # type: ignore[union-attr]
else:
# The fixture function needs to be bound to the actual # The fixture function needs to be bound to the actual
# request.instance so that code working with "fixturedef" behaves # request.instance so that code working with "fixturedef" behaves
# as expected. # as expected.
@ -1614,7 +1602,6 @@ class FixtureManager:
Union[Tuple[Optional[object], ...], Callable[[Any], Optional[object]]] Union[Tuple[Optional[object], ...], Callable[[Any], Optional[object]]]
] = None, ] = None,
autouse: bool = False, autouse: bool = False,
unittest: bool = False,
) -> None: ) -> None:
"""Register a fixture """Register a fixture
@ -1635,8 +1622,6 @@ class FixtureManager:
The fixture's IDs. The fixture's IDs.
:param autouse: :param autouse:
Whether this is an autouse fixture. Whether this is an autouse fixture.
:param unittest:
Set this if this is a unittest fixture.
""" """
fixture_def = FixtureDef( fixture_def = FixtureDef(
config=self.config, config=self.config,
@ -1645,7 +1630,6 @@ class FixtureManager:
func=func, func=func,
scope=scope, scope=scope,
params=params, params=params,
unittest=unittest,
ids=ids, ids=ids,
_ispytest=True, _ispytest=True,
) )
@ -1667,8 +1651,6 @@ class FixtureManager:
def parsefactories( def parsefactories(
self, self,
node_or_obj: nodes.Node, node_or_obj: nodes.Node,
*,
unittest: bool = ...,
) -> None: ) -> None:
raise NotImplementedError() raise NotImplementedError()
@ -1677,8 +1659,6 @@ class FixtureManager:
self, self,
node_or_obj: object, node_or_obj: object,
nodeid: Optional[str], nodeid: Optional[str],
*,
unittest: bool = ...,
) -> None: ) -> None:
raise NotImplementedError() raise NotImplementedError()
@ -1686,8 +1666,6 @@ class FixtureManager:
self, self,
node_or_obj: Union[nodes.Node, object], node_or_obj: Union[nodes.Node, object],
nodeid: Union[str, NotSetType, None] = NOTSET, nodeid: Union[str, NotSetType, None] = NOTSET,
*,
unittest: bool = False,
) -> None: ) -> None:
"""Collect fixtures from a collection node or object. """Collect fixtures from a collection node or object.
@ -1739,7 +1717,6 @@ class FixtureManager:
func=func, func=func,
scope=marker.scope, scope=marker.scope,
params=marker.params, params=marker.params,
unittest=unittest,
ids=marker.ids, ids=marker.ids,
autouse=marker.autouse, autouse=marker.autouse,
) )

View File

@ -1314,7 +1314,6 @@ class Metafunc:
func=get_direct_param_fixture_func, func=get_direct_param_fixture_func,
scope=scope_, scope=scope_,
params=None, params=None,
unittest=False,
ids=None, ids=None,
_ispytest=True, _ispytest=True,
) )

View File

@ -15,7 +15,6 @@ from typing import TYPE_CHECKING
from typing import Union from typing import Union
import _pytest._code import _pytest._code
from _pytest.compat import getimfunc
from _pytest.compat import is_async_function from _pytest.compat import is_async_function
from _pytest.config import hookimpl from _pytest.config import hookimpl
from _pytest.fixtures import FixtureRequest from _pytest.fixtures import FixtureRequest
@ -63,6 +62,14 @@ class UnitTestCase(Class):
# to declare that our children do not support funcargs. # to declare that our children do not support funcargs.
nofuncargs = True nofuncargs = True
def newinstance(self):
# TestCase __init__ takes the method (test) name. The TestCase
# constructor treats the name "runTest" as a special no-op, so it can be
# used when a dummy instance is needed. While unittest.TestCase has a
# default, some subclasses omit the default (#9610), so always supply
# it.
return self.obj("runTest")
def collect(self) -> Iterable[Union[Item, Collector]]: def collect(self) -> Iterable[Union[Item, Collector]]:
from unittest import TestLoader from unittest import TestLoader
@ -76,15 +83,15 @@ class UnitTestCase(Class):
self._register_unittest_setup_class_fixture(cls) self._register_unittest_setup_class_fixture(cls)
self._register_setup_class_fixture() self._register_setup_class_fixture()
self.session._fixturemanager.parsefactories(self, unittest=True) self.session._fixturemanager.parsefactories(self.newinstance(), self.nodeid)
loader = TestLoader() loader = TestLoader()
foundsomething = False foundsomething = False
for name in loader.getTestCaseNames(self.obj): for name in loader.getTestCaseNames(self.obj):
x = getattr(self.obj, name) x = getattr(self.obj, name)
if not getattr(x, "__test__", True): if not getattr(x, "__test__", True):
continue continue
funcobj = getimfunc(x) yield TestCaseFunction.from_parent(self, name=name)
yield TestCaseFunction.from_parent(self, name=name, callobj=funcobj)
foundsomething = True foundsomething = True
if not foundsomething: if not foundsomething:
@ -169,23 +176,21 @@ class UnitTestCase(Class):
class TestCaseFunction(Function): class TestCaseFunction(Function):
nofuncargs = True nofuncargs = True
_excinfo: Optional[List[_pytest._code.ExceptionInfo[BaseException]]] = None _excinfo: Optional[List[_pytest._code.ExceptionInfo[BaseException]]] = None
_testcase: Optional["unittest.TestCase"] = None
def _getobj(self): def _getobj(self):
assert self.parent is not None assert isinstance(self.parent, UnitTestCase)
# Unlike a regular Function in a Class, where `item.obj` returns testcase = self.parent.obj(self.name)
# a *bound* method (attached to an instance), TestCaseFunction's return getattr(testcase, self.name)
# `obj` returns an *unbound* method (not attached to an instance).
# This inconsistency is probably not desirable, but needs some # Backward compat for pytest-django; can be removed after pytest-django
# consideration before changing. # updates + some slack.
return getattr(self.parent.obj, self.originalname) # type: ignore[attr-defined] @property
def _testcase(self):
return self._obj.__self__
def setup(self) -> None: def setup(self) -> None:
# A bound method to be called during teardown() if set (see 'runtest()'). # A bound method to be called during teardown() if set (see 'runtest()').
self._explicit_tearDown: Optional[Callable[[], None]] = None self._explicit_tearDown: Optional[Callable[[], None]] = None
assert self.parent is not None
self._testcase = self.parent.obj(self.name) # type: ignore[attr-defined]
self._obj = getattr(self._testcase, self.name)
super().setup() super().setup()
def teardown(self) -> None: def teardown(self) -> None:
@ -193,7 +198,6 @@ class TestCaseFunction(Function):
if self._explicit_tearDown is not None: if self._explicit_tearDown is not None:
self._explicit_tearDown() self._explicit_tearDown()
self._explicit_tearDown = None self._explicit_tearDown = None
self._testcase = None
self._obj = None self._obj = None
def startTest(self, testcase: "unittest.TestCase") -> None: def startTest(self, testcase: "unittest.TestCase") -> None:
@ -292,14 +296,14 @@ class TestCaseFunction(Function):
def runtest(self) -> None: def runtest(self) -> None:
from _pytest.debugging import maybe_wrap_pytest_function_for_tracing from _pytest.debugging import maybe_wrap_pytest_function_for_tracing
assert self._testcase is not None testcase = self.obj.__self__
maybe_wrap_pytest_function_for_tracing(self) maybe_wrap_pytest_function_for_tracing(self)
# Let the unittest framework handle async functions. # Let the unittest framework handle async functions.
if is_async_function(self.obj): if is_async_function(self.obj):
# Type ignored because self acts as the TestResult, but is not actually one. # Type ignored because self acts as the TestResult, but is not actually one.
self._testcase(result=self) # type: ignore[arg-type] testcase(result=self) # type: ignore[arg-type]
else: else:
# When --pdb is given, we want to postpone calling tearDown() otherwise # When --pdb is given, we want to postpone calling tearDown() otherwise
# when entering the pdb prompt, tearDown() would have probably cleaned up # when entering the pdb prompt, tearDown() would have probably cleaned up
@ -311,16 +315,16 @@ class TestCaseFunction(Function):
assert isinstance(self.parent, UnitTestCase) assert isinstance(self.parent, UnitTestCase)
skipped = _is_skipped(self.obj) or _is_skipped(self.parent.obj) skipped = _is_skipped(self.obj) or _is_skipped(self.parent.obj)
if self.config.getoption("usepdb") and not skipped: if self.config.getoption("usepdb") and not skipped:
self._explicit_tearDown = self._testcase.tearDown self._explicit_tearDown = testcase.tearDown
setattr(self._testcase, "tearDown", lambda *args: None) setattr(testcase, "tearDown", lambda *args: None)
# We need to update the actual bound method with self.obj, because # We need to update the actual bound method with self.obj, because
# wrap_pytest_function_for_tracing replaces self.obj by a wrapper. # wrap_pytest_function_for_tracing replaces self.obj by a wrapper.
setattr(self._testcase, self.name, self.obj) setattr(testcase, self.name, self.obj)
try: try:
self._testcase(result=self) # type: ignore[arg-type] testcase(result=self) # type: ignore[arg-type]
finally: finally:
delattr(self._testcase, self.name) delattr(testcase, self.name)
def _traceback_filter( def _traceback_filter(
self, excinfo: _pytest._code.ExceptionInfo[BaseException] self, excinfo: _pytest._code.ExceptionInfo[BaseException]

View File

@ -208,10 +208,14 @@ def test_teardown_issue1649(pytester: Pytester) -> None:
""" """
) )
pytester.inline_run("-s", testpath) pytester.inline_run("-s", testpath)
gc.collect() gc.collect()
# Either already destroyed, or didn't run setUp.
for obj in gc.get_objects(): for obj in gc.get_objects():
assert type(obj).__name__ != "TestCaseObjectsShouldBeCleanedUp" if type(obj).__name__ == "TestCaseObjectsShouldBeCleanedUp":
assert not hasattr(obj, "an_expensive_obj")
def test_unittest_skip_issue148(pytester: Pytester) -> None: def test_unittest_skip_issue148(pytester: Pytester) -> None: