Merge pull request #3919 from fabioz/master

Improve import performance of assertion rewrite. Fixes #3918.
This commit is contained in:
Ronny Pfannschmidt 2018-09-05 14:33:40 +02:00 committed by GitHub
commit 410d5762c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 166 additions and 24 deletions

3
.gitignore vendored
View File

@ -38,3 +38,6 @@ env/
.ropeproject .ropeproject
.idea .idea
.hypothesis .hypothesis
.pydevproject
.project
.settings

View File

@ -72,6 +72,7 @@ Endre Galaczi
Eric Hunsberger Eric Hunsberger
Eric Siegerman Eric Siegerman
Erik M. Bray Erik M. Bray
Fabio Zadrozny
Feng Ma Feng Ma
Florian Bruhin Florian Bruhin
Floris Bruynooghe Floris Bruynooghe

View File

@ -0,0 +1 @@
Improve performance of assertion rewriting.

View File

@ -67,14 +67,24 @@ class AssertionRewritingHook(object):
# flag to guard against trying to rewrite a pyc file while we are already writing another pyc file, # flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
# which might result in infinite recursion (#3506) # which might result in infinite recursion (#3506)
self._writing_pyc = False self._writing_pyc = False
self._basenames_to_check_rewrite = {"conftest"}
self._marked_for_rewrite_cache = {}
self._session_paths_checked = False
def set_session(self, session): def set_session(self, session):
self.session = session self.session = session
self._session_paths_checked = False
def _imp_find_module(self, name, path=None):
"""Indirection so we can mock calls to find_module originated from the hook during testing"""
return imp.find_module(name, path)
def find_module(self, name, path=None): def find_module(self, name, path=None):
if self._writing_pyc: if self._writing_pyc:
return None return None
state = self.config._assertstate state = self.config._assertstate
if self._early_rewrite_bailout(name, state):
return None
state.trace("find_module called for: %s" % name) state.trace("find_module called for: %s" % name)
names = name.rsplit(".", 1) names = name.rsplit(".", 1)
lastname = names[-1] lastname = names[-1]
@ -87,7 +97,7 @@ class AssertionRewritingHook(object):
pth = path[0] pth = path[0]
if pth is None: if pth is None:
try: try:
fd, fn, desc = imp.find_module(lastname, path) fd, fn, desc = self._imp_find_module(lastname, path)
except ImportError: except ImportError:
return None return None
if fd is not None: if fd is not None:
@ -166,6 +176,44 @@ class AssertionRewritingHook(object):
self.modules[name] = co, pyc self.modules[name] = co, pyc
return self return self
def _early_rewrite_bailout(self, name, state):
"""
This is a fast way to get out of rewriting modules. Profiling has
shown that the call to imp.find_module (inside of the find_module
from this class) is a major slowdown, so, this method tries to
filter what we're sure won't be rewritten before getting to it.
"""
if self.session is not None and not self._session_paths_checked:
self._session_paths_checked = True
for path in self.session._initialpaths:
# Make something as c:/projects/my_project/path.py ->
# ['c:', 'projects', 'my_project', 'path.py']
parts = str(path).split(os.path.sep)
# add 'path' to basenames to be checked.
self._basenames_to_check_rewrite.add(os.path.splitext(parts[-1])[0])
# Note: conftest already by default in _basenames_to_check_rewrite.
parts = name.split(".")
if parts[-1] in self._basenames_to_check_rewrite:
return False
# For matching the name it must be as if it was a filename.
parts[-1] = parts[-1] + ".py"
fn_pypath = py.path.local(os.path.sep.join(parts))
for pat in self.fnpats:
# if the pattern contains subdirectories ("tests/**.py" for example) we can't bail out based
# on the name alone because we need to match against the full path
if os.path.dirname(pat):
return False
if fn_pypath.fnmatch(pat):
return False
if self._is_marked_for_rewrite(name, state):
return False
state.trace("early skip of rewriting module: %s" % (name,))
return True
def _should_rewrite(self, name, fn_pypath, state): def _should_rewrite(self, name, fn_pypath, state):
# always rewrite conftest files # always rewrite conftest files
fn = str(fn_pypath) fn = str(fn_pypath)
@ -185,12 +233,20 @@ class AssertionRewritingHook(object):
state.trace("matched test file %r" % (fn,)) state.trace("matched test file %r" % (fn,))
return True return True
for marked in self._must_rewrite: return self._is_marked_for_rewrite(name, state)
if name == marked or name.startswith(marked + "."):
state.trace("matched marked file %r (from %r)" % (name, marked))
return True
return False def _is_marked_for_rewrite(self, name, state):
try:
return self._marked_for_rewrite_cache[name]
except KeyError:
for marked in self._must_rewrite:
if name == marked or name.startswith(marked + "."):
state.trace("matched marked file %r (from %r)" % (name, marked))
self._marked_for_rewrite_cache[name] = True
return True
self._marked_for_rewrite_cache[name] = False
return False
def mark_rewrite(self, *names): def mark_rewrite(self, *names):
"""Mark import names as needing to be rewritten. """Mark import names as needing to be rewritten.
@ -207,6 +263,7 @@ class AssertionRewritingHook(object):
): ):
self._warn_already_imported(name) self._warn_already_imported(name)
self._must_rewrite.update(names) self._must_rewrite.update(names)
self._marked_for_rewrite_cache.clear()
def _warn_already_imported(self, name): def _warn_already_imported(self, name):
self.config.warn( self.config.warn(
@ -241,7 +298,7 @@ class AssertionRewritingHook(object):
def is_package(self, name): def is_package(self, name):
try: try:
fd, fn, desc = imp.find_module(name) fd, fn, desc = self._imp_find_module(name)
except ImportError: except ImportError:
return False return False
if fd is not None: if fd is not None:

View File

@ -383,6 +383,7 @@ class Session(nodes.FSCollector):
self.trace = config.trace.root.get("collection") self.trace = config.trace.root.get("collection")
self._norecursepatterns = config.getini("norecursedirs") self._norecursepatterns = config.getini("norecursedirs")
self.startdir = py.path.local() self.startdir = py.path.local()
self._initialpaths = frozenset()
# Keep track of any collected nodes in here, so we don't duplicate fixtures # Keep track of any collected nodes in here, so we don't duplicate fixtures
self._node_cache = {} self._node_cache = {}
@ -441,13 +442,14 @@ class Session(nodes.FSCollector):
self.trace("perform_collect", self, args) self.trace("perform_collect", self, args)
self.trace.root.indent += 1 self.trace.root.indent += 1
self._notfound = [] self._notfound = []
self._initialpaths = set() initialpaths = []
self._initialparts = [] self._initialparts = []
self.items = items = [] self.items = items = []
for arg in args: for arg in args:
parts = self._parsearg(arg) parts = self._parsearg(arg)
self._initialparts.append(parts) self._initialparts.append(parts)
self._initialpaths.add(parts[0]) initialpaths.append(parts[0])
self._initialpaths = frozenset(initialpaths)
rep = collect_one_node(self) rep = collect_one_node(self)
self.ihook.pytest_collectreport(report=rep) self.ihook.pytest_collectreport(report=rep)
self.trace.root.indent -= 1 self.trace.root.indent -= 1
@ -564,7 +566,6 @@ class Session(nodes.FSCollector):
"""Convert a dotted module name to path. """Convert a dotted module name to path.
""" """
try: try:
with _patched_find_module(): with _patched_find_module():
loader = pkgutil.find_loader(x) loader = pkgutil.find_loader(x)

View File

@ -1106,22 +1106,21 @@ class TestIssue925(object):
class TestIssue2121: class TestIssue2121:
def test_simple(self, testdir): def test_rewrite_python_files_contain_subdirs(self, testdir):
testdir.tmpdir.join("tests/file.py").ensure().write( testdir.makepyfile(
""" **{
def test_simple_failure(): "tests/file.py": """
assert 1 + 1 == 3 def test_simple_failure():
""" assert 1 + 1 == 3
)
testdir.tmpdir.join("pytest.ini").write(
textwrap.dedent(
""" """
[pytest] }
python_files = tests/**.py )
""" testdir.makeini(
) """
[pytest]
python_files = tests/**.py
"""
) )
result = testdir.runpytest() result = testdir.runpytest()
result.stdout.fnmatch_lines("*E*assert (1 + 1) == 3") result.stdout.fnmatch_lines("*E*assert (1 + 1) == 3")
@ -1153,3 +1152,83 @@ def test_rewrite_infinite_recursion(testdir, pytestconfig, monkeypatch):
hook = AssertionRewritingHook(pytestconfig) hook = AssertionRewritingHook(pytestconfig)
assert hook.find_module("test_foo") is not None assert hook.find_module("test_foo") is not None
assert len(write_pyc_called) == 1 assert len(write_pyc_called) == 1
class TestEarlyRewriteBailout(object):
@pytest.fixture
def hook(self, pytestconfig, monkeypatch, testdir):
"""Returns a patched AssertionRewritingHook instance so we can configure its initial paths and track
if imp.find_module has been called.
"""
import imp
self.find_module_calls = []
self.initial_paths = set()
class StubSession(object):
_initialpaths = self.initial_paths
def isinitpath(self, p):
return p in self._initialpaths
def spy_imp_find_module(name, path):
self.find_module_calls.append(name)
return imp.find_module(name, path)
hook = AssertionRewritingHook(pytestconfig)
# use default patterns, otherwise we inherit pytest's testing config
hook.fnpats[:] = ["test_*.py", "*_test.py"]
monkeypatch.setattr(hook, "_imp_find_module", spy_imp_find_module)
hook.set_session(StubSession())
testdir.syspathinsert()
return hook
def test_basic(self, testdir, hook):
"""
Ensure we avoid calling imp.find_module when we know for sure a certain module will not be rewritten
to optimize assertion rewriting (#3918).
"""
testdir.makeconftest(
"""
import pytest
@pytest.fixture
def fix(): return 1
"""
)
testdir.makepyfile(test_foo="def test_foo(): pass")
testdir.makepyfile(bar="def bar(): pass")
foobar_path = testdir.makepyfile(foobar="def foobar(): pass")
self.initial_paths.add(foobar_path)
# conftest files should always be rewritten
assert hook.find_module("conftest") is not None
assert self.find_module_calls == ["conftest"]
# files matching "python_files" mask should always be rewritten
assert hook.find_module("test_foo") is not None
assert self.find_module_calls == ["conftest", "test_foo"]
# file does not match "python_files": early bailout
assert hook.find_module("bar") is None
assert self.find_module_calls == ["conftest", "test_foo"]
# file is an initial path (passed on the command-line): should be rewritten
assert hook.find_module("foobar") is not None
assert self.find_module_calls == ["conftest", "test_foo", "foobar"]
def test_pattern_contains_subdirectories(self, testdir, hook):
"""If one of the python_files patterns contain subdirectories ("tests/**.py") we can't bailout early
because we need to match with the full path, which can only be found by calling imp.find_module.
"""
p = testdir.makepyfile(
**{
"tests/file.py": """
def test_simple_failure():
assert 1 + 1 == 3
"""
}
)
testdir.syspathinsert(p.dirpath())
hook.fnpats[:] = ["tests/**.py"]
assert hook.find_module("file") is not None
assert self.find_module_calls == ["file"]