diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 6bed6ed42..4273bd73f 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -143,6 +143,9 @@ time or change existing behaviors in order to make them less surprising/more use **Changes** +* Plugins now benefit from assertion rewriting. Thanks + `@sober7`_, `@nicoddemus`_ and `@flub`_ for the PR. + * Fixtures marked with ``@pytest.fixture`` can now use ``yield`` statements exactly like those marked with the ``@pytest.yield_fixture`` decorator. This change renders ``@pytest.yield_fixture`` deprecated and makes ``@pytest.fixture`` with ``yield`` statements diff --git a/_pytest/assertion/__init__.py b/_pytest/assertion/__init__.py index dd30e1471..746c810ee 100644 --- a/_pytest/assertion/__init__.py +++ b/_pytest/assertion/__init__.py @@ -5,8 +5,7 @@ import py import os import sys -from _pytest.config import hookimpl -from _pytest.monkeypatch import MonkeyPatch +from _pytest.monkeypatch import monkeypatch from _pytest.assertion import util @@ -35,10 +34,7 @@ class AssertionState: self.trace = config.trace.root.get("assertion") -@hookimpl(tryfirst=True) -def pytest_load_initial_conftests(early_config, parser, args): - ns, ns_unknown_args = parser.parse_known_and_unknown_args(args) - mode = ns.assertmode +def install_importhook(config, mode): if mode == "rewrite": try: import ast # noqa @@ -51,37 +47,37 @@ def pytest_load_initial_conftests(early_config, parser, args): sys.version_info[:3] == (2, 6, 0)): mode = "reinterp" - early_config._assertstate = AssertionState(early_config, mode) - warn_about_missing_assertion(mode, early_config.pluginmanager) + config._assertstate = AssertionState(config, mode) - if mode != "plain": - _load_modules(mode) - m = MonkeyPatch() - early_config._cleanup.append(m.undo) - m.setattr(py.builtin.builtins, 'AssertionError', - reinterpret.AssertionError) # noqa + _load_modules(mode) + m = monkeypatch() + config._cleanup.append(m.undo) + m.setattr(py.builtin.builtins, 'AssertionError', + reinterpret.AssertionError) # noqa hook = None if mode == "rewrite": - hook = rewrite.AssertionRewritingHook(early_config) # noqa + hook = rewrite.AssertionRewritingHook(config) # noqa sys.meta_path.insert(0, hook) - early_config._assertstate.hook = hook - early_config._assertstate.trace("configured with mode set to %r" % (mode,)) + config._assertstate.hook = hook + config._assertstate.trace("configured with mode set to %r" % (mode,)) def undo(): - hook = early_config._assertstate.hook + hook = config._assertstate.hook if hook is not None and hook in sys.meta_path: sys.meta_path.remove(hook) - early_config.add_cleanup(undo) + config.add_cleanup(undo) + return hook def pytest_collection(session): # this hook is only called when test modules are collected # so for example not in the master process of pytest-xdist # (which does not collect test modules) - hook = session.config._assertstate.hook - if hook is not None: - hook.set_session(session) + assertstate = getattr(session.config, '_assertstate', None) + if assertstate: + if assertstate.hook is not None: + assertstate.hook.set_session(session) def _running_on_ci(): @@ -138,9 +134,10 @@ def pytest_runtest_teardown(item): def pytest_sessionfinish(session): - hook = session.config._assertstate.hook - if hook is not None: - hook.session = None + assertstate = getattr(session.config, '_assertstate', None) + if assertstate: + if assertstate.hook is not None: + assertstate.hook.set_session(None) def _load_modules(mode): @@ -151,31 +148,5 @@ def _load_modules(mode): from _pytest.assertion import rewrite # noqa -def warn_about_missing_assertion(mode, pluginmanager): - try: - assert False - except AssertionError: - pass - else: - if mode == "rewrite": - specifically = ("assertions which are not in test modules " - "will be ignored") - else: - specifically = "failing tests may report as passing" - - # temporarily disable capture so we can print our warning - capman = pluginmanager.getplugin('capturemanager') - try: - out, err = capman.suspendcapture() - sys.stderr.write("WARNING: " + specifically + - " because assert statements are not executed " - "by the underlying Python interpreter " - "(are you using python -O?)\n") - finally: - capman.resumecapture() - sys.stdout.write(out) - sys.stderr.write(err) - - # Expose this plugin's implementation for the pytest_assertrepr_compare hook pytest_assertrepr_compare = util.assertrepr_compare diff --git a/_pytest/assertion/rewrite.py b/_pytest/assertion/rewrite.py index 06944b016..50d8062ae 100644 --- a/_pytest/assertion/rewrite.py +++ b/_pytest/assertion/rewrite.py @@ -51,6 +51,7 @@ class AssertionRewritingHook(object): self.session = None self.modules = {} self._register_with_pkg_resources() + self._must_rewrite = set() def set_session(self, session): self.session = session @@ -87,7 +88,7 @@ class AssertionRewritingHook(object): fn = os.path.join(pth, name.rpartition(".")[2] + ".py") fn_pypath = py.path.local(fn) - if not self._should_rewrite(fn_pypath, state): + if not self._should_rewrite(name, fn_pypath, state): return None # The requested module looks like a test file, so rewrite it. This is @@ -137,7 +138,7 @@ class AssertionRewritingHook(object): self.modules[name] = co, pyc return self - def _should_rewrite(self, fn_pypath, state): + def _should_rewrite(self, name, fn_pypath, state): # always rewrite conftest files fn = str(fn_pypath) if fn_pypath.basename == 'conftest.py': @@ -161,8 +162,29 @@ class AssertionRewritingHook(object): finally: self.session = session del session + else: + for marked in self._must_rewrite: + if marked.startswith(name): + return True return False + def mark_rewrite(self, *names): + """Mark import names as needing to be re-written. + + The named module or package as well as any nested modules will + be re-written on import. + """ + already_imported = set(names).intersection(set(sys.modules)) + if already_imported: + self._warn_already_imported(already_imported) + self._must_rewrite.update(names) + + def _warn_already_imported(self, names): + self.config.warn( + 'P1', + 'Modules are already imported so can not be re-written: %s' % + ','.join(names)) + def load_module(self, name): # If there is an existing module object named 'fullname' in # sys.modules, the loader must use that existing module. (Otherwise, diff --git a/_pytest/config.py b/_pytest/config.py index 8cb1e6e01..5ac120ab3 100644 --- a/_pytest/config.py +++ b/_pytest/config.py @@ -5,6 +5,7 @@ import traceback import types import warnings +import pkg_resources import py # DON't import pytest here because it causes import cycle troubles import sys, os @@ -918,14 +919,63 @@ class Config(object): self._parser.addini('addopts', 'extra command line options', 'args') self._parser.addini('minversion', 'minimally required pytest version') + def _consider_importhook(self, args, entrypoint_name): + """Install the PEP 302 import hook if using assertion re-writing. + + Needs to parse the --assert= option from the commandline + and find all the installed plugins to mark them for re-writing + by the importhook. + """ + import _pytest.assertion + ns, unknown_args = self._parser.parse_known_and_unknown_args(args) + mode = ns.assertmode + if ns.noassert or ns.nomagic: + mode = "plain" + self._warn_about_missing_assertion(mode) + if mode != 'plain': + hook = _pytest.assertion.install_importhook(self, mode) + if hook: + for entrypoint in pkg_resources.iter_entry_points('pytest11'): + for entry in entrypoint.dist._get_metadata('RECORD'): + fn = entry.split(',')[0] + is_simple_module = os.sep not in fn and fn.endswith('.py') + is_package = fn.count(os.sep) == 1 and fn.endswith('__init__.py') + if is_simple_module: + module_name, ext = os.path.splitext(fn) + hook.mark_rewrite(module_name) + elif is_package: + package_name = os.path.dirname(fn) + hook.mark_rewrite(package_name) + + def _warn_about_missing_assertion(self, mode): + try: + assert False + except AssertionError: + pass + else: + if mode == "rewrite": + specifically = ("assertions not in test modules or plugins" + "will be ignored") + else: + specifically = "failing tests may report as passing" + sys.stderr.write("WARNING: " + specifically + + " because assert statements are not executed " + "by the underlying Python interpreter " + "(are you using python -O?)\n") + def _preparse(self, args, addopts=True): self._initini(args) if addopts: args[:] = shlex.split(os.environ.get('PYTEST_ADDOPTS', '')) + args args[:] = self.getini("addopts") + args self._checkversion() + entrypoint_name = 'pytest11' + self._consider_importhook(args, entrypoint_name) self.pluginmanager.consider_preparse(args) - self.pluginmanager.load_setuptools_entrypoints("pytest11") + try: + self.pluginmanager.load_setuptools_entrypoints(entrypoint_name) + except ImportError as e: + self.warn("I2", "could not load setuptools entry import: %s" % (e,)) self.pluginmanager.consider_env() self.known_args_namespace = ns = self._parser.parse_known_args(args, namespace=self.option.copy()) if self.known_args_namespace.confcutdir is None and self.inifile: diff --git a/testing/test_assertion.py b/testing/test_assertion.py index 56cd73bd3..0346cb9a9 100644 --- a/testing/test_assertion.py +++ b/testing/test_assertion.py @@ -26,6 +26,116 @@ def mock_config(): def interpret(expr): return reinterpret.reinterpret(expr, _pytest._code.Frame(sys._getframe(1))) + +class TestImportHookInstallation: + + @pytest.mark.parametrize('initial_conftest', [True, False]) + @pytest.mark.parametrize('mode', ['plain', 'rewrite', 'reinterp']) + def test_conftest_assertion_rewrite(self, testdir, initial_conftest, mode): + """Test that conftest files are using assertion rewrite on import. + (#1619) + """ + testdir.tmpdir.join('foo/tests').ensure(dir=1) + conftest_path = 'conftest.py' if initial_conftest else 'foo/conftest.py' + contents = { + conftest_path: """ + import pytest + @pytest.fixture + def check_first(): + def check(values, value): + assert values.pop(0) == value + return check + """, + 'foo/tests/test_foo.py': """ + def test(check_first): + check_first([10, 30], 30) + """ + } + testdir.makepyfile(**contents) + result = testdir.runpytest_subprocess('--assert=%s' % mode) + if mode == 'plain': + expected = 'E AssertionError' + elif mode == 'rewrite': + expected = '*assert 10 == 30*' + elif mode == 'reinterp': + expected = '*AssertionError:*was re-run*' + else: + assert 0 + result.stdout.fnmatch_lines([expected]) + + @pytest.mark.parametrize('mode', ['plain', 'rewrite', 'reinterp']) + def test_installed_plugin_rewrite(self, testdir, mode): + # Make sure the hook is installed early enough so that plugins + # installed via setuptools are re-written. + ham = testdir.tmpdir.join('hampkg').ensure(dir=1) + ham.join('__init__.py').write(""" +import pytest + +@pytest.fixture +def check_first2(): + def check(values, value): + assert values.pop(0) == value + return check + """) + testdir.makepyfile( + spamplugin=""" + import pytest + from hampkg import check_first2 + + @pytest.fixture + def check_first(): + def check(values, value): + assert values.pop(0) == value + return check + """, + mainwrapper=""" + import pytest, pkg_resources + + class DummyDistInfo: + project_name = 'spam' + version = '1.0' + + def _get_metadata(self, name): + return ['spamplugin.py,sha256=abc,123', + 'hampkg/__init__.py,sha256=abc,123'] + + class DummyEntryPoint: + name = 'spam' + module_name = 'spam.py' + attrs = () + extras = None + dist = DummyDistInfo() + + def load(self, require=True, *args, **kwargs): + import spamplugin + return spamplugin + + def iter_entry_points(name): + yield DummyEntryPoint() + + pkg_resources.iter_entry_points = iter_entry_points + pytest.main() + """, + test_foo=""" + def test(check_first): + check_first([10, 30], 30) + + def test2(check_first2): + check_first([10, 30], 30) + """, + ) + result = testdir.run(sys.executable, 'mainwrapper.py', '-s', '--assert=%s' % mode) + if mode == 'plain': + expected = 'E AssertionError' + elif mode == 'rewrite': + expected = '*assert 10 == 30*' + elif mode == 'reinterp': + expected = '*AssertionError:*was re-run*' + else: + assert 0 + result.stdout.fnmatch_lines([expected]) + + class TestBinReprIntegration: def test_pytest_assertrepr_compare_called(self, testdir): diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 5f8127af9..496034c23 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -12,7 +12,7 @@ if sys.platform.startswith("java"): import _pytest._code from _pytest.assertion import util -from _pytest.assertion.rewrite import rewrite_asserts, PYTEST_TAG +from _pytest.assertion.rewrite import rewrite_asserts, PYTEST_TAG, AssertionRewritingHook from _pytest.main import EXIT_NOTESTSCOLLECTED @@ -524,6 +524,16 @@ def test_rewritten(): testdir.makepyfile("import a_package_without_init_py.module") assert testdir.runpytest().ret == EXIT_NOTESTSCOLLECTED + def test_rewrite_warning(self, pytestconfig, monkeypatch): + hook = AssertionRewritingHook(pytestconfig) + warnings = [] + def mywarn(code, msg): + warnings.append((code, msg)) + monkeypatch.setattr(hook.config, 'warn', mywarn) + hook.mark_rewrite('_pytest') + assert '_pytest' in warnings[0][1] + + class TestAssertionRewriteHookDetails(object): def test_loader_is_package_false_for_module(self, testdir): testdir.makepyfile(test_fun=""" @@ -704,40 +714,6 @@ class TestAssertionRewriteHookDetails(object): result = testdir.runpytest() result.stdout.fnmatch_lines('*1 passed*') - @pytest.mark.parametrize('initial_conftest', [True, False]) - @pytest.mark.parametrize('mode', ['plain', 'rewrite', 'reinterp']) - def test_conftest_assertion_rewrite(self, testdir, initial_conftest, mode): - """Test that conftest files are using assertion rewrite on import. - (#1619) - """ - testdir.tmpdir.join('foo/tests').ensure(dir=1) - conftest_path = 'conftest.py' if initial_conftest else 'foo/conftest.py' - contents = { - conftest_path: """ - import pytest - @pytest.fixture - def check_first(): - def check(values, value): - assert values.pop(0) == value - return check - """, - 'foo/tests/test_foo.py': """ - def test(check_first): - check_first([10, 30], 30) - """ - } - testdir.makepyfile(**contents) - result = testdir.runpytest_subprocess('--assert=%s' % mode) - if mode == 'plain': - expected = 'E AssertionError' - elif mode == 'rewrite': - expected = '*assert 10 == 30*' - elif mode == 'reinterp': - expected = '*AssertionError:*was re-run*' - else: - assert 0 - result.stdout.fnmatch_lines([expected]) - def test_issue731(testdir): testdir.makepyfile(""" diff --git a/testing/test_config.py b/testing/test_config.py index bb686c3b0..57c95cd50 100644 --- a/testing/test_config.py +++ b/testing/test_config.py @@ -373,10 +373,14 @@ def test_preparse_ordering_with_setuptools(testdir, monkeypatch): pkg_resources = pytest.importorskip("pkg_resources") def my_iter(name): assert name == "pytest11" + class Dist: + project_name = 'spam' + version = '1.0' + def _get_metadata(self, name): + return ['foo.txt,sha256=abc,123'] class EntryPoint: name = "mytestplugin" - class dist: - pass + dist = Dist() def load(self): class PseudoPlugin: x = 42 @@ -412,8 +416,14 @@ def test_plugin_preparse_prevents_setuptools_loading(testdir, monkeypatch): pkg_resources = pytest.importorskip("pkg_resources") def my_iter(name): assert name == "pytest11" + class Dist: + project_name = 'spam' + version = '1.0' + def _get_metadata(self, name): + return ['foo.txt,sha256=abc,123'] class EntryPoint: name = "mytestplugin" + dist = Dist() def load(self): assert 0, "should not arrive here" return iter([EntryPoint()]) @@ -505,7 +515,6 @@ def test_load_initial_conftest_last_ordering(testdir): expected = [ "_pytest.config", 'test_config', - '_pytest.assertion', '_pytest.capture', ] assert [x.function.__module__ for x in l] == expected