Merge pull request #5468 from asottile/switch_importlib_to_imp

Switch from deprecated imp to importlib
This commit is contained in:
Anthony Sottile 2019-06-24 11:22:01 -07:00 committed by GitHub
commit 61dcb84f0d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 182 additions and 242 deletions

View File

@ -0,0 +1 @@
Switch from ``imp`` to ``importlib``.

View File

@ -0,0 +1 @@
Honor PEP 235 on case-insensitive file systems.

View File

@ -0,0 +1 @@
Test module is no longer double-imported when using ``--pyargs``.

View File

@ -0,0 +1 @@
Prevent "already imported" warnings from assertion rewriter when invoking pytest in-process multiple times.

View File

@ -0,0 +1 @@
Fix assertion rewriting in packages (``__init__.py``).

View File

@ -1,18 +1,16 @@
"""Rewrite assertion AST to produce nice error messages""" """Rewrite assertion AST to produce nice error messages"""
import ast import ast
import errno import errno
import imp import importlib.machinery
import importlib.util
import itertools import itertools
import marshal import marshal
import os import os
import re
import struct import struct
import sys import sys
import types import types
from importlib.util import spec_from_file_location
import atomicwrites import atomicwrites
import py
from _pytest._io.saferepr import saferepr from _pytest._io.saferepr import saferepr
from _pytest.assertion import util from _pytest.assertion import util
@ -23,23 +21,13 @@ from _pytest.pathlib import fnmatch_ex
from _pytest.pathlib import PurePath from _pytest.pathlib import PurePath
# pytest caches rewritten pycs in __pycache__. # pytest caches rewritten pycs in __pycache__.
if hasattr(imp, "get_tag"): PYTEST_TAG = "{}-PYTEST".format(sys.implementation.cache_tag)
PYTEST_TAG = imp.get_tag() + "-PYTEST"
else:
if hasattr(sys, "pypy_version_info"):
impl = "pypy"
else:
impl = "cpython"
ver = sys.version_info
PYTEST_TAG = "{}-{}{}-PYTEST".format(impl, ver[0], ver[1])
del ver, impl
PYC_EXT = ".py" + (__debug__ and "c" or "o") PYC_EXT = ".py" + (__debug__ and "c" or "o")
PYC_TAIL = "." + PYTEST_TAG + PYC_EXT PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
class AssertionRewritingHook: class AssertionRewritingHook:
"""PEP302 Import hook which rewrites asserts.""" """PEP302/PEP451 import hook which rewrites asserts."""
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
@ -48,7 +36,6 @@ class AssertionRewritingHook:
except ValueError: except ValueError:
self.fnpats = ["test_*.py", "*_test.py"] self.fnpats = ["test_*.py", "*_test.py"]
self.session = None self.session = None
self.modules = {}
self._rewritten_names = set() self._rewritten_names = set()
self._must_rewrite = set() self._must_rewrite = set()
# flag to guard against trying to rewrite a pyc file while we are already writing another pyc file, # flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
@ -62,55 +49,51 @@ class AssertionRewritingHook:
self.session = session self.session = session
self._session_paths_checked = False self._session_paths_checked = False
def _imp_find_module(self, name, path=None): # Indirection so we can mock calls to find_spec originated from the hook during testing
"""Indirection so we can mock calls to find_module originated from the hook during testing""" _find_spec = importlib.machinery.PathFinder.find_spec
return imp.find_module(name, path)
def find_module(self, name, path=None): def find_spec(self, name, path=None, target=None):
if self._writing_pyc: if self._writing_pyc:
return None return None
state = self.config._assertstate state = self.config._assertstate
if self._early_rewrite_bailout(name, state): if self._early_rewrite_bailout(name, state):
return None return None
state.trace("find_module called for: %s" % name) state.trace("find_module called for: %s" % name)
names = name.rsplit(".", 1)
lastname = names[-1] spec = self._find_spec(name, path)
pth = None if (
if path is not None: # the import machinery could not find a file to import
# Starting with Python 3.3, path is a _NamespacePath(), which spec is None
# causes problems if not converted to list. # this is a namespace package (without `__init__.py`)
path = list(path) # there's nothing to rewrite there
if len(path) == 1: # python3.5 - python3.6: `namespace`
pth = path[0] # python3.7+: `None`
if pth is None: or spec.origin in {None, "namespace"}
try: # if the file doesn't exist, we can't rewrite it
fd, fn, desc = self._imp_find_module(lastname, path) or not os.path.exists(spec.origin)
except ImportError: ):
return None
if fd is not None:
fd.close()
tp = desc[2]
if tp == imp.PY_COMPILED:
if hasattr(imp, "source_from_cache"):
try:
fn = imp.source_from_cache(fn)
except ValueError:
# Python 3 doesn't like orphaned but still-importable
# .pyc files.
fn = fn[:-1]
else:
fn = fn[:-1]
elif tp != imp.PY_SOURCE:
# Don't know what this is.
return None return None
else: else:
fn = os.path.join(pth, name.rpartition(".")[2] + ".py") fn = spec.origin
fn_pypath = py.path.local(fn) if not self._should_rewrite(name, fn, state):
if not self._should_rewrite(name, fn_pypath, state):
return None return None
self._rewritten_names.add(name) return importlib.util.spec_from_file_location(
name,
fn,
loader=self,
submodule_search_locations=spec.submodule_search_locations,
)
def create_module(self, spec):
return None # default behaviour is fine
def exec_module(self, module):
fn = module.__spec__.origin
state = self.config._assertstate
self._rewritten_names.add(module.__name__)
# The requested module looks like a test file, so rewrite it. This is # The requested module looks like a test file, so rewrite it. This is
# the most magical part of the process: load the source, rewrite the # the most magical part of the process: load the source, rewrite the
@ -121,7 +104,7 @@ class AssertionRewritingHook:
# cached pyc is always a complete, valid pyc. Operations on it must be # cached pyc is always a complete, valid pyc. Operations on it must be
# atomic. POSIX's atomic rename comes in handy. # atomic. POSIX's atomic rename comes in handy.
write = not sys.dont_write_bytecode write = not sys.dont_write_bytecode
cache_dir = os.path.join(fn_pypath.dirname, "__pycache__") cache_dir = os.path.join(os.path.dirname(fn), "__pycache__")
if write: if write:
try: try:
os.mkdir(cache_dir) os.mkdir(cache_dir)
@ -132,26 +115,23 @@ class AssertionRewritingHook:
# common case) or it's blocked by a non-dir node. In the # common case) or it's blocked by a non-dir node. In the
# latter case, we'll ignore it in _write_pyc. # latter case, we'll ignore it in _write_pyc.
pass pass
elif e in [errno.ENOENT, errno.ENOTDIR]: elif e in {errno.ENOENT, errno.ENOTDIR}:
# One of the path components was not a directory, likely # One of the path components was not a directory, likely
# because we're in a zip file. # because we're in a zip file.
write = False write = False
elif e in [errno.EACCES, errno.EROFS, errno.EPERM]: elif e in {errno.EACCES, errno.EROFS, errno.EPERM}:
state.trace("read only directory: %r" % fn_pypath.dirname) state.trace("read only directory: %r" % os.path.dirname(fn))
write = False write = False
else: else:
raise raise
cache_name = fn_pypath.basename[:-3] + PYC_TAIL cache_name = os.path.basename(fn)[:-3] + PYC_TAIL
pyc = os.path.join(cache_dir, cache_name) pyc = os.path.join(cache_dir, cache_name)
# Notice that even if we're in a read-only directory, I'm going # Notice that even if we're in a read-only directory, I'm going
# to check for a cached pyc. This may not be optimal... # to check for a cached pyc. This may not be optimal...
co = _read_pyc(fn_pypath, pyc, state.trace) co = _read_pyc(fn, pyc, state.trace)
if co is None: if co is None:
state.trace("rewriting {!r}".format(fn)) state.trace("rewriting {!r}".format(fn))
source_stat, co = _rewrite_test(self.config, fn_pypath) source_stat, co = _rewrite_test(fn)
if co is None:
# Probably a SyntaxError in the test.
return None
if write: if write:
self._writing_pyc = True self._writing_pyc = True
try: try:
@ -160,13 +140,11 @@ class AssertionRewritingHook:
self._writing_pyc = False self._writing_pyc = False
else: else:
state.trace("found cached rewritten pyc for {!r}".format(fn)) state.trace("found cached rewritten pyc for {!r}".format(fn))
self.modules[name] = co, pyc exec(co, module.__dict__)
return self
def _early_rewrite_bailout(self, name, state): def _early_rewrite_bailout(self, name, state):
""" """This is a fast way to get out of rewriting modules. Profiling has
This is a fast way to get out of rewriting modules. Profiling has shown that the call to PathFinder.find_spec (inside of the find_spec
shown that the call to imp.find_module (inside of the find_module
from this class) is a major slowdown, so, this method tries to from this class) is a major slowdown, so, this method tries to
filter what we're sure won't be rewritten before getting to it. filter what we're sure won't be rewritten before getting to it.
""" """
@ -201,10 +179,9 @@ class AssertionRewritingHook:
state.trace("early skip of rewriting module: {}".format(name)) state.trace("early skip of rewriting module: {}".format(name))
return True return True
def _should_rewrite(self, name, fn_pypath, state): def _should_rewrite(self, name, fn, state):
# always rewrite conftest files # always rewrite conftest files
fn = str(fn_pypath) if os.path.basename(fn) == "conftest.py":
if fn_pypath.basename == "conftest.py":
state.trace("rewriting conftest file: {!r}".format(fn)) state.trace("rewriting conftest file: {!r}".format(fn))
return True return True
@ -217,8 +194,9 @@ class AssertionRewritingHook:
# modules not passed explicitly on the command line are only # modules not passed explicitly on the command line are only
# rewritten if they match the naming convention for test files # rewritten if they match the naming convention for test files
fn_path = PurePath(fn)
for pat in self.fnpats: for pat in self.fnpats:
if fn_pypath.fnmatch(pat): if fnmatch_ex(pat, fn_path):
state.trace("matched test file {!r}".format(fn)) state.trace("matched test file {!r}".format(fn))
return True return True
@ -249,9 +227,10 @@ class AssertionRewritingHook:
set(names).intersection(sys.modules).difference(self._rewritten_names) set(names).intersection(sys.modules).difference(self._rewritten_names)
) )
for name in already_imported: for name in already_imported:
mod = sys.modules[name]
if not AssertionRewriter.is_rewrite_disabled( if not AssertionRewriter.is_rewrite_disabled(
sys.modules[name].__doc__ or "" mod.__doc__ or ""
): ) and not isinstance(mod.__loader__, type(self)):
self._warn_already_imported(name) self._warn_already_imported(name)
self._must_rewrite.update(names) self._must_rewrite.update(names)
self._marked_for_rewrite_cache.clear() self._marked_for_rewrite_cache.clear()
@ -268,45 +247,8 @@ class AssertionRewritingHook:
stacklevel=5, stacklevel=5,
) )
def load_module(self, name):
co, pyc = self.modules.pop(name)
if name in sys.modules:
# If there is an existing module object named 'fullname' in
# sys.modules, the loader must use that existing module. (Otherwise,
# the reload() builtin will not work correctly.)
mod = sys.modules[name]
else:
# I wish I could just call imp.load_compiled here, but __file__ has to
# be set properly. In Python 3.2+, this all would be handled correctly
# by load_compiled.
mod = sys.modules[name] = imp.new_module(name)
try:
mod.__file__ = co.co_filename
# Normally, this attribute is 3.2+.
mod.__cached__ = pyc
mod.__loader__ = self
# Normally, this attribute is 3.4+
mod.__spec__ = spec_from_file_location(name, co.co_filename, loader=self)
exec(co, mod.__dict__)
except: # noqa
if name in sys.modules:
del sys.modules[name]
raise
return sys.modules[name]
def is_package(self, name):
try:
fd, fn, desc = self._imp_find_module(name)
except ImportError:
return False
if fd is not None:
fd.close()
tp = desc[2]
return tp == imp.PKG_DIRECTORY
def get_data(self, pathname): def get_data(self, pathname):
"""Optional PEP302 get_data API. """Optional PEP302 get_data API."""
"""
with open(pathname, "rb") as f: with open(pathname, "rb") as f:
return f.read() return f.read()
@ -314,15 +256,13 @@ class AssertionRewritingHook:
def _write_pyc(state, co, source_stat, pyc): def _write_pyc(state, co, source_stat, pyc):
# Technically, we don't have to have the same pyc format as # Technically, we don't have to have the same pyc format as
# (C)Python, since these "pycs" should never be seen by builtin # (C)Python, since these "pycs" should never be seen by builtin
# import. However, there's little reason deviate, and I hope # import. However, there's little reason deviate.
# sometime to be able to use imp.load_compiled to load them. (See
# the comment in load_module above.)
try: try:
with atomicwrites.atomic_write(pyc, mode="wb", overwrite=True) as fp: with atomicwrites.atomic_write(pyc, mode="wb", overwrite=True) as fp:
fp.write(imp.get_magic()) fp.write(importlib.util.MAGIC_NUMBER)
# as of now, bytecode header expects 32-bit numbers for size and mtime (#4903) # as of now, bytecode header expects 32-bit numbers for size and mtime (#4903)
mtime = int(source_stat.mtime) & 0xFFFFFFFF mtime = int(source_stat.st_mtime) & 0xFFFFFFFF
size = source_stat.size & 0xFFFFFFFF size = source_stat.st_size & 0xFFFFFFFF
# "<LL" stands for 2 unsigned longs, little-ending # "<LL" stands for 2 unsigned longs, little-ending
fp.write(struct.pack("<LL", mtime, size)) fp.write(struct.pack("<LL", mtime, size))
fp.write(marshal.dumps(co)) fp.write(marshal.dumps(co))
@ -335,35 +275,14 @@ def _write_pyc(state, co, source_stat, pyc):
return True return True
RN = b"\r\n" def _rewrite_test(fn):
N = b"\n" """read and rewrite *fn* and return the code object."""
stat = os.stat(fn)
cookie_re = re.compile(r"^[ \t\f]*#.*coding[:=][ \t]*[-\w.]+") with open(fn, "rb") as f:
BOM_UTF8 = "\xef\xbb\xbf" source = f.read()
tree = ast.parse(source, filename=fn)
rewrite_asserts(tree, fn)
def _rewrite_test(config, fn): co = compile(tree, fn, "exec", dont_inherit=True)
"""Try to read and rewrite *fn* and return the code object."""
state = config._assertstate
try:
stat = fn.stat()
source = fn.read("rb")
except EnvironmentError:
return None, None
try:
tree = ast.parse(source, filename=fn.strpath)
except SyntaxError:
# Let this pop up again in the real import.
state.trace("failed to parse: {!r}".format(fn))
return None, None
rewrite_asserts(tree, fn, config)
try:
co = compile(tree, fn.strpath, "exec", dont_inherit=True)
except SyntaxError:
# It's possible that this error is from some bug in the
# assertion rewriting, but I don't know of a fast way to tell.
state.trace("failed to compile: {!r}".format(fn))
return None, None
return stat, co return stat, co
@ -378,8 +297,9 @@ def _read_pyc(source, pyc, trace=lambda x: None):
return None return None
with fp: with fp:
try: try:
mtime = int(source.mtime()) stat_result = os.stat(source)
size = source.size() mtime = int(stat_result.st_mtime)
size = stat_result.st_size
data = fp.read(12) data = fp.read(12)
except EnvironmentError as e: except EnvironmentError as e:
trace("_read_pyc({}): EnvironmentError {}".format(source, e)) trace("_read_pyc({}): EnvironmentError {}".format(source, e))
@ -387,7 +307,7 @@ def _read_pyc(source, pyc, trace=lambda x: None):
# Check for invalid or out of date pyc file. # Check for invalid or out of date pyc file.
if ( if (
len(data) != 12 len(data) != 12
or data[:4] != imp.get_magic() or data[:4] != importlib.util.MAGIC_NUMBER
or struct.unpack("<LL", data[4:]) != (mtime & 0xFFFFFFFF, size & 0xFFFFFFFF) or struct.unpack("<LL", data[4:]) != (mtime & 0xFFFFFFFF, size & 0xFFFFFFFF)
): ):
trace("_read_pyc(%s): invalid or out of date pyc" % source) trace("_read_pyc(%s): invalid or out of date pyc" % source)
@ -403,9 +323,9 @@ def _read_pyc(source, pyc, trace=lambda x: None):
return co return co
def rewrite_asserts(mod, module_path=None, config=None): def rewrite_asserts(mod, module_path=None):
"""Rewrite the assert statements in mod.""" """Rewrite the assert statements in mod."""
AssertionRewriter(module_path, config).run(mod) AssertionRewriter(module_path).run(mod)
def _saferepr(obj): def _saferepr(obj):
@ -586,10 +506,9 @@ class AssertionRewriter(ast.NodeVisitor):
""" """
def __init__(self, module_path, config): def __init__(self, module_path):
super().__init__() super().__init__()
self.module_path = module_path self.module_path = module_path
self.config = config
def run(self, mod): def run(self, mod):
"""Find all assert statements in *mod* and rewrite them.""" """Find all assert statements in *mod* and rewrite them."""
@ -758,7 +677,7 @@ class AssertionRewriter(ast.NodeVisitor):
"assertion is always true, perhaps remove parentheses?" "assertion is always true, perhaps remove parentheses?"
), ),
category=None, category=None,
filename=str(self.module_path), filename=self.module_path,
lineno=assert_.lineno, lineno=assert_.lineno,
) )
@ -817,7 +736,7 @@ class AssertionRewriter(ast.NodeVisitor):
AST_NONE = ast.parse("None").body[0].value AST_NONE = ast.parse("None").body[0].value
val_is_none = ast.Compare(node, [ast.Is()], [AST_NONE]) val_is_none = ast.Compare(node, [ast.Is()], [AST_NONE])
send_warning = ast.parse( send_warning = ast.parse(
""" """\
from _pytest.warning_types import PytestAssertRewriteWarning from _pytest.warning_types import PytestAssertRewriteWarning
from warnings import warn_explicit from warnings import warn_explicit
warn_explicit( warn_explicit(
@ -827,7 +746,7 @@ warn_explicit(
lineno={lineno}, lineno={lineno},
) )
""".format( """.format(
filename=module_path.strpath, lineno=lineno filename=module_path, lineno=lineno
) )
).body ).body
return ast.If(val_is_none, send_warning, []) return ast.If(val_is_none, send_warning, [])

View File

@ -2,8 +2,8 @@
import enum import enum
import fnmatch import fnmatch
import functools import functools
import importlib
import os import os
import pkgutil
import sys import sys
import warnings import warnings
@ -630,21 +630,15 @@ class Session(nodes.FSCollector):
def _tryconvertpyarg(self, x): def _tryconvertpyarg(self, x):
"""Convert a dotted module name to path.""" """Convert a dotted module name to path."""
try: try:
loader = pkgutil.find_loader(x) spec = importlib.util.find_spec(x)
except ImportError: except (ValueError, ImportError):
return x return x
if loader is None: if spec is None or spec.origin in {None, "namespace"}:
return x return x
# This method is sometimes invoked when AssertionRewritingHook, which elif spec.submodule_search_locations:
# does not define a get_filename method, is already in place: return os.path.dirname(spec.origin)
try: else:
path = loader.get_filename(x) return spec.origin
except AttributeError:
# Retrieve path from AssertionRewritingHook:
path = loader.modules[x][0].co_filename
if loader.is_package(x):
path = os.path.dirname(path)
return path
def _parsearg(self, arg): def _parsearg(self, arg):
""" return (fspath, names) tuple after checking the file exists. """ """ return (fspath, names) tuple after checking the file exists. """

View File

@ -294,6 +294,8 @@ def fnmatch_ex(pattern, path):
name = path.name name = path.name
else: else:
name = str(path) name = str(path)
if path.is_absolute() and not os.path.isabs(pattern):
pattern = "*{}{}".format(os.sep, pattern)
return fnmatch.fnmatch(name, pattern) return fnmatch.fnmatch(name, pattern)

View File

@ -1,5 +1,6 @@
"""(disabled by default) support for testing pytest and pytest plugins.""" """(disabled by default) support for testing pytest and pytest plugins."""
import gc import gc
import importlib
import os import os
import platform import platform
import re import re
@ -16,7 +17,6 @@ import py
import pytest import pytest
from _pytest._code import Source from _pytest._code import Source
from _pytest._io.saferepr import saferepr from _pytest._io.saferepr import saferepr
from _pytest.assertion.rewrite import AssertionRewritingHook
from _pytest.capture import MultiCapture from _pytest.capture import MultiCapture
from _pytest.capture import SysCapture from _pytest.capture import SysCapture
from _pytest.main import ExitCode from _pytest.main import ExitCode
@ -787,6 +787,11 @@ class Testdir:
:return: a :py:class:`HookRecorder` instance :return: a :py:class:`HookRecorder` instance
""" """
# (maybe a cpython bug?) the importlib cache sometimes isn't updated
# properly between file creation and inline_run (especially if imports
# are interspersed with file creation)
importlib.invalidate_caches()
plugins = list(plugins) plugins = list(plugins)
finalizers = [] finalizers = []
try: try:
@ -796,18 +801,6 @@ class Testdir:
mp_run.setenv(k, v) mp_run.setenv(k, v)
finalizers.append(mp_run.undo) finalizers.append(mp_run.undo)
# When running pytest inline any plugins active in the main test
# process are already imported. So this disables the warning which
# will trigger to say they can no longer be rewritten, which is
# fine as they have already been rewritten.
orig_warn = AssertionRewritingHook._warn_already_imported
def revert_warn_already_imported():
AssertionRewritingHook._warn_already_imported = orig_warn
finalizers.append(revert_warn_already_imported)
AssertionRewritingHook._warn_already_imported = lambda *a: None
# Any sys.module or sys.path changes done while running pytest # Any sys.module or sys.path changes done while running pytest
# inline should be reverted after the test run completes to avoid # inline should be reverted after the test run completes to avoid
# clashing with later inline tests run within the same pytest test, # clashing with later inline tests run within the same pytest test,

View File

@ -633,6 +633,19 @@ class TestInvocationVariants:
result.stdout.fnmatch_lines(["collected*0*items*/*1*errors"]) result.stdout.fnmatch_lines(["collected*0*items*/*1*errors"])
def test_pyargs_only_imported_once(self, testdir):
pkg = testdir.mkpydir("foo")
pkg.join("test_foo.py").write("print('hello from test_foo')\ndef test(): pass")
pkg.join("conftest.py").write(
"def pytest_configure(config): print('configuring')"
)
result = testdir.runpytest("--pyargs", "foo.test_foo", "-s", syspathinsert=True)
# should only import once
assert result.outlines.count("hello from test_foo") == 1
# should only configure once
assert result.outlines.count("configuring") == 1
def test_cmdline_python_package(self, testdir, monkeypatch): def test_cmdline_python_package(self, testdir, monkeypatch):
import warnings import warnings

View File

@ -137,8 +137,8 @@ class TestImportHookInstallation:
"hamster.py": "", "hamster.py": "",
"test_foo.py": """\ "test_foo.py": """\
def test_foo(pytestconfig): def test_foo(pytestconfig):
assert pytestconfig.pluginmanager.rewrite_hook.find_module('ham') is not None assert pytestconfig.pluginmanager.rewrite_hook.find_spec('ham') is not None
assert pytestconfig.pluginmanager.rewrite_hook.find_module('hamster') is None assert pytestconfig.pluginmanager.rewrite_hook.find_spec('hamster') is None
""", """,
} }
testdir.makepyfile(**contents) testdir.makepyfile(**contents)

View File

@ -1,5 +1,6 @@
import ast import ast
import glob import glob
import importlib
import os import os
import py_compile import py_compile
import stat import stat
@ -117,6 +118,37 @@ class TestAssertionRewrite:
result = testdir.runpytest_subprocess() result = testdir.runpytest_subprocess()
assert "warnings" not in "".join(result.outlines) assert "warnings" not in "".join(result.outlines)
def test_rewrites_plugin_as_a_package(self, testdir):
pkgdir = testdir.mkpydir("plugin")
pkgdir.join("__init__.py").write(
"import pytest\n"
"@pytest.fixture\n"
"def special_asserter():\n"
" def special_assert(x, y):\n"
" assert x == y\n"
" return special_assert\n"
)
testdir.makeconftest('pytest_plugins = ["plugin"]')
testdir.makepyfile("def test(special_asserter): special_asserter(1, 2)\n")
result = testdir.runpytest()
result.stdout.fnmatch_lines(["*assert 1 == 2*"])
def test_honors_pep_235(self, testdir, monkeypatch):
# note: couldn't make it fail on macos with a single `sys.path` entry
# note: these modules are named `test_*` to trigger rewriting
testdir.tmpdir.join("test_y.py").write("x = 1")
xdir = testdir.tmpdir.join("x").ensure_dir()
xdir.join("test_Y").ensure_dir().join("__init__.py").write("x = 2")
testdir.makepyfile(
"import test_y\n"
"import test_Y\n"
"def test():\n"
" assert test_y.x == 1\n"
" assert test_Y.x == 2\n"
)
monkeypatch.syspath_prepend(xdir)
testdir.runpytest().assert_outcomes(passed=1)
def test_name(self, request): def test_name(self, request):
def f(): def f():
assert False assert False
@ -831,8 +863,9 @@ def test_rewritten():
monkeypatch.setattr( monkeypatch.setattr(
hook, "_warn_already_imported", lambda code, msg: warnings.append(msg) hook, "_warn_already_imported", lambda code, msg: warnings.append(msg)
) )
hook.find_module("test_remember_rewritten_modules") spec = hook.find_spec("test_remember_rewritten_modules")
hook.load_module("test_remember_rewritten_modules") module = importlib.util.module_from_spec(spec)
hook.exec_module(module)
hook.mark_rewrite("test_remember_rewritten_modules") hook.mark_rewrite("test_remember_rewritten_modules")
hook.mark_rewrite("test_remember_rewritten_modules") hook.mark_rewrite("test_remember_rewritten_modules")
assert warnings == [] assert warnings == []
@ -872,33 +905,6 @@ def test_rewritten():
class TestAssertionRewriteHookDetails: class TestAssertionRewriteHookDetails:
def test_loader_is_package_false_for_module(self, testdir):
testdir.makepyfile(
test_fun="""
def test_loader():
assert not __loader__.is_package(__name__)
"""
)
result = testdir.runpytest()
result.stdout.fnmatch_lines(["* 1 passed*"])
def test_loader_is_package_true_for_package(self, testdir):
testdir.makepyfile(
test_fun="""
def test_loader():
assert not __loader__.is_package(__name__)
def test_fun():
assert __loader__.is_package('fun')
def test_missing():
assert not __loader__.is_package('pytest_not_there')
"""
)
testdir.mkpydir("fun")
result = testdir.runpytest()
result.stdout.fnmatch_lines(["* 3 passed*"])
def test_sys_meta_path_munged(self, testdir): def test_sys_meta_path_munged(self, testdir):
testdir.makepyfile( testdir.makepyfile(
""" """
@ -917,7 +923,7 @@ class TestAssertionRewriteHookDetails:
state = AssertionState(config, "rewrite") state = AssertionState(config, "rewrite")
source_path = tmpdir.ensure("source.py") source_path = tmpdir.ensure("source.py")
pycpath = tmpdir.join("pyc").strpath pycpath = tmpdir.join("pyc").strpath
assert _write_pyc(state, [1], source_path.stat(), pycpath) assert _write_pyc(state, [1], os.stat(source_path.strpath), pycpath)
@contextmanager @contextmanager
def atomic_write_failed(fn, mode="r", overwrite=False): def atomic_write_failed(fn, mode="r", overwrite=False):
@ -979,7 +985,7 @@ class TestAssertionRewriteHookDetails:
assert len(contents) > strip_bytes assert len(contents) > strip_bytes
pyc.write(contents[:strip_bytes], mode="wb") pyc.write(contents[:strip_bytes], mode="wb")
assert _read_pyc(source, str(pyc)) is None # no error assert _read_pyc(str(source), str(pyc)) is None # no error
def test_reload_is_same(self, testdir): def test_reload_is_same(self, testdir):
# A file that will be picked up during collecting. # A file that will be picked up during collecting.
@ -1186,14 +1192,17 @@ def test_rewrite_infinite_recursion(testdir, pytestconfig, monkeypatch):
# make a note that we have called _write_pyc # make a note that we have called _write_pyc
write_pyc_called.append(True) write_pyc_called.append(True)
# try to import a module at this point: we should not try to rewrite this module # try to import a module at this point: we should not try to rewrite this module
assert hook.find_module("test_bar") is None assert hook.find_spec("test_bar") is None
return original_write_pyc(*args, **kwargs) return original_write_pyc(*args, **kwargs)
monkeypatch.setattr(rewrite, "_write_pyc", spy_write_pyc) monkeypatch.setattr(rewrite, "_write_pyc", spy_write_pyc)
monkeypatch.setattr(sys, "dont_write_bytecode", False) monkeypatch.setattr(sys, "dont_write_bytecode", False)
hook = AssertionRewritingHook(pytestconfig) hook = AssertionRewritingHook(pytestconfig)
assert hook.find_module("test_foo") is not None spec = hook.find_spec("test_foo")
assert spec is not None
module = importlib.util.module_from_spec(spec)
hook.exec_module(module)
assert len(write_pyc_called) == 1 assert len(write_pyc_called) == 1
@ -1201,11 +1210,11 @@ class TestEarlyRewriteBailout:
@pytest.fixture @pytest.fixture
def hook(self, pytestconfig, monkeypatch, testdir): def hook(self, pytestconfig, monkeypatch, testdir):
"""Returns a patched AssertionRewritingHook instance so we can configure its initial paths and track """Returns a patched AssertionRewritingHook instance so we can configure its initial paths and track
if imp.find_module has been called. if PathFinder.find_spec has been called.
""" """
import imp import importlib.machinery
self.find_module_calls = [] self.find_spec_calls = []
self.initial_paths = set() self.initial_paths = set()
class StubSession: class StubSession:
@ -1214,22 +1223,22 @@ class TestEarlyRewriteBailout:
def isinitpath(self, p): def isinitpath(self, p):
return p in self._initialpaths return p in self._initialpaths
def spy_imp_find_module(name, path): def spy_find_spec(name, path):
self.find_module_calls.append(name) self.find_spec_calls.append(name)
return imp.find_module(name, path) return importlib.machinery.PathFinder.find_spec(name, path)
hook = AssertionRewritingHook(pytestconfig) hook = AssertionRewritingHook(pytestconfig)
# use default patterns, otherwise we inherit pytest's testing config # use default patterns, otherwise we inherit pytest's testing config
hook.fnpats[:] = ["test_*.py", "*_test.py"] hook.fnpats[:] = ["test_*.py", "*_test.py"]
monkeypatch.setattr(hook, "_imp_find_module", spy_imp_find_module) monkeypatch.setattr(hook, "_find_spec", spy_find_spec)
hook.set_session(StubSession()) hook.set_session(StubSession())
testdir.syspathinsert() testdir.syspathinsert()
return hook return hook
def test_basic(self, testdir, hook): def test_basic(self, testdir, hook):
""" """
Ensure we avoid calling imp.find_module when we know for sure a certain module will not be rewritten Ensure we avoid calling PathFinder.find_spec when we know for sure a certain
to optimize assertion rewriting (#3918). module will not be rewritten to optimize assertion rewriting (#3918).
""" """
testdir.makeconftest( testdir.makeconftest(
""" """
@ -1244,24 +1253,24 @@ class TestEarlyRewriteBailout:
self.initial_paths.add(foobar_path) self.initial_paths.add(foobar_path)
# conftest files should always be rewritten # conftest files should always be rewritten
assert hook.find_module("conftest") is not None assert hook.find_spec("conftest") is not None
assert self.find_module_calls == ["conftest"] assert self.find_spec_calls == ["conftest"]
# files matching "python_files" mask should always be rewritten # files matching "python_files" mask should always be rewritten
assert hook.find_module("test_foo") is not None assert hook.find_spec("test_foo") is not None
assert self.find_module_calls == ["conftest", "test_foo"] assert self.find_spec_calls == ["conftest", "test_foo"]
# file does not match "python_files": early bailout # file does not match "python_files": early bailout
assert hook.find_module("bar") is None assert hook.find_spec("bar") is None
assert self.find_module_calls == ["conftest", "test_foo"] assert self.find_spec_calls == ["conftest", "test_foo"]
# file is an initial path (passed on the command-line): should be rewritten # file is an initial path (passed on the command-line): should be rewritten
assert hook.find_module("foobar") is not None assert hook.find_spec("foobar") is not None
assert self.find_module_calls == ["conftest", "test_foo", "foobar"] assert self.find_spec_calls == ["conftest", "test_foo", "foobar"]
def test_pattern_contains_subdirectories(self, testdir, hook): def test_pattern_contains_subdirectories(self, testdir, hook):
"""If one of the python_files patterns contain subdirectories ("tests/**.py") we can't bailout early """If one of the python_files patterns contain subdirectories ("tests/**.py") we can't bailout early
because we need to match with the full path, which can only be found by calling imp.find_module. because we need to match with the full path, which can only be found by calling PathFinder.find_spec
""" """
p = testdir.makepyfile( p = testdir.makepyfile(
**{ **{
@ -1273,8 +1282,8 @@ class TestEarlyRewriteBailout:
) )
testdir.syspathinsert(p.dirpath()) testdir.syspathinsert(p.dirpath())
hook.fnpats[:] = ["tests/**.py"] hook.fnpats[:] = ["tests/**.py"]
assert hook.find_module("file") is not None assert hook.find_spec("file") is not None
assert self.find_module_calls == ["file"] assert self.find_spec_calls == ["file"]
@pytest.mark.skipif( @pytest.mark.skipif(
sys.platform.startswith("win32"), reason="cannot remove cwd on Windows" sys.platform.startswith("win32"), reason="cannot remove cwd on Windows"

View File

@ -1,3 +1,4 @@
import os.path
import sys import sys
import py import py
@ -53,6 +54,10 @@ class TestPort:
def test_matching(self, match, pattern, path): def test_matching(self, match, pattern, path):
assert match(pattern, path) assert match(pattern, path)
def test_matching_abspath(self, match):
abspath = os.path.abspath(os.path.join("tests/foo.py"))
assert match("tests/foo.py", abspath)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"pattern, path", "pattern, path",
[ [