diff --git a/_pytest/monkeypatch.py b/_pytest/monkeypatch.py index 0df224cd6..beaba37cc 100644 --- a/_pytest/monkeypatch.py +++ b/_pytest/monkeypatch.py @@ -1,8 +1,14 @@ """ monkeypatching and mocking functionality. """ import os, sys +import re + from py.builtin import _basestring + +RE_IMPORT_ERROR_NAME = re.compile("^No module named (.*)$") + + def pytest_funcarg__monkeypatch(request): """The returned ``monkeypatch`` funcarg provides these helper methods to modify objects, dictionaries or os.environ:: @@ -34,14 +40,28 @@ def derive_importpath(import_path, raising): (import_path,)) rest = [] target = import_path + target_parts = set(target.split(".")) while target: try: obj = __import__(target, None, None, "__doc__") - except ImportError: + except ImportError as ex: + if hasattr(ex, 'name'): + # Python >= 3.3 + failed_name = ex.name + else: + match = RE_IMPORT_ERROR_NAME.match(ex.args[0]) + assert match + failed_name = match.group(1) + if "." not in target: __tracebackhide__ = True pytest.fail("could not import any sub part: %s" % import_path) + elif failed_name != target \ + and not any(p == failed_name for p in target_parts): + # target is importable but causes ImportError itself + __tracebackhide__ = True + pytest.fail("import error in %s: %s" % (target, ex.args[0])) target, name = target.rsplit(".", 1) rest.append(name) else: diff --git a/testing/test_monkeypatch.py b/testing/test_monkeypatch.py index 690aee556..49db0bada 100644 --- a/testing/test_monkeypatch.py +++ b/testing/test_monkeypatch.py @@ -1,4 +1,6 @@ import os, sys +import textwrap + import pytest from _pytest.monkeypatch import monkeypatch as MonkeyPatch @@ -245,6 +247,21 @@ def test_issue185_time_breaks(testdir): *1 passed* """) +def test_importerror(testdir): + p = testdir.mkpydir("package") + p.join("a.py").write(textwrap.dedent("""\ + import doesnotexist + + x = 1 + """)) + testdir.tmpdir.join("test_importerror.py").write(textwrap.dedent("""\ + def test_importerror(monkeypatch): + monkeypatch.setattr('package.a.x', 2) + """)) + result = testdir.runpytest() + result.stdout.fnmatch_lines(""" + *import error in package.a.x: No module named {0}doesnotexist{0}* + """.format("'" if sys.version_info > (3, 0) else "")) class SampleNew(object):