From c629f6b18b076bd60fc8e28bec95c9e5f141687d Mon Sep 17 00:00:00 2001 From: Daniel Hahler Date: Wed, 4 Mar 2015 16:21:27 +0100 Subject: [PATCH] Fix `reload()` with modules handled via `python_files` If a module exists in `sys.modules` already, `load_module` has to return it. Fixes https://bitbucket.org/pytest-dev/pytest/issue/435 --HG-- branch : fix-reload --- _pytest/assertion/rewrite.py | 6 ++++++ testing/test_assertrewrite.py | 24 ++++++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/_pytest/assertion/rewrite.py b/_pytest/assertion/rewrite.py index 5e13a44c8..f965ba2c3 100644 --- a/_pytest/assertion/rewrite.py +++ b/_pytest/assertion/rewrite.py @@ -146,6 +146,12 @@ class AssertionRewritingHook(object): return self def load_module(self, name): + # If there is an existing module object named 'fullname' in + # sys.modules, the loader must use that existing module. (Otherwise, + # the reload() builtin will not work correctly.) + if name in sys.modules: + return sys.modules[name] + co, pyc = self.modules.pop(name) # I wish I could just call imp.load_compiled here, but __file__ has to # be set properly. In Python 3.2+, this all would be handled correctly diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 56b4d3164..820fcabff 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -641,3 +641,27 @@ class TestAssertionRewriteHookDetails(object): pyc.write(contents[:strip_bytes], mode='wb') assert _read_pyc(source, str(pyc)) is None # no error + + def test_reload_is_same(self, testdir): + # A file that will be picked up during collecting. + testdir.tmpdir.join("file.py").ensure() + testdir.tmpdir.join("pytest.ini").write(py.std.textwrap.dedent(""" + [pytest] + python_files = *.py + """)) + + testdir.makepyfile(test_fun=""" + import sys + try: + from imp import reload + except ImportError: + pass + + def test_loader(): + import file + assert sys.modules["file"] is reload(file) + """) + result = testdir.runpytest('-s') + result.stdout.fnmatch_lines([ + "* 1 passed*", + ])