import os import sys import zipfile import py import pytest ast = pytest.importorskip("ast") if sys.platform.startswith("java"): # XXX should be xfail pytest.skip("assert rewrite does currently not work on jython") from _pytest.assertion import util from _pytest.assertion.rewrite import rewrite_asserts, PYTEST_TAG def setup_module(mod): mod._old_reprcompare = util._reprcompare py.code._reprcompare = None def teardown_module(mod): util._reprcompare = mod._old_reprcompare del mod._old_reprcompare def rewrite(src): tree = ast.parse(src) rewrite_asserts(tree) return tree def getmsg(f, extra_ns=None, must_pass=False): """Rewrite the assertions in f, run it, and get the failure message.""" src = '\n'.join(py.code.Code(f).source().lines) mod = rewrite(src) code = compile(mod, "", "exec") ns = {} if extra_ns is not None: ns.update(extra_ns) py.builtin.exec_(code, ns) func = ns[f.__name__] try: func() except AssertionError: if must_pass: pytest.fail("shouldn't have raised") s = str(sys.exc_info()[1]) if not s.startswith("assert"): return "AssertionError: " + s return s else: if not must_pass: pytest.fail("function didn't raise at all") class TestAssertionRewrite: def test_place_initial_imports(self): s = """'Doc string'\nother = stuff""" m = rewrite(s) assert isinstance(m.body[0], ast.Expr) assert isinstance(m.body[0].value, ast.Str) for imp in m.body[1:3]: assert isinstance(imp, ast.Import) assert imp.lineno == 2 assert imp.col_offset == 0 assert isinstance(m.body[3], ast.Assign) s = """from __future__ import with_statement\nother_stuff""" m = rewrite(s) assert isinstance(m.body[0], ast.ImportFrom) for imp in m.body[1:3]: assert isinstance(imp, ast.Import) assert imp.lineno == 2 assert imp.col_offset == 0 assert isinstance(m.body[3], ast.Expr) s = """'doc string'\nfrom __future__ import with_statement\nother""" m = rewrite(s) assert isinstance(m.body[0], ast.Expr) assert isinstance(m.body[0].value, ast.Str) assert isinstance(m.body[1], ast.ImportFrom) for imp in m.body[2:4]: assert isinstance(imp, ast.Import) assert imp.lineno == 3 assert imp.col_offset == 0 assert isinstance(m.body[4], ast.Expr) s = """from . import relative\nother_stuff""" m = rewrite(s) for imp in m.body[0:2]: assert isinstance(imp, ast.Import) assert imp.lineno == 1 assert imp.col_offset == 0 assert isinstance(m.body[3], ast.Expr) def test_dont_rewrite(self): s = """'PYTEST_DONT_REWRITE'\nassert 14""" m = rewrite(s) assert len(m.body) == 2 assert isinstance(m.body[0].value, ast.Str) assert isinstance(m.body[1], ast.Assert) assert m.body[1].msg is None def test_name(self): def f(): assert False assert getmsg(f) == "assert False" def f(): f = False assert f assert getmsg(f) == "assert False" def f(): assert a_global assert getmsg(f, {"a_global" : False}) == "assert False" def f(): assert sys == 42 assert getmsg(f, {"sys" : sys}) == "assert sys == 42" def f(): assert cls == 42 class X(object): pass assert getmsg(f, {"cls" : X}) == "assert cls == 42" def test_assert_already_has_message(self): def f(): assert False, "something bad!" assert getmsg(f) == "AssertionError: something bad!" def test_boolop(self): def f(): f = g = False assert f and g assert getmsg(f) == "assert (False)" def f(): f = True g = False assert f and g assert getmsg(f) == "assert (True and False)" def f(): f = False g = True assert f and g assert getmsg(f) == "assert (False)" def f(): f = g = False assert f or g assert getmsg(f) == "assert (False or False)" def f(): f = g = False assert not f and not g getmsg(f, must_pass=True) def x(): return False def f(): assert x() and x() assert getmsg(f, {"x" : x}) == "assert (x())" def f(): assert False or x() assert getmsg(f, {"x" : x}) == "assert (False or x())" def f(): assert 1 in {} and 2 in {} assert getmsg(f) == "assert (1 in {})" def f(): x = 1 y = 2 assert x in {1 : None} and y in {} assert getmsg(f) == "assert (1 in {1: None} and 2 in {})" def f(): f = True g = False assert f or g getmsg(f, must_pass=True) def f(): f = g = h = lambda: True assert f() and g() and h() getmsg(f, must_pass=True) def test_short_circut_evaluation(self): def f(): assert True or explode getmsg(f, must_pass=True) def f(): x = 1 assert x == 1 or x == 2 getmsg(f, must_pass=True) def test_unary_op(self): def f(): x = True assert not x assert getmsg(f) == "assert not True" def f(): x = 0 assert ~x + 1 assert getmsg(f) == "assert (~0 + 1)" def f(): x = 3 assert -x + x assert getmsg(f) == "assert (-3 + 3)" def f(): x = 0 assert +x + x assert getmsg(f) == "assert (+0 + 0)" def test_binary_op(self): def f(): x = 1 y = -1 assert x + y assert getmsg(f) == "assert (1 + -1)" def f(): x = range(10) assert not 5 % 4 assert getmsg(f) == "assert not (5 % 4)" def test_call(self): def g(a=42, *args, **kwargs): return False ns = {"g" : g} def f(): assert g() assert getmsg(f, ns) == """assert g()""" def f(): assert g(1) assert getmsg(f, ns) == """assert g(1)""" def f(): assert g(1, 2) assert getmsg(f, ns) == """assert g(1, 2)""" def f(): assert g(1, g=42) assert getmsg(f, ns) == """assert g(1, g=42)""" def f(): assert g(1, 3, g=23) assert getmsg(f, ns) == """assert g(1, 3, g=23)""" def f(): seq = [1, 2, 3] assert g(*seq) assert getmsg(f, ns) == """assert g(*[1, 2, 3])""" def f(): x = "a" assert g(**{x : 2}) assert getmsg(f, ns) == """assert g(**{'a': 2})""" def test_attribute(self): class X(object): g = 3 ns = {"x" : X} def f(): assert not x.g assert getmsg(f, ns) == """assert not 3 + where 3 = x.g""" def f(): x.a = False assert x.a assert getmsg(f, ns) == """assert x.a""" def test_comparisons(self): def f(): a, b = range(2) assert b < a assert getmsg(f) == """assert 1 < 0""" def f(): a, b, c = range(3) assert a > b > c assert getmsg(f) == """assert 0 > 1""" def f(): a, b, c = range(3) assert a < b > c assert getmsg(f) == """assert 1 > 2""" def f(): a, b, c = range(3) assert a < b <= c getmsg(f, must_pass=True) def f(): a, b, c = range(3) assert a < b assert b < c getmsg(f, must_pass=True) def test_len(self): def f(): l = list(range(10)) assert len(l) == 11 assert getmsg(f).startswith("""assert 10 == 11 + where 10 = len([""") def test_custom_reprcompare(self, monkeypatch): def my_reprcompare(op, left, right): return "42" monkeypatch.setattr(util, "_reprcompare", my_reprcompare) def f(): assert 42 < 3 assert getmsg(f) == "assert 42" def my_reprcompare(op, left, right): return "%s %s %s" % (left, op, right) monkeypatch.setattr(util, "_reprcompare", my_reprcompare) def f(): assert 1 < 3 < 5 <= 4 < 7 assert getmsg(f) == "assert 5 <= 4" def test_assert_raising_nonzero_in_comparison(self): def f(): class A(object): def __nonzero__(self): raise ValueError(42) def __lt__(self, other): return A() def __repr__(self): return "" def myany(x): return False assert myany(A() < 0) assert " < 0" in getmsg(f) def test_formatchar(self): def f(): assert "%test" == "test" assert getmsg(f).startswith("assert '%test' == 'test'") class TestRewriteOnImport: def test_pycache_is_a_file(self, testdir): testdir.tmpdir.join("__pycache__").write("Hello") testdir.makepyfile(""" def test_rewritten(): assert "@py_builtins" in globals()""") assert testdir.runpytest().ret == 0 def test_zipfile(self, testdir): z = testdir.tmpdir.join("myzip.zip") z_fn = str(z) f = zipfile.ZipFile(z_fn, "w") try: f.writestr("test_gum/__init__.py", "") f.writestr("test_gum/test_lizard.py", "") finally: f.close() z.chmod(256) testdir.makepyfile(""" import sys sys.path.append(%r) import test_gum.test_lizard""" % (z_fn,)) assert testdir.runpytest().ret == 0 def test_readonly(self, testdir): sub = testdir.mkdir("testing") sub.join("test_readonly.py").write( py.builtin._totext(""" def test_rewritten(): assert "@py_builtins" in globals() """).encode("utf-8"), "wb") sub.chmod(320) assert testdir.runpytest().ret == 0 def test_dont_write_bytecode(self, testdir, monkeypatch): testdir.makepyfile(""" import os def test_no_bytecode(): assert "__pycache__" in __cached__ assert not os.path.exists(__cached__) assert not os.path.exists(os.path.dirname(__cached__))""") monkeypatch.setenv("PYTHONDONTWRITEBYTECODE", "1") assert testdir.runpytest().ret == 0 @pytest.mark.skipif('"__pypy__" in sys.modules') def test_pyc_vs_pyo(self, testdir, monkeypatch): testdir.makepyfile(""" import pytest def test_optimized(): "hello" assert test_optimized.__doc__ is None""" ) p = py.path.local.make_numbered_dir(prefix="runpytest-", keep=None, rootdir=testdir.tmpdir) tmp = "--basetemp=%s" % p monkeypatch.setenv("PYTHONOPTIMIZE", "2") monkeypatch.delenv("PYTHONDONTWRITEBYTECODE", raising=False) assert testdir.runpybin("py.test", tmp).ret == 0 tagged = "test_pyc_vs_pyo." + PYTEST_TAG assert tagged + ".pyo" in os.listdir("__pycache__") monkeypatch.undo() monkeypatch.delenv("PYTHONDONTWRITEBYTECODE", raising=False) assert testdir.runpybin("py.test", tmp).ret == 1 assert tagged + ".pyc" in os.listdir("__pycache__") def test_package(self, testdir): pkg = testdir.tmpdir.join("pkg") pkg.mkdir() pkg.join("__init__.py").ensure() pkg.join("test_blah.py").write(""" def test_rewritten(): assert "@py_builtins" in globals()""") assert testdir.runpytest().ret == 0 def test_translate_newlines(self, testdir): content = "def test_rewritten():\r\n assert '@py_builtins' in globals()" b = content.encode("utf-8") testdir.tmpdir.join("test_newlines.py").write(b, "wb") assert testdir.runpytest().ret == 0