Improve performance of assertion rewriting. Fixes #3918

This commit is contained in:
Fabio Zadrozny 2018-08-31 12:27:08 -03:00
parent 4345efaffc
commit d53e449296
5 changed files with 63 additions and 17 deletions

3
.gitignore vendored
View File

@ -38,3 +38,6 @@ env/
.ropeproject .ropeproject
.idea .idea
.hypothesis .hypothesis
.pydevproject
.project
.settings

View File

@ -71,6 +71,7 @@ Endre Galaczi
Eric Hunsberger Eric Hunsberger
Eric Siegerman Eric Siegerman
Erik M. Bray Erik M. Bray
Fabio Zadrozny
Feng Ma Feng Ma
Florian Bruhin Florian Bruhin
Floris Bruynooghe Floris Bruynooghe

View File

@ -0,0 +1 @@
Improve performance of assertion rewriting.

View File

@ -67,14 +67,20 @@ class AssertionRewritingHook(object):
# flag to guard against trying to rewrite a pyc file while we are already writing another pyc file, # flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
# which might result in infinite recursion (#3506) # which might result in infinite recursion (#3506)
self._writing_pyc = False self._writing_pyc = False
self._basenames_to_check_rewrite = set('conftest',)
self._marked_for_rewrite_cache = {}
self._session_paths_checked = False
def set_session(self, session): def set_session(self, session):
self.session = session self.session = session
self._session_paths_checked = False
def find_module(self, name, path=None): def find_module(self, name, path=None):
if self._writing_pyc: if self._writing_pyc:
return None return None
state = self.config._assertstate state = self.config._assertstate
if self._early_rewrite_bailout(name, state):
return None
state.trace("find_module called for: %s" % name) state.trace("find_module called for: %s" % name)
names = name.rsplit(".", 1) names = name.rsplit(".", 1)
lastname = names[-1] lastname = names[-1]
@ -166,6 +172,41 @@ class AssertionRewritingHook(object):
self.modules[name] = co, pyc self.modules[name] = co, pyc
return self return self
def _early_rewrite_bailout(self, name, state):
"""
This is a fast way to get out of rewriting modules. Profiling has
shown that the call to imp.find_module (inside of the find_module
from this class) is a major slowdown, so, this method tries to
filter what we're sure won't be rewritten before getting to it.
"""
if not self._session_paths_checked and self.session is not None \
and hasattr(self.session, '_initialpaths'):
self._session_paths_checked = True
for path in self.session._initialpaths:
# Make something as c:/projects/my_project/path.py ->
# ['c:', 'projects', 'my_project', 'path.py']
parts = str(path).split(os.path.sep)
# add 'path' to basenames to be checked.
self._basenames_to_check_rewrite.add(os.path.splitext(parts[-1])[0])
# Note: conftest already by default in _basenames_to_check_rewrite.
parts = name.split('.')
if parts[-1] in self._basenames_to_check_rewrite:
return False
# For matching the name it must be as if it was a filename.
parts[-1] = parts[-1] + '.py'
fn_pypath = py.path.local(os.path.sep.join(parts))
for pat in self.fnpats:
if fn_pypath.fnmatch(pat):
return False
if self._is_marked_for_rewrite(name, state):
return False
state.trace("early skip of rewriting module: %s" % (name,))
return True
def _should_rewrite(self, name, fn_pypath, state): def _should_rewrite(self, name, fn_pypath, state):
# always rewrite conftest files # always rewrite conftest files
fn = str(fn_pypath) fn = str(fn_pypath)
@ -185,12 +226,20 @@ class AssertionRewritingHook(object):
state.trace("matched test file %r" % (fn,)) state.trace("matched test file %r" % (fn,))
return True return True
for marked in self._must_rewrite: return self._is_marked_for_rewrite(name, state)
if name == marked or name.startswith(marked + "."):
state.trace("matched marked file %r (from %r)" % (name, marked))
return True
return False def _is_marked_for_rewrite(self, name, state):
try:
return self._marked_for_rewrite_cache[name]
except KeyError:
for marked in self._must_rewrite:
if name == marked or name.startswith(marked + "."):
state.trace("matched marked file %r (from %r)" % (name, marked))
self._marked_for_rewrite_cache[name] = True
return True
self._marked_for_rewrite_cache[name] = False
return False
def mark_rewrite(self, *names): def mark_rewrite(self, *names):
"""Mark import names as needing to be rewritten. """Mark import names as needing to be rewritten.
@ -207,6 +256,7 @@ class AssertionRewritingHook(object):
): ):
self._warn_already_imported(name) self._warn_already_imported(name)
self._must_rewrite.update(names) self._must_rewrite.update(names)
self._marked_for_rewrite_cache.clear()
def _warn_already_imported(self, name): def _warn_already_imported(self, name):
self.config.warn( self.config.warn(
@ -239,16 +289,6 @@ class AssertionRewritingHook(object):
raise raise
return sys.modules[name] return sys.modules[name]
def is_package(self, name):
try:
fd, fn, desc = imp.find_module(name)
except ImportError:
return False
if fd is not None:
fd.close()
tp = desc[2]
return tp == imp.PKG_DIRECTORY
@classmethod @classmethod
def _register_with_pkg_resources(cls): def _register_with_pkg_resources(cls):
""" """

View File

@ -441,13 +441,14 @@ class Session(nodes.FSCollector):
self.trace("perform_collect", self, args) self.trace("perform_collect", self, args)
self.trace.root.indent += 1 self.trace.root.indent += 1
self._notfound = [] self._notfound = []
self._initialpaths = set() initialpaths = []
self._initialparts = [] self._initialparts = []
self.items = items = [] self.items = items = []
for arg in args: for arg in args:
parts = self._parsearg(arg) parts = self._parsearg(arg)
self._initialparts.append(parts) self._initialparts.append(parts)
self._initialpaths.add(parts[0]) initialpaths.append(parts[0])
self._initialpaths = frozenset(initialpaths)
rep = collect_one_node(self) rep = collect_one_node(self)
self.ihook.pytest_collectreport(report=rep) self.ihook.pytest_collectreport(report=rep)
self.trace.root.indent -= 1 self.trace.root.indent -= 1