Merge pull request #5468 from asottile/switch_importlib_to_imp

Switch from deprecated imp to importlib
This commit is contained in:
Anthony Sottile 2019-06-24 11:22:01 -07:00 committed by GitHub
commit 61dcb84f0d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 182 additions and 242 deletions

View File

@ -0,0 +1 @@
Switch from ``imp`` to ``importlib``.

View File

@ -0,0 +1 @@
Honor PEP 235 on case-insensitive file systems.

View File

@ -0,0 +1 @@
Test module is no longer double-imported when using ``--pyargs``.

View File

@ -0,0 +1 @@
Prevent "already imported" warnings from assertion rewriter when invoking pytest in-process multiple times.

View File

@ -0,0 +1 @@
Fix assertion rewriting in packages (``__init__.py``).

View File

@ -1,18 +1,16 @@
"""Rewrite assertion AST to produce nice error messages"""
import ast
import errno
import imp
import importlib.machinery
import importlib.util
import itertools
import marshal
import os
import re
import struct
import sys
import types
from importlib.util import spec_from_file_location
import atomicwrites
import py
from _pytest._io.saferepr import saferepr
from _pytest.assertion import util
@ -23,23 +21,13 @@ from _pytest.pathlib import fnmatch_ex
from _pytest.pathlib import PurePath
# pytest caches rewritten pycs in __pycache__.
if hasattr(imp, "get_tag"):
PYTEST_TAG = imp.get_tag() + "-PYTEST"
else:
if hasattr(sys, "pypy_version_info"):
impl = "pypy"
else:
impl = "cpython"
ver = sys.version_info
PYTEST_TAG = "{}-{}{}-PYTEST".format(impl, ver[0], ver[1])
del ver, impl
PYTEST_TAG = "{}-PYTEST".format(sys.implementation.cache_tag)
PYC_EXT = ".py" + (__debug__ and "c" or "o")
PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
class AssertionRewritingHook:
"""PEP302 Import hook which rewrites asserts."""
"""PEP302/PEP451 import hook which rewrites asserts."""
def __init__(self, config):
self.config = config
@ -48,7 +36,6 @@ class AssertionRewritingHook:
except ValueError:
self.fnpats = ["test_*.py", "*_test.py"]
self.session = None
self.modules = {}
self._rewritten_names = set()
self._must_rewrite = set()
# flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
@ -62,55 +49,51 @@ class AssertionRewritingHook:
self.session = session
self._session_paths_checked = False
def _imp_find_module(self, name, path=None):
"""Indirection so we can mock calls to find_module originated from the hook during testing"""
return imp.find_module(name, path)
# Indirection so we can mock calls to find_spec originated from the hook during testing
_find_spec = importlib.machinery.PathFinder.find_spec
def find_module(self, name, path=None):
def find_spec(self, name, path=None, target=None):
if self._writing_pyc:
return None
state = self.config._assertstate
if self._early_rewrite_bailout(name, state):
return None
state.trace("find_module called for: %s" % name)
names = name.rsplit(".", 1)
lastname = names[-1]
pth = None
if path is not None:
# Starting with Python 3.3, path is a _NamespacePath(), which
# causes problems if not converted to list.
path = list(path)
if len(path) == 1:
pth = path[0]
if pth is None:
try:
fd, fn, desc = self._imp_find_module(lastname, path)
except ImportError:
return None
if fd is not None:
fd.close()
tp = desc[2]
if tp == imp.PY_COMPILED:
if hasattr(imp, "source_from_cache"):
try:
fn = imp.source_from_cache(fn)
except ValueError:
# Python 3 doesn't like orphaned but still-importable
# .pyc files.
fn = fn[:-1]
else:
fn = fn[:-1]
elif tp != imp.PY_SOURCE:
# Don't know what this is.
return None
else:
fn = os.path.join(pth, name.rpartition(".")[2] + ".py")
fn_pypath = py.path.local(fn)
if not self._should_rewrite(name, fn_pypath, state):
spec = self._find_spec(name, path)
if (
# the import machinery could not find a file to import
spec is None
# this is a namespace package (without `__init__.py`)
# there's nothing to rewrite there
# python3.5 - python3.6: `namespace`
# python3.7+: `None`
or spec.origin in {None, "namespace"}
# if the file doesn't exist, we can't rewrite it
or not os.path.exists(spec.origin)
):
return None
else:
fn = spec.origin
if not self._should_rewrite(name, fn, state):
return None
self._rewritten_names.add(name)
return importlib.util.spec_from_file_location(
name,
fn,
loader=self,
submodule_search_locations=spec.submodule_search_locations,
)
def create_module(self, spec):
return None # default behaviour is fine
def exec_module(self, module):
fn = module.__spec__.origin
state = self.config._assertstate
self._rewritten_names.add(module.__name__)
# The requested module looks like a test file, so rewrite it. This is
# the most magical part of the process: load the source, rewrite the
@ -121,7 +104,7 @@ class AssertionRewritingHook:
# cached pyc is always a complete, valid pyc. Operations on it must be
# atomic. POSIX's atomic rename comes in handy.
write = not sys.dont_write_bytecode
cache_dir = os.path.join(fn_pypath.dirname, "__pycache__")
cache_dir = os.path.join(os.path.dirname(fn), "__pycache__")
if write:
try:
os.mkdir(cache_dir)
@ -132,26 +115,23 @@ class AssertionRewritingHook:
# common case) or it's blocked by a non-dir node. In the
# latter case, we'll ignore it in _write_pyc.
pass
elif e in [errno.ENOENT, errno.ENOTDIR]:
elif e in {errno.ENOENT, errno.ENOTDIR}:
# One of the path components was not a directory, likely
# because we're in a zip file.
write = False
elif e in [errno.EACCES, errno.EROFS, errno.EPERM]:
state.trace("read only directory: %r" % fn_pypath.dirname)
elif e in {errno.EACCES, errno.EROFS, errno.EPERM}:
state.trace("read only directory: %r" % os.path.dirname(fn))
write = False
else:
raise
cache_name = fn_pypath.basename[:-3] + PYC_TAIL
cache_name = os.path.basename(fn)[:-3] + PYC_TAIL
pyc = os.path.join(cache_dir, cache_name)
# Notice that even if we're in a read-only directory, I'm going
# to check for a cached pyc. This may not be optimal...
co = _read_pyc(fn_pypath, pyc, state.trace)
co = _read_pyc(fn, pyc, state.trace)
if co is None:
state.trace("rewriting {!r}".format(fn))
source_stat, co = _rewrite_test(self.config, fn_pypath)
if co is None:
# Probably a SyntaxError in the test.
return None
source_stat, co = _rewrite_test(fn)
if write:
self._writing_pyc = True
try:
@ -160,13 +140,11 @@ class AssertionRewritingHook:
self._writing_pyc = False
else:
state.trace("found cached rewritten pyc for {!r}".format(fn))
self.modules[name] = co, pyc
return self
exec(co, module.__dict__)
def _early_rewrite_bailout(self, name, state):
"""
This is a fast way to get out of rewriting modules. Profiling has
shown that the call to imp.find_module (inside of the find_module
"""This is a fast way to get out of rewriting modules. Profiling has
shown that the call to PathFinder.find_spec (inside of the find_spec
from this class) is a major slowdown, so, this method tries to
filter what we're sure won't be rewritten before getting to it.
"""
@ -201,10 +179,9 @@ class AssertionRewritingHook:
state.trace("early skip of rewriting module: {}".format(name))
return True
def _should_rewrite(self, name, fn_pypath, state):
def _should_rewrite(self, name, fn, state):
# always rewrite conftest files
fn = str(fn_pypath)
if fn_pypath.basename == "conftest.py":
if os.path.basename(fn) == "conftest.py":
state.trace("rewriting conftest file: {!r}".format(fn))
return True
@ -217,8 +194,9 @@ class AssertionRewritingHook:
# modules not passed explicitly on the command line are only
# rewritten if they match the naming convention for test files
fn_path = PurePath(fn)
for pat in self.fnpats:
if fn_pypath.fnmatch(pat):
if fnmatch_ex(pat, fn_path):
state.trace("matched test file {!r}".format(fn))
return True
@ -249,9 +227,10 @@ class AssertionRewritingHook:
set(names).intersection(sys.modules).difference(self._rewritten_names)
)
for name in already_imported:
mod = sys.modules[name]
if not AssertionRewriter.is_rewrite_disabled(
sys.modules[name].__doc__ or ""
):
mod.__doc__ or ""
) and not isinstance(mod.__loader__, type(self)):
self._warn_already_imported(name)
self._must_rewrite.update(names)
self._marked_for_rewrite_cache.clear()
@ -268,45 +247,8 @@ class AssertionRewritingHook:
stacklevel=5,
)
def load_module(self, name):
co, pyc = self.modules.pop(name)
if name in sys.modules:
# If there is an existing module object named 'fullname' in
# sys.modules, the loader must use that existing module. (Otherwise,
# the reload() builtin will not work correctly.)
mod = sys.modules[name]
else:
# I wish I could just call imp.load_compiled here, but __file__ has to
# be set properly. In Python 3.2+, this all would be handled correctly
# by load_compiled.
mod = sys.modules[name] = imp.new_module(name)
try:
mod.__file__ = co.co_filename
# Normally, this attribute is 3.2+.
mod.__cached__ = pyc
mod.__loader__ = self
# Normally, this attribute is 3.4+
mod.__spec__ = spec_from_file_location(name, co.co_filename, loader=self)
exec(co, mod.__dict__)
except: # noqa
if name in sys.modules:
del sys.modules[name]
raise
return sys.modules[name]
def is_package(self, name):
try:
fd, fn, desc = self._imp_find_module(name)
except ImportError:
return False
if fd is not None:
fd.close()
tp = desc[2]
return tp == imp.PKG_DIRECTORY
def get_data(self, pathname):
"""Optional PEP302 get_data API.
"""
"""Optional PEP302 get_data API."""
with open(pathname, "rb") as f:
return f.read()
@ -314,15 +256,13 @@ class AssertionRewritingHook:
def _write_pyc(state, co, source_stat, pyc):
# Technically, we don't have to have the same pyc format as
# (C)Python, since these "pycs" should never be seen by builtin
# import. However, there's little reason deviate, and I hope
# sometime to be able to use imp.load_compiled to load them. (See
# the comment in load_module above.)
# import. However, there's little reason deviate.
try:
with atomicwrites.atomic_write(pyc, mode="wb", overwrite=True) as fp:
fp.write(imp.get_magic())
fp.write(importlib.util.MAGIC_NUMBER)
# as of now, bytecode header expects 32-bit numbers for size and mtime (#4903)
mtime = int(source_stat.mtime) & 0xFFFFFFFF
size = source_stat.size & 0xFFFFFFFF
mtime = int(source_stat.st_mtime) & 0xFFFFFFFF
size = source_stat.st_size & 0xFFFFFFFF
# "<LL" stands for 2 unsigned longs, little-ending
fp.write(struct.pack("<LL", mtime, size))
fp.write(marshal.dumps(co))
@ -335,35 +275,14 @@ def _write_pyc(state, co, source_stat, pyc):
return True
RN = b"\r\n"
N = b"\n"
cookie_re = re.compile(r"^[ \t\f]*#.*coding[:=][ \t]*[-\w.]+")
BOM_UTF8 = "\xef\xbb\xbf"
def _rewrite_test(config, fn):
"""Try to read and rewrite *fn* and return the code object."""
state = config._assertstate
try:
stat = fn.stat()
source = fn.read("rb")
except EnvironmentError:
return None, None
try:
tree = ast.parse(source, filename=fn.strpath)
except SyntaxError:
# Let this pop up again in the real import.
state.trace("failed to parse: {!r}".format(fn))
return None, None
rewrite_asserts(tree, fn, config)
try:
co = compile(tree, fn.strpath, "exec", dont_inherit=True)
except SyntaxError:
# It's possible that this error is from some bug in the
# assertion rewriting, but I don't know of a fast way to tell.
state.trace("failed to compile: {!r}".format(fn))
return None, None
def _rewrite_test(fn):
"""read and rewrite *fn* and return the code object."""
stat = os.stat(fn)
with open(fn, "rb") as f:
source = f.read()
tree = ast.parse(source, filename=fn)
rewrite_asserts(tree, fn)
co = compile(tree, fn, "exec", dont_inherit=True)
return stat, co
@ -378,8 +297,9 @@ def _read_pyc(source, pyc, trace=lambda x: None):
return None
with fp:
try:
mtime = int(source.mtime())
size = source.size()
stat_result = os.stat(source)
mtime = int(stat_result.st_mtime)
size = stat_result.st_size
data = fp.read(12)
except EnvironmentError as e:
trace("_read_pyc({}): EnvironmentError {}".format(source, e))
@ -387,7 +307,7 @@ def _read_pyc(source, pyc, trace=lambda x: None):
# Check for invalid or out of date pyc file.
if (
len(data) != 12
or data[:4] != imp.get_magic()
or data[:4] != importlib.util.MAGIC_NUMBER
or struct.unpack("<LL", data[4:]) != (mtime & 0xFFFFFFFF, size & 0xFFFFFFFF)
):
trace("_read_pyc(%s): invalid or out of date pyc" % source)
@ -403,9 +323,9 @@ def _read_pyc(source, pyc, trace=lambda x: None):
return co
def rewrite_asserts(mod, module_path=None, config=None):
def rewrite_asserts(mod, module_path=None):
"""Rewrite the assert statements in mod."""
AssertionRewriter(module_path, config).run(mod)
AssertionRewriter(module_path).run(mod)
def _saferepr(obj):
@ -586,10 +506,9 @@ class AssertionRewriter(ast.NodeVisitor):
"""
def __init__(self, module_path, config):
def __init__(self, module_path):
super().__init__()
self.module_path = module_path
self.config = config
def run(self, mod):
"""Find all assert statements in *mod* and rewrite them."""
@ -758,7 +677,7 @@ class AssertionRewriter(ast.NodeVisitor):
"assertion is always true, perhaps remove parentheses?"
),
category=None,
filename=str(self.module_path),
filename=self.module_path,
lineno=assert_.lineno,
)
@ -817,7 +736,7 @@ class AssertionRewriter(ast.NodeVisitor):
AST_NONE = ast.parse("None").body[0].value
val_is_none = ast.Compare(node, [ast.Is()], [AST_NONE])
send_warning = ast.parse(
"""
"""\
from _pytest.warning_types import PytestAssertRewriteWarning
from warnings import warn_explicit
warn_explicit(
@ -827,7 +746,7 @@ warn_explicit(
lineno={lineno},
)
""".format(
filename=module_path.strpath, lineno=lineno
filename=module_path, lineno=lineno
)
).body
return ast.If(val_is_none, send_warning, [])

View File

@ -2,8 +2,8 @@
import enum
import fnmatch
import functools
import importlib
import os
import pkgutil
import sys
import warnings
@ -630,21 +630,15 @@ class Session(nodes.FSCollector):
def _tryconvertpyarg(self, x):
"""Convert a dotted module name to path."""
try:
loader = pkgutil.find_loader(x)
except ImportError:
spec = importlib.util.find_spec(x)
except (ValueError, ImportError):
return x
if loader is None:
if spec is None or spec.origin in {None, "namespace"}:
return x
# This method is sometimes invoked when AssertionRewritingHook, which
# does not define a get_filename method, is already in place:
try:
path = loader.get_filename(x)
except AttributeError:
# Retrieve path from AssertionRewritingHook:
path = loader.modules[x][0].co_filename
if loader.is_package(x):
path = os.path.dirname(path)
return path
elif spec.submodule_search_locations:
return os.path.dirname(spec.origin)
else:
return spec.origin
def _parsearg(self, arg):
""" return (fspath, names) tuple after checking the file exists. """

View File

@ -294,6 +294,8 @@ def fnmatch_ex(pattern, path):
name = path.name
else:
name = str(path)
if path.is_absolute() and not os.path.isabs(pattern):
pattern = "*{}{}".format(os.sep, pattern)
return fnmatch.fnmatch(name, pattern)

View File

@ -1,5 +1,6 @@
"""(disabled by default) support for testing pytest and pytest plugins."""
import gc
import importlib
import os
import platform
import re
@ -16,7 +17,6 @@ import py
import pytest
from _pytest._code import Source
from _pytest._io.saferepr import saferepr
from _pytest.assertion.rewrite import AssertionRewritingHook
from _pytest.capture import MultiCapture
from _pytest.capture import SysCapture
from _pytest.main import ExitCode
@ -787,6 +787,11 @@ class Testdir:
:return: a :py:class:`HookRecorder` instance
"""
# (maybe a cpython bug?) the importlib cache sometimes isn't updated
# properly between file creation and inline_run (especially if imports
# are interspersed with file creation)
importlib.invalidate_caches()
plugins = list(plugins)
finalizers = []
try:
@ -796,18 +801,6 @@ class Testdir:
mp_run.setenv(k, v)
finalizers.append(mp_run.undo)
# When running pytest inline any plugins active in the main test
# process are already imported. So this disables the warning which
# will trigger to say they can no longer be rewritten, which is
# fine as they have already been rewritten.
orig_warn = AssertionRewritingHook._warn_already_imported
def revert_warn_already_imported():
AssertionRewritingHook._warn_already_imported = orig_warn
finalizers.append(revert_warn_already_imported)
AssertionRewritingHook._warn_already_imported = lambda *a: None
# Any sys.module or sys.path changes done while running pytest
# inline should be reverted after the test run completes to avoid
# clashing with later inline tests run within the same pytest test,

View File

@ -633,6 +633,19 @@ class TestInvocationVariants:
result.stdout.fnmatch_lines(["collected*0*items*/*1*errors"])
def test_pyargs_only_imported_once(self, testdir):
pkg = testdir.mkpydir("foo")
pkg.join("test_foo.py").write("print('hello from test_foo')\ndef test(): pass")
pkg.join("conftest.py").write(
"def pytest_configure(config): print('configuring')"
)
result = testdir.runpytest("--pyargs", "foo.test_foo", "-s", syspathinsert=True)
# should only import once
assert result.outlines.count("hello from test_foo") == 1
# should only configure once
assert result.outlines.count("configuring") == 1
def test_cmdline_python_package(self, testdir, monkeypatch):
import warnings

View File

@ -137,8 +137,8 @@ class TestImportHookInstallation:
"hamster.py": "",
"test_foo.py": """\
def test_foo(pytestconfig):
assert pytestconfig.pluginmanager.rewrite_hook.find_module('ham') is not None
assert pytestconfig.pluginmanager.rewrite_hook.find_module('hamster') is None
assert pytestconfig.pluginmanager.rewrite_hook.find_spec('ham') is not None
assert pytestconfig.pluginmanager.rewrite_hook.find_spec('hamster') is None
""",
}
testdir.makepyfile(**contents)

View File

@ -1,5 +1,6 @@
import ast
import glob
import importlib
import os
import py_compile
import stat
@ -117,6 +118,37 @@ class TestAssertionRewrite:
result = testdir.runpytest_subprocess()
assert "warnings" not in "".join(result.outlines)
def test_rewrites_plugin_as_a_package(self, testdir):
pkgdir = testdir.mkpydir("plugin")
pkgdir.join("__init__.py").write(
"import pytest\n"
"@pytest.fixture\n"
"def special_asserter():\n"
" def special_assert(x, y):\n"
" assert x == y\n"
" return special_assert\n"
)
testdir.makeconftest('pytest_plugins = ["plugin"]')
testdir.makepyfile("def test(special_asserter): special_asserter(1, 2)\n")
result = testdir.runpytest()
result.stdout.fnmatch_lines(["*assert 1 == 2*"])
def test_honors_pep_235(self, testdir, monkeypatch):
# note: couldn't make it fail on macos with a single `sys.path` entry
# note: these modules are named `test_*` to trigger rewriting
testdir.tmpdir.join("test_y.py").write("x = 1")
xdir = testdir.tmpdir.join("x").ensure_dir()
xdir.join("test_Y").ensure_dir().join("__init__.py").write("x = 2")
testdir.makepyfile(
"import test_y\n"
"import test_Y\n"
"def test():\n"
" assert test_y.x == 1\n"
" assert test_Y.x == 2\n"
)
monkeypatch.syspath_prepend(xdir)
testdir.runpytest().assert_outcomes(passed=1)
def test_name(self, request):
def f():
assert False
@ -831,8 +863,9 @@ def test_rewritten():
monkeypatch.setattr(
hook, "_warn_already_imported", lambda code, msg: warnings.append(msg)
)
hook.find_module("test_remember_rewritten_modules")
hook.load_module("test_remember_rewritten_modules")
spec = hook.find_spec("test_remember_rewritten_modules")
module = importlib.util.module_from_spec(spec)
hook.exec_module(module)
hook.mark_rewrite("test_remember_rewritten_modules")
hook.mark_rewrite("test_remember_rewritten_modules")
assert warnings == []
@ -872,33 +905,6 @@ def test_rewritten():
class TestAssertionRewriteHookDetails:
def test_loader_is_package_false_for_module(self, testdir):
testdir.makepyfile(
test_fun="""
def test_loader():
assert not __loader__.is_package(__name__)
"""
)
result = testdir.runpytest()
result.stdout.fnmatch_lines(["* 1 passed*"])
def test_loader_is_package_true_for_package(self, testdir):
testdir.makepyfile(
test_fun="""
def test_loader():
assert not __loader__.is_package(__name__)
def test_fun():
assert __loader__.is_package('fun')
def test_missing():
assert not __loader__.is_package('pytest_not_there')
"""
)
testdir.mkpydir("fun")
result = testdir.runpytest()
result.stdout.fnmatch_lines(["* 3 passed*"])
def test_sys_meta_path_munged(self, testdir):
testdir.makepyfile(
"""
@ -917,7 +923,7 @@ class TestAssertionRewriteHookDetails:
state = AssertionState(config, "rewrite")
source_path = tmpdir.ensure("source.py")
pycpath = tmpdir.join("pyc").strpath
assert _write_pyc(state, [1], source_path.stat(), pycpath)
assert _write_pyc(state, [1], os.stat(source_path.strpath), pycpath)
@contextmanager
def atomic_write_failed(fn, mode="r", overwrite=False):
@ -979,7 +985,7 @@ class TestAssertionRewriteHookDetails:
assert len(contents) > strip_bytes
pyc.write(contents[:strip_bytes], mode="wb")
assert _read_pyc(source, str(pyc)) is None # no error
assert _read_pyc(str(source), str(pyc)) is None # no error
def test_reload_is_same(self, testdir):
# A file that will be picked up during collecting.
@ -1186,14 +1192,17 @@ def test_rewrite_infinite_recursion(testdir, pytestconfig, monkeypatch):
# make a note that we have called _write_pyc
write_pyc_called.append(True)
# try to import a module at this point: we should not try to rewrite this module
assert hook.find_module("test_bar") is None
assert hook.find_spec("test_bar") is None
return original_write_pyc(*args, **kwargs)
monkeypatch.setattr(rewrite, "_write_pyc", spy_write_pyc)
monkeypatch.setattr(sys, "dont_write_bytecode", False)
hook = AssertionRewritingHook(pytestconfig)
assert hook.find_module("test_foo") is not None
spec = hook.find_spec("test_foo")
assert spec is not None
module = importlib.util.module_from_spec(spec)
hook.exec_module(module)
assert len(write_pyc_called) == 1
@ -1201,11 +1210,11 @@ class TestEarlyRewriteBailout:
@pytest.fixture
def hook(self, pytestconfig, monkeypatch, testdir):
"""Returns a patched AssertionRewritingHook instance so we can configure its initial paths and track
if imp.find_module has been called.
if PathFinder.find_spec has been called.
"""
import imp
import importlib.machinery
self.find_module_calls = []
self.find_spec_calls = []
self.initial_paths = set()
class StubSession:
@ -1214,22 +1223,22 @@ class TestEarlyRewriteBailout:
def isinitpath(self, p):
return p in self._initialpaths
def spy_imp_find_module(name, path):
self.find_module_calls.append(name)
return imp.find_module(name, path)
def spy_find_spec(name, path):
self.find_spec_calls.append(name)
return importlib.machinery.PathFinder.find_spec(name, path)
hook = AssertionRewritingHook(pytestconfig)
# use default patterns, otherwise we inherit pytest's testing config
hook.fnpats[:] = ["test_*.py", "*_test.py"]
monkeypatch.setattr(hook, "_imp_find_module", spy_imp_find_module)
monkeypatch.setattr(hook, "_find_spec", spy_find_spec)
hook.set_session(StubSession())
testdir.syspathinsert()
return hook
def test_basic(self, testdir, hook):
"""
Ensure we avoid calling imp.find_module when we know for sure a certain module will not be rewritten
to optimize assertion rewriting (#3918).
Ensure we avoid calling PathFinder.find_spec when we know for sure a certain
module will not be rewritten to optimize assertion rewriting (#3918).
"""
testdir.makeconftest(
"""
@ -1244,24 +1253,24 @@ class TestEarlyRewriteBailout:
self.initial_paths.add(foobar_path)
# conftest files should always be rewritten
assert hook.find_module("conftest") is not None
assert self.find_module_calls == ["conftest"]
assert hook.find_spec("conftest") is not None
assert self.find_spec_calls == ["conftest"]
# files matching "python_files" mask should always be rewritten
assert hook.find_module("test_foo") is not None
assert self.find_module_calls == ["conftest", "test_foo"]
assert hook.find_spec("test_foo") is not None
assert self.find_spec_calls == ["conftest", "test_foo"]
# file does not match "python_files": early bailout
assert hook.find_module("bar") is None
assert self.find_module_calls == ["conftest", "test_foo"]
assert hook.find_spec("bar") is None
assert self.find_spec_calls == ["conftest", "test_foo"]
# file is an initial path (passed on the command-line): should be rewritten
assert hook.find_module("foobar") is not None
assert self.find_module_calls == ["conftest", "test_foo", "foobar"]
assert hook.find_spec("foobar") is not None
assert self.find_spec_calls == ["conftest", "test_foo", "foobar"]
def test_pattern_contains_subdirectories(self, testdir, hook):
"""If one of the python_files patterns contain subdirectories ("tests/**.py") we can't bailout early
because we need to match with the full path, which can only be found by calling imp.find_module.
because we need to match with the full path, which can only be found by calling PathFinder.find_spec
"""
p = testdir.makepyfile(
**{
@ -1273,8 +1282,8 @@ class TestEarlyRewriteBailout:
)
testdir.syspathinsert(p.dirpath())
hook.fnpats[:] = ["tests/**.py"]
assert hook.find_module("file") is not None
assert self.find_module_calls == ["file"]
assert hook.find_spec("file") is not None
assert self.find_spec_calls == ["file"]
@pytest.mark.skipif(
sys.platform.startswith("win32"), reason="cannot remove cwd on Windows"

View File

@ -1,3 +1,4 @@
import os.path
import sys
import py
@ -53,6 +54,10 @@ class TestPort:
def test_matching(self, match, pattern, path):
assert match(pattern, path)
def test_matching_abspath(self, match):
abspath = os.path.abspath(os.path.join("tests/foo.py"))
assert match("tests/foo.py", abspath)
@pytest.mark.parametrize(
"pattern, path",
[