diff --git a/changelog/9169.bugfix.rst b/changelog/9169.bugfix.rst new file mode 100644 index 000000000..83fce0a38 --- /dev/null +++ b/changelog/9169.bugfix.rst @@ -0,0 +1 @@ +Support for the ``files`` API from ``importlib.resources`` within rewritten files. diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 1f7de90d5..5fb076d48 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -64,7 +64,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader) except ValueError: self.fnpats = ["test_*.py", "*_test.py"] self.session: Optional[Session] = None - self._rewritten_names: Set[str] = set() + self._rewritten_names: Dict[str, Path] = {} self._must_rewrite: Set[str] = set() # flag to guard against trying to rewrite a pyc file while we are already writing another pyc file, # which might result in infinite recursion (#3506) @@ -134,7 +134,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader) fn = Path(module.__spec__.origin) state = self.config.stash[assertstate_key] - self._rewritten_names.add(module.__name__) + self._rewritten_names[module.__name__] = fn # The requested module looks like a test file, so rewrite it. This is # the most magical part of the process: load the source, rewrite the @@ -276,6 +276,14 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader) with open(pathname, "rb") as f: return f.read() + if sys.version_info >= (3, 9): + + def get_resource_reader(self, name: str) -> importlib.abc.TraversableResources: # type: ignore + from types import SimpleNamespace + from importlib.readers import FileReader + + return FileReader(SimpleNamespace(path=self._rewritten_names[name])) + def _write_pyc_fp( fp: IO[bytes], source_stat: os.stat_result, co: types.CodeType diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 5e63d61fa..ebe1220bb 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -795,6 +795,35 @@ class TestRewriteOnImport: ) assert pytester.runpytest().ret == ExitCode.NO_TESTS_COLLECTED + @pytest.mark.skipif( + sys.version_info < (3, 9), + reason="importlib.resources.files was introduced in 3.9", + ) + def test_load_resource_via_files_with_rewrite(self, pytester: Pytester) -> None: + example = pytester.path.joinpath("demo") / "example" + init = pytester.path.joinpath("demo") / "__init__.py" + pytester.makepyfile( + **{ + "demo/__init__.py": """ + from importlib.resources import files + + def load(): + return files(__name__) + """, + "test_load": f""" + pytest_plugins = ["demo"] + + def test_load(): + from demo import load + found = {{str(i) for i in load().iterdir() if i.name != "__pycache__"}} + assert found == {{{str(example)!r}, {str(init)!r}}} + """, + } + ) + example.mkdir() + + assert pytester.runpytest("-vv").ret == ExitCode.OK + def test_readonly(self, pytester: Pytester) -> None: sub = pytester.mkdir("testing") sub.joinpath("test_readonly.py").write_bytes(