diff --git a/_pytest/assertion/__init__.py b/_pytest/assertion/__init__.py index bca626931..3c42910d5 100644 --- a/_pytest/assertion/__init__.py +++ b/_pytest/assertion/__init__.py @@ -31,10 +31,11 @@ def pytest_namespace(): def register_assert_rewrite(*names): """Register a module name to be rewritten on import. - This function will make sure that the module will get it's assert - statements rewritten when it is imported. Thus you should make - sure to call this before the module is actually imported, usually - in your __init__.py if you are a plugin using a package. + This function will make sure that this module or all modules inside + the package will get their assert statements rewritten. + Thus you should make sure to call this before the module is + actually imported, usually in your __init__.py if you are a plugin + using a package. """ for hook in sys.meta_path: if isinstance(hook, rewrite.AssertionRewritingHook): diff --git a/_pytest/assertion/rewrite.py b/_pytest/assertion/rewrite.py index 50d8062ae..80d6ee3ba 100644 --- a/_pytest/assertion/rewrite.py +++ b/_pytest/assertion/rewrite.py @@ -11,6 +11,7 @@ import re import struct import sys import types +from fnmatch import fnmatch import py from _pytest.assertion import util @@ -144,28 +145,29 @@ class AssertionRewritingHook(object): if fn_pypath.basename == 'conftest.py': state.trace("rewriting conftest file: %r" % (fn,)) return True - elif self.session is not None: + + if self.session is not None: if self.session.isinitpath(fn): state.trace("matched test file (was specified on cmdline): %r" % (fn,)) return True - else: - # modules not passed explicitly on the command line are only - # rewritten if they match the naming convention for test files - session = self.session # avoid a cycle here - self.session = None - try: - for pat in self.fnpats: - if fn_pypath.fnmatch(pat): - state.trace("matched test file %r" % (fn,)) - return True - finally: - self.session = session - del session - else: - for marked in self._must_rewrite: - if marked.startswith(name): - return True + + # modules not passed explicitly on the command line are only + # rewritten if they match the naming convention for test files + for pat in self.fnpats: + # use fnmatch instead of fn_pypath.fnmatch because the + # latter might trigger an import to fnmatch.fnmatch + # internally, which would cause this method to be + # called recursively + if fnmatch(fn_pypath.basename, pat): + state.trace("matched test file %r" % (fn,)) + return True + + for marked in self._must_rewrite: + if name.startswith(marked): + state.trace("matched marked file %r (from %r)" % (name, marked)) + return True + return False def mark_rewrite(self, *names): diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 496034c23..cedd435f8 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -533,6 +533,16 @@ def test_rewritten(): hook.mark_rewrite('_pytest') assert '_pytest' in warnings[0][1] + def test_rewrite_module_imported_from_conftest(self, testdir): + testdir.makeconftest(''' + import test_rewrite_module_imported + ''') + testdir.makepyfile(test_rewrite_module_imported=''' + def test_rewritten(): + assert "@py_builtins" in globals() + ''') + assert testdir.runpytest_subprocess().ret == 0 + class TestAssertionRewriteHookDetails(object): def test_loader_is_package_false_for_module(self, testdir):