diff --git a/.gitignore b/.gitignore index afb6bf9fd..2ae5ea752 100644 --- a/.gitignore +++ b/.gitignore @@ -38,3 +38,6 @@ env/ .ropeproject .idea .hypothesis +.pydevproject +.project +.settings diff --git a/AUTHORS b/AUTHORS index c39c0c68a..ed941a7fc 100644 --- a/AUTHORS +++ b/AUTHORS @@ -72,6 +72,7 @@ Endre Galaczi Eric Hunsberger Eric Siegerman Erik M. Bray +Fabio Zadrozny Feng Ma Florian Bruhin Floris Bruynooghe diff --git a/changelog/3918.bugfix.rst b/changelog/3918.bugfix.rst new file mode 100644 index 000000000..7ba811916 --- /dev/null +++ b/changelog/3918.bugfix.rst @@ -0,0 +1 @@ +Improve performance of assertion rewriting. diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index a48a931ac..6f46da8fe 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -67,14 +67,24 @@ class AssertionRewritingHook(object): # flag to guard against trying to rewrite a pyc file while we are already writing another pyc file, # which might result in infinite recursion (#3506) self._writing_pyc = False + self._basenames_to_check_rewrite = {"conftest"} + self._marked_for_rewrite_cache = {} + self._session_paths_checked = False def set_session(self, session): self.session = session + self._session_paths_checked = False + + def _imp_find_module(self, name, path=None): + """Indirection so we can mock calls to find_module originated from the hook during testing""" + return imp.find_module(name, path) def find_module(self, name, path=None): if self._writing_pyc: return None state = self.config._assertstate + if self._early_rewrite_bailout(name, state): + return None state.trace("find_module called for: %s" % name) names = name.rsplit(".", 1) lastname = names[-1] @@ -87,7 +97,7 @@ class AssertionRewritingHook(object): pth = path[0] if pth is None: try: - fd, fn, desc = imp.find_module(lastname, path) + fd, fn, desc = self._imp_find_module(lastname, path) except ImportError: return None if fd is not None: @@ -166,6 +176,44 @@ class AssertionRewritingHook(object): self.modules[name] = co, pyc return self + def _early_rewrite_bailout(self, name, state): + """ + This is a fast way to get out of rewriting modules. Profiling has + shown that the call to imp.find_module (inside of the find_module + from this class) is a major slowdown, so, this method tries to + filter what we're sure won't be rewritten before getting to it. + """ + if self.session is not None and not self._session_paths_checked: + self._session_paths_checked = True + for path in self.session._initialpaths: + # Make something as c:/projects/my_project/path.py -> + # ['c:', 'projects', 'my_project', 'path.py'] + parts = str(path).split(os.path.sep) + # add 'path' to basenames to be checked. + self._basenames_to_check_rewrite.add(os.path.splitext(parts[-1])[0]) + + # Note: conftest already by default in _basenames_to_check_rewrite. + parts = name.split(".") + if parts[-1] in self._basenames_to_check_rewrite: + return False + + # For matching the name it must be as if it was a filename. + parts[-1] = parts[-1] + ".py" + fn_pypath = py.path.local(os.path.sep.join(parts)) + for pat in self.fnpats: + # if the pattern contains subdirectories ("tests/**.py" for example) we can't bail out based + # on the name alone because we need to match against the full path + if os.path.dirname(pat): + return False + if fn_pypath.fnmatch(pat): + return False + + if self._is_marked_for_rewrite(name, state): + return False + + state.trace("early skip of rewriting module: %s" % (name,)) + return True + def _should_rewrite(self, name, fn_pypath, state): # always rewrite conftest files fn = str(fn_pypath) @@ -185,12 +233,20 @@ class AssertionRewritingHook(object): state.trace("matched test file %r" % (fn,)) return True - for marked in self._must_rewrite: - if name == marked or name.startswith(marked + "."): - state.trace("matched marked file %r (from %r)" % (name, marked)) - return True + return self._is_marked_for_rewrite(name, state) - return False + def _is_marked_for_rewrite(self, name, state): + try: + return self._marked_for_rewrite_cache[name] + except KeyError: + for marked in self._must_rewrite: + if name == marked or name.startswith(marked + "."): + state.trace("matched marked file %r (from %r)" % (name, marked)) + self._marked_for_rewrite_cache[name] = True + return True + + self._marked_for_rewrite_cache[name] = False + return False def mark_rewrite(self, *names): """Mark import names as needing to be rewritten. @@ -207,6 +263,7 @@ class AssertionRewritingHook(object): ): self._warn_already_imported(name) self._must_rewrite.update(names) + self._marked_for_rewrite_cache.clear() def _warn_already_imported(self, name): self.config.warn( @@ -241,7 +298,7 @@ class AssertionRewritingHook(object): def is_package(self, name): try: - fd, fn, desc = imp.find_module(name) + fd, fn, desc = self._imp_find_module(name) except ImportError: return False if fd is not None: diff --git a/src/_pytest/main.py b/src/_pytest/main.py index 947c6aa4b..f5078b9e7 100644 --- a/src/_pytest/main.py +++ b/src/_pytest/main.py @@ -383,6 +383,7 @@ class Session(nodes.FSCollector): self.trace = config.trace.root.get("collection") self._norecursepatterns = config.getini("norecursedirs") self.startdir = py.path.local() + self._initialpaths = frozenset() # Keep track of any collected nodes in here, so we don't duplicate fixtures self._node_cache = {} @@ -441,13 +442,14 @@ class Session(nodes.FSCollector): self.trace("perform_collect", self, args) self.trace.root.indent += 1 self._notfound = [] - self._initialpaths = set() + initialpaths = [] self._initialparts = [] self.items = items = [] for arg in args: parts = self._parsearg(arg) self._initialparts.append(parts) - self._initialpaths.add(parts[0]) + initialpaths.append(parts[0]) + self._initialpaths = frozenset(initialpaths) rep = collect_one_node(self) self.ihook.pytest_collectreport(report=rep) self.trace.root.indent -= 1 @@ -564,7 +566,6 @@ class Session(nodes.FSCollector): """Convert a dotted module name to path. """ - try: with _patched_find_module(): loader = pkgutil.find_loader(x) diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index c436ab0de..b70b50607 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -1106,22 +1106,21 @@ class TestIssue925(object): class TestIssue2121: - def test_simple(self, testdir): - testdir.tmpdir.join("tests/file.py").ensure().write( - """ -def test_simple_failure(): - assert 1 + 1 == 3 -""" - ) - testdir.tmpdir.join("pytest.ini").write( - textwrap.dedent( + def test_rewrite_python_files_contain_subdirs(self, testdir): + testdir.makepyfile( + **{ + "tests/file.py": """ + def test_simple_failure(): + assert 1 + 1 == 3 """ - [pytest] - python_files = tests/**.py - """ - ) + } + ) + testdir.makeini( + """ + [pytest] + python_files = tests/**.py + """ ) - result = testdir.runpytest() result.stdout.fnmatch_lines("*E*assert (1 + 1) == 3") @@ -1153,3 +1152,83 @@ def test_rewrite_infinite_recursion(testdir, pytestconfig, monkeypatch): hook = AssertionRewritingHook(pytestconfig) assert hook.find_module("test_foo") is not None assert len(write_pyc_called) == 1 + + +class TestEarlyRewriteBailout(object): + @pytest.fixture + def hook(self, pytestconfig, monkeypatch, testdir): + """Returns a patched AssertionRewritingHook instance so we can configure its initial paths and track + if imp.find_module has been called. + """ + import imp + + self.find_module_calls = [] + self.initial_paths = set() + + class StubSession(object): + _initialpaths = self.initial_paths + + def isinitpath(self, p): + return p in self._initialpaths + + def spy_imp_find_module(name, path): + self.find_module_calls.append(name) + return imp.find_module(name, path) + + hook = AssertionRewritingHook(pytestconfig) + # use default patterns, otherwise we inherit pytest's testing config + hook.fnpats[:] = ["test_*.py", "*_test.py"] + monkeypatch.setattr(hook, "_imp_find_module", spy_imp_find_module) + hook.set_session(StubSession()) + testdir.syspathinsert() + return hook + + def test_basic(self, testdir, hook): + """ + Ensure we avoid calling imp.find_module when we know for sure a certain module will not be rewritten + to optimize assertion rewriting (#3918). + """ + testdir.makeconftest( + """ + import pytest + @pytest.fixture + def fix(): return 1 + """ + ) + testdir.makepyfile(test_foo="def test_foo(): pass") + testdir.makepyfile(bar="def bar(): pass") + foobar_path = testdir.makepyfile(foobar="def foobar(): pass") + self.initial_paths.add(foobar_path) + + # conftest files should always be rewritten + assert hook.find_module("conftest") is not None + assert self.find_module_calls == ["conftest"] + + # files matching "python_files" mask should always be rewritten + assert hook.find_module("test_foo") is not None + assert self.find_module_calls == ["conftest", "test_foo"] + + # file does not match "python_files": early bailout + assert hook.find_module("bar") is None + assert self.find_module_calls == ["conftest", "test_foo"] + + # file is an initial path (passed on the command-line): should be rewritten + assert hook.find_module("foobar") is not None + assert self.find_module_calls == ["conftest", "test_foo", "foobar"] + + def test_pattern_contains_subdirectories(self, testdir, hook): + """If one of the python_files patterns contain subdirectories ("tests/**.py") we can't bailout early + because we need to match with the full path, which can only be found by calling imp.find_module. + """ + p = testdir.makepyfile( + **{ + "tests/file.py": """ + def test_simple_failure(): + assert 1 + 1 == 3 + """ + } + ) + testdir.syspathinsert(p.dirpath()) + hook.fnpats[:] = ["tests/**.py"] + assert hook.find_module("file") is not None + assert self.find_module_calls == ["file"]