diff --git a/AUTHORS b/AUTHORS index 2d59a1b0f..a65623820 100644 --- a/AUTHORS +++ b/AUTHORS @@ -93,6 +93,7 @@ Hui Wang (coldnight) Ian Bicking Ian Lesperance Ionuț Turturică +Iwan Briquemont Jaap Broekhuizen Jan Balster Janne Vanhala diff --git a/changelog/3539.bugfix.rst b/changelog/3539.bugfix.rst new file mode 100644 index 000000000..d0741cda9 --- /dev/null +++ b/changelog/3539.bugfix.rst @@ -0,0 +1 @@ +Fix reload on assertion rewritten modules. diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index be8c6dc4d..7a11c4ec1 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -269,17 +269,17 @@ class AssertionRewritingHook(object): ) def load_module(self, name): - # If there is an existing module object named 'fullname' in - # sys.modules, the loader must use that existing module. (Otherwise, - # the reload() builtin will not work correctly.) - if name in sys.modules: - return sys.modules[name] - co, pyc = self.modules.pop(name) - # I wish I could just call imp.load_compiled here, but __file__ has to - # be set properly. In Python 3.2+, this all would be handled correctly - # by load_compiled. - mod = sys.modules[name] = imp.new_module(name) + if name in sys.modules: + # If there is an existing module object named 'fullname' in + # sys.modules, the loader must use that existing module. (Otherwise, + # the reload() builtin will not work correctly.) + mod = sys.modules[name] + else: + # I wish I could just call imp.load_compiled here, but __file__ has to + # be set properly. In Python 3.2+, this all would be handled correctly + # by load_compiled. + mod = sys.modules[name] = imp.new_module(name) try: mod.__file__ = co.co_filename # Normally, this attribute is 3.2+. diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index a2cd8e81c..5153fc741 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -1050,6 +1050,48 @@ class TestAssertionRewriteHookDetails(object): result = testdir.runpytest("-s") result.stdout.fnmatch_lines(["* 1 passed*"]) + def test_reload_reloads(self, testdir): + """Reloading a module after change picks up the change.""" + testdir.tmpdir.join("file.py").write( + textwrap.dedent( + """ + def reloaded(): + return False + + def rewrite_self(): + with open(__file__, 'w') as self: + self.write('def reloaded(): return True') + """ + ) + ) + testdir.tmpdir.join("pytest.ini").write( + textwrap.dedent( + """ + [pytest] + python_files = *.py + """ + ) + ) + + testdir.makepyfile( + test_fun=""" + import sys + try: + from imp import reload + except ImportError: + pass + + def test_loader(): + import file + assert not file.reloaded() + file.rewrite_self() + reload(file) + assert file.reloaded() + """ + ) + result = testdir.runpytest("-s") + result.stdout.fnmatch_lines(["* 1 passed*"]) + def test_get_data_support(self, testdir): """Implement optional PEP302 api (#808). """