Merge pull request #1787 from nicoddemus/fix-rewrite-conftest
Rewrite asserts in test-modules loaded very early in the startup
This commit is contained in:
commit
6e3105dc8f
|
@ -31,10 +31,11 @@ def pytest_namespace():
|
|||
def register_assert_rewrite(*names):
|
||||
"""Register a module name to be rewritten on import.
|
||||
|
||||
This function will make sure that the module will get it's assert
|
||||
statements rewritten when it is imported. Thus you should make
|
||||
sure to call this before the module is actually imported, usually
|
||||
in your __init__.py if you are a plugin using a package.
|
||||
This function will make sure that this module or all modules inside
|
||||
the package will get their assert statements rewritten.
|
||||
Thus you should make sure to call this before the module is
|
||||
actually imported, usually in your __init__.py if you are a plugin
|
||||
using a package.
|
||||
"""
|
||||
for hook in sys.meta_path:
|
||||
if isinstance(hook, rewrite.AssertionRewritingHook):
|
||||
|
|
|
@ -11,6 +11,7 @@ import re
|
|||
import struct
|
||||
import sys
|
||||
import types
|
||||
from fnmatch import fnmatch
|
||||
|
||||
import py
|
||||
from _pytest.assertion import util
|
||||
|
@ -144,28 +145,29 @@ class AssertionRewritingHook(object):
|
|||
if fn_pypath.basename == 'conftest.py':
|
||||
state.trace("rewriting conftest file: %r" % (fn,))
|
||||
return True
|
||||
elif self.session is not None:
|
||||
|
||||
if self.session is not None:
|
||||
if self.session.isinitpath(fn):
|
||||
state.trace("matched test file (was specified on cmdline): %r" %
|
||||
(fn,))
|
||||
return True
|
||||
else:
|
||||
# modules not passed explicitly on the command line are only
|
||||
# rewritten if they match the naming convention for test files
|
||||
session = self.session # avoid a cycle here
|
||||
self.session = None
|
||||
try:
|
||||
for pat in self.fnpats:
|
||||
if fn_pypath.fnmatch(pat):
|
||||
state.trace("matched test file %r" % (fn,))
|
||||
return True
|
||||
finally:
|
||||
self.session = session
|
||||
del session
|
||||
else:
|
||||
for marked in self._must_rewrite:
|
||||
if marked.startswith(name):
|
||||
return True
|
||||
|
||||
# modules not passed explicitly on the command line are only
|
||||
# rewritten if they match the naming convention for test files
|
||||
for pat in self.fnpats:
|
||||
# use fnmatch instead of fn_pypath.fnmatch because the
|
||||
# latter might trigger an import to fnmatch.fnmatch
|
||||
# internally, which would cause this method to be
|
||||
# called recursively
|
||||
if fnmatch(fn_pypath.basename, pat):
|
||||
state.trace("matched test file %r" % (fn,))
|
||||
return True
|
||||
|
||||
for marked in self._must_rewrite:
|
||||
if name.startswith(marked):
|
||||
state.trace("matched marked file %r (from %r)" % (name, marked))
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def mark_rewrite(self, *names):
|
||||
|
|
|
@ -533,6 +533,16 @@ def test_rewritten():
|
|||
hook.mark_rewrite('_pytest')
|
||||
assert '_pytest' in warnings[0][1]
|
||||
|
||||
def test_rewrite_module_imported_from_conftest(self, testdir):
|
||||
testdir.makeconftest('''
|
||||
import test_rewrite_module_imported
|
||||
''')
|
||||
testdir.makepyfile(test_rewrite_module_imported='''
|
||||
def test_rewritten():
|
||||
assert "@py_builtins" in globals()
|
||||
''')
|
||||
assert testdir.runpytest_subprocess().ret == 0
|
||||
|
||||
|
||||
class TestAssertionRewriteHookDetails(object):
|
||||
def test_loader_is_package_false_for_module(self, testdir):
|
||||
|
|
Loading…
Reference in New Issue