Introduce pytest.register_assert_rewrite()

Plugins can now explicitly mark modules to be re-written.  By default
only the modules containing the plugin entrypoint are re-written.
This commit is contained in:
Floris Bruynooghe 2016-06-25 18:26:45 +02:00
parent 944da5b98a
commit 743f59afb2
4 changed files with 127 additions and 21 deletions

View File

@ -7,6 +7,7 @@ import sys
from _pytest.monkeypatch import monkeypatch
from _pytest.assertion import util
from _pytest.assertion import rewrite
def pytest_addoption(parser):
@ -26,6 +27,34 @@ def pytest_addoption(parser):
provide assert expression information. """)
def pytest_namespace():
return {'register_assert_rewrite': register_assert_rewrite}
def register_assert_rewrite(*names):
"""Register a module name to be rewritten on import.
This function will make sure that the module will get it's assert
statements rewritten when it is imported. Thus you should make
sure to call this before the module is actually imported, usually
in your __init__.py if you are a plugin using a package.
"""
for hook in sys.meta_path:
if isinstance(hook, rewrite.AssertionRewritingHook):
importhook = hook
break
else:
importhook = DummyRewriteHook()
importhook.mark_rewrite(*names)
class DummyRewriteHook(object):
"""A no-op import hook for when rewriting is disabled."""
def mark_rewrite(self, *names):
pass
class AssertionState:
"""State for the assertion plugin."""

View File

@ -163,9 +163,9 @@ class AssertionRewritingHook(object):
self.session = session
del session
else:
toplevel_name = name.split('.', 1)[0]
if toplevel_name in self._must_rewrite:
return True
for marked in self._must_rewrite:
if marked.startswith(name):
return True
return False
def mark_rewrite(self, *names):

View File

@ -11,6 +11,7 @@ import py
import sys, os
import _pytest._code
import _pytest.hookspec # the extension point definitions
import _pytest.assertion
from _pytest._pluggy import PluginManager, HookimplMarker, HookspecMarker
hookimpl = HookimplMarker("pytest")
@ -154,6 +155,9 @@ class PytestPluginManager(PluginManager):
self.trace.root.setwriter(err.write)
self.enable_tracing()
# Config._consider_importhook will set a real object if required.
self.rewrite_hook = _pytest.assertion.DummyRewriteHook()
def addhooks(self, module_or_class):
"""
.. deprecated:: 2.8
@ -362,7 +366,9 @@ class PytestPluginManager(PluginManager):
self._import_plugin_specs(os.environ.get("PYTEST_PLUGINS"))
def consider_module(self, mod):
self._import_plugin_specs(getattr(mod, "pytest_plugins", None))
plugins = getattr(mod, 'pytest_plugins', [])
self.rewrite_hook.mark_rewrite(*plugins)
self._import_plugin_specs(plugins)
def _import_plugin_specs(self, spec):
if spec:
@ -926,15 +932,13 @@ class Config(object):
and find all the installed plugins to mark them for re-writing
by the importhook.
"""
import _pytest.assertion
ns, unknown_args = self._parser.parse_known_and_unknown_args(args)
mode = ns.assertmode
if ns.noassert or ns.nomagic:
mode = "plain"
self._warn_about_missing_assertion(mode)
if mode != 'plain':
hook = _pytest.assertion.install_importhook(self, mode)
if hook:
self.pluginmanager.rewrite_hook = hook
for entrypoint in pkg_resources.iter_entry_points('pytest11'):
for entry in entrypoint.dist._get_metadata('RECORD'):
fn = entry.split(',')[0]

View File

@ -63,22 +63,53 @@ class TestImportHookInstallation:
assert 0
result.stdout.fnmatch_lines([expected])
@pytest.mark.parametrize('mode', ['plain', 'rewrite', 'reinterp'])
def test_pytest_plugins_rewrite(self, testdir, mode):
contents = {
'conftest.py': """
pytest_plugins = ['ham']
""",
'ham.py': """
import pytest
@pytest.fixture
def check_first():
def check(values, value):
assert values.pop(0) == value
return check
""",
'test_foo.py': """
def test_foo(check_first):
check_first([10, 30], 30)
""",
}
testdir.makepyfile(**contents)
result = testdir.runpytest_subprocess('--assert=%s' % mode)
if mode == 'plain':
expected = 'E AssertionError'
elif mode == 'rewrite':
expected = '*assert 10 == 30*'
elif mode == 'reinterp':
expected = '*AssertionError:*was re-run*'
else:
assert 0
result.stdout.fnmatch_lines([expected])
@pytest.mark.parametrize('mode', ['plain', 'rewrite', 'reinterp'])
def test_installed_plugin_rewrite(self, testdir, mode):
# Make sure the hook is installed early enough so that plugins
# installed via setuptools are re-written.
ham = testdir.tmpdir.join('hampkg').ensure(dir=1)
ham.join('__init__.py').write("""
import pytest
testdir.tmpdir.join('hampkg').ensure(dir=1)
contents = {
'hampkg/__init__.py': """
import pytest
@pytest.fixture
def check_first2():
def check(values, value):
assert values.pop(0) == value
return check
""")
testdir.makepyfile(
spamplugin="""
@pytest.fixture
def check_first2():
def check(values, value):
assert values.pop(0) == value
return check
""",
'spamplugin.py': """
import pytest
from hampkg import check_first2
@ -88,7 +119,7 @@ def check_first2():
assert values.pop(0) == value
return check
""",
mainwrapper="""
'mainwrapper.py': """
import pytest, pkg_resources
class DummyDistInfo:
@ -116,14 +147,15 @@ def check_first2():
pkg_resources.iter_entry_points = iter_entry_points
pytest.main()
""",
test_foo="""
'test_foo.py': """
def test(check_first):
check_first([10, 30], 30)
def test2(check_first2):
check_first([10, 30], 30)
""",
)
}
testdir.makepyfile(**contents)
result = testdir.run(sys.executable, 'mainwrapper.py', '-s', '--assert=%s' % mode)
if mode == 'plain':
expected = 'E AssertionError'
@ -135,6 +167,47 @@ def check_first2():
assert 0
result.stdout.fnmatch_lines([expected])
def test_rewrite_ast(self, testdir):
testdir.tmpdir.join('pkg').ensure(dir=1)
contents = {
'pkg/__init__.py': """
import pytest
pytest.register_assert_rewrite('pkg.helper')
""",
'pkg/helper.py': """
def tool():
a, b = 2, 3
assert a == b
""",
'pkg/plugin.py': """
import pytest, pkg.helper
@pytest.fixture
def tool():
return pkg.helper.tool
""",
'pkg/other.py': """
l = [3, 2]
def tool():
assert l.pop() == 3
""",
'conftest.py': """
pytest_plugins = ['pkg.plugin']
""",
'test_pkg.py': """
import pkg.other
def test_tool(tool):
tool()
def test_other():
pkg.other.tool()
""",
}
testdir.makepyfile(**contents)
result = testdir.runpytest_subprocess('--assert=rewrite')
result.stdout.fnmatch_lines(['>*assert a == b*',
'E*assert 2 == 3*',
'>*assert l.pop() == 3*',
'E*AssertionError*re-run*'])
class TestBinReprIntegration: