diff --git a/src/_pytest/assertion/__init__.py b/src/_pytest/assertion/__init__.py index cdb034703..0a57f6afa 100644 --- a/src/_pytest/assertion/__init__.py +++ b/src/_pytest/assertion/__init__.py @@ -7,6 +7,7 @@ from typing import Optional from _pytest.assertion import rewrite from _pytest.assertion import truncate from _pytest.assertion import util +from _pytest.assertion.rewrite import assertstate_key from _pytest.compat import TYPE_CHECKING from _pytest.config import hookimpl @@ -82,13 +83,13 @@ class AssertionState: def install_importhook(config): """Try to install the rewrite hook, raise SystemError if it fails.""" - config._assertstate = AssertionState(config, "rewrite") - config._assertstate.hook = hook = rewrite.AssertionRewritingHook(config) + config._store[assertstate_key] = AssertionState(config, "rewrite") + config._store[assertstate_key].hook = hook = rewrite.AssertionRewritingHook(config) sys.meta_path.insert(0, hook) - config._assertstate.trace("installed rewrite import hook") + config._store[assertstate_key].trace("installed rewrite import hook") def undo(): - hook = config._assertstate.hook + hook = config._store[assertstate_key].hook if hook is not None and hook in sys.meta_path: sys.meta_path.remove(hook) @@ -100,7 +101,7 @@ def pytest_collection(session: "Session") -> None: # this hook is only called when test modules are collected # so for example not in the master process of pytest-xdist # (which does not collect test modules) - assertstate = getattr(session.config, "_assertstate", None) + assertstate = session.config._store.get(assertstate_key, None) if assertstate: if assertstate.hook is not None: assertstate.hook.set_session(session) @@ -163,7 +164,7 @@ def pytest_runtest_protocol(item): def pytest_sessionfinish(session): - assertstate = getattr(session.config, "_assertstate", None) + assertstate = session.config._store.get(assertstate_key, None) if assertstate: if assertstate.hook is not None: assertstate.hook.set_session(None) diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 00bdcfc3e..f84127dca 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -26,9 +26,18 @@ from _pytest.assertion.util import ( # noqa: F401 format_explanation as _format_explanation, ) from _pytest.compat import fspath +from _pytest.compat import TYPE_CHECKING from _pytest.pathlib import fnmatch_ex from _pytest.pathlib import Path from _pytest.pathlib import PurePath +from _pytest.store import StoreKey + +if TYPE_CHECKING: + from _pytest.assertion import AssertionState # noqa: F401 + + +assertstate_key = StoreKey["AssertionState"]() + # pytest caches rewritten pycs in pycache dirs PYTEST_TAG = "{}-pytest-{}".format(sys.implementation.cache_tag, version) @@ -65,7 +74,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader) def find_spec(self, name, path=None, target=None): if self._writing_pyc: return None - state = self.config._assertstate + state = self.config._store[assertstate_key] if self._early_rewrite_bailout(name, state): return None state.trace("find_module called for: %s" % name) @@ -104,7 +113,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader) def exec_module(self, module): fn = Path(module.__spec__.origin) - state = self.config._assertstate + state = self.config._store[assertstate_key] self._rewritten_names.add(module.__name__)