2011-05-19 04:31:10 +08:00
|
|
|
"""Rewrite assertion AST to produce nice error messages"""
|
|
|
|
|
|
|
|
import ast
|
|
|
|
import collections
|
|
|
|
import itertools
|
2011-06-29 10:13:12 +08:00
|
|
|
import imp
|
|
|
|
import marshal
|
|
|
|
import os
|
|
|
|
import struct
|
2011-05-27 10:15:40 +08:00
|
|
|
import sys
|
2011-07-07 12:24:04 +08:00
|
|
|
import types
|
2011-05-19 04:31:10 +08:00
|
|
|
|
|
|
|
import py
|
2011-05-27 01:01:34 +08:00
|
|
|
from _pytest.assertion import util
|
2011-05-19 04:31:10 +08:00
|
|
|
|
|
|
|
|
2011-06-29 10:13:12 +08:00
|
|
|
# py.test caches rewritten pycs in __pycache__.
|
|
|
|
if hasattr(imp, "get_tag"):
|
|
|
|
PYTEST_TAG = imp.get_tag() + "-PYTEST"
|
|
|
|
else:
|
2011-07-09 02:53:23 +08:00
|
|
|
if hasattr(sys, "pypy_version_info"):
|
|
|
|
impl = "pypy"
|
|
|
|
elif sys.platform == "java":
|
|
|
|
impl = "jython"
|
|
|
|
else:
|
|
|
|
impl = "cpython"
|
2011-06-29 10:13:12 +08:00
|
|
|
ver = sys.version_info
|
2011-07-09 02:53:23 +08:00
|
|
|
PYTEST_TAG = "%s-%s%s-PYTEST" % (impl, ver[0], ver[1])
|
|
|
|
del ver, impl
|
2011-06-29 10:13:12 +08:00
|
|
|
|
|
|
|
class AssertionRewritingHook(object):
|
2011-07-07 12:24:04 +08:00
|
|
|
"""Import hook which rewrites asserts."""
|
2011-06-29 10:13:12 +08:00
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
self.session = None
|
2011-07-07 12:24:04 +08:00
|
|
|
self.modules = {}
|
2011-06-29 10:13:12 +08:00
|
|
|
|
|
|
|
def set_session(self, session):
|
|
|
|
self.fnpats = session.config.getini("python_files")
|
|
|
|
self.session = session
|
|
|
|
|
|
|
|
def find_module(self, name, path=None):
|
|
|
|
if self.session is None:
|
|
|
|
return None
|
|
|
|
sess = self.session
|
|
|
|
state = sess.config._assertstate
|
|
|
|
names = name.rsplit(".", 1)
|
|
|
|
lastname = names[-1]
|
|
|
|
pth = None
|
|
|
|
if path is not None and len(path) == 1:
|
|
|
|
pth = path[0]
|
|
|
|
if pth is None:
|
|
|
|
try:
|
|
|
|
fd, fn, desc = imp.find_module(lastname, path)
|
|
|
|
except ImportError:
|
|
|
|
return None
|
|
|
|
if fd is not None:
|
|
|
|
fd.close()
|
|
|
|
tp = desc[2]
|
|
|
|
if tp == imp.PY_COMPILED:
|
|
|
|
if hasattr(imp, "source_from_cache"):
|
|
|
|
fn = imp.source_from_cache(fn)
|
|
|
|
else:
|
|
|
|
fn = fn[:-1]
|
|
|
|
elif tp != imp.PY_SOURCE:
|
|
|
|
# Don't know what this is.
|
|
|
|
return None
|
|
|
|
else:
|
|
|
|
fn = os.path.join(pth, name + ".py")
|
|
|
|
fn_pypath = py.path.local(fn)
|
|
|
|
# Is this a test file?
|
|
|
|
if not sess.isinitpath(fn):
|
|
|
|
# We have to be very careful here because imports in this code can
|
|
|
|
# trigger a cycle.
|
|
|
|
self.session = None
|
|
|
|
try:
|
|
|
|
for pat in self.fnpats:
|
|
|
|
if fn_pypath.fnmatch(pat):
|
|
|
|
break
|
|
|
|
else:
|
|
|
|
return None
|
|
|
|
finally:
|
|
|
|
self.session = sess
|
2011-07-07 12:24:04 +08:00
|
|
|
# The requested module looks like a test file, so rewrite it. This is
|
|
|
|
# the most magical part of the process: load the source, rewrite the
|
|
|
|
# asserts, and load the rewritten source. We also cache the rewritten
|
|
|
|
# module code in a special pyc. We must be aware of the possibility of
|
|
|
|
# concurrent py.test processes rewriting and loading pycs. To avoid
|
|
|
|
# tricky race conditions, we maintain the following invariant: The
|
|
|
|
# cached pyc is always a complete, valid pyc. Operations on it must be
|
|
|
|
# atomic. POSIX's atomic rename comes in handy.
|
|
|
|
cache_dir = os.path.join(fn_pypath.dirname, "__pycache__")
|
2011-07-13 06:09:14 +08:00
|
|
|
try:
|
|
|
|
py.path.local(cache_dir).ensure(dir=True)
|
|
|
|
except py.error.EACCES:
|
|
|
|
state.trace("read only directory: %r" % (fn_pypath.dirname,))
|
|
|
|
write = False
|
|
|
|
else:
|
|
|
|
write = True
|
2011-07-07 12:24:04 +08:00
|
|
|
cache_name = fn_pypath.basename[:-3] + "." + PYTEST_TAG + ".pyc"
|
|
|
|
pyc = os.path.join(cache_dir, cache_name)
|
2011-07-13 06:09:14 +08:00
|
|
|
# Notice that even if we're in a read-only directory, I'm going to check
|
|
|
|
# for a cached pyc. This may not be optimal...
|
2011-07-07 12:24:04 +08:00
|
|
|
co = _read_pyc(fn_pypath, pyc)
|
|
|
|
if co is None:
|
2011-06-29 10:13:12 +08:00
|
|
|
state.trace("rewriting %r" % (fn,))
|
2011-07-13 06:09:14 +08:00
|
|
|
co = _rewrite_test(state, fn_pypath)
|
2011-07-07 12:24:04 +08:00
|
|
|
if co is None:
|
2011-07-13 06:09:14 +08:00
|
|
|
# Probably a SyntaxError in the test.
|
2011-07-07 12:24:04 +08:00
|
|
|
return None
|
2011-07-13 06:09:14 +08:00
|
|
|
if write:
|
|
|
|
_make_rewritten_pyc(state, fn_pypath, pyc, co)
|
2011-06-29 10:13:12 +08:00
|
|
|
else:
|
2011-07-07 12:24:04 +08:00
|
|
|
state.trace("found cached rewritten pyc for %r" % (fn,))
|
|
|
|
self.modules[name] = co, pyc
|
|
|
|
return self
|
|
|
|
|
|
|
|
def load_module(self, name):
|
|
|
|
co, pyc = self.modules.pop(name)
|
|
|
|
# I wish I could just call imp.load_compiled here, but __file__ has to
|
|
|
|
# be set properly. In Python 3.2+, this all would be handled correctly
|
|
|
|
# by load_compiled.
|
|
|
|
mod = sys.modules[name] = imp.new_module(name)
|
|
|
|
try:
|
|
|
|
mod.__file__ = co.co_filename
|
|
|
|
# Normally, this attribute is 3.2+.
|
|
|
|
mod.__cached__ = pyc
|
2011-07-07 22:27:40 +08:00
|
|
|
py.builtin.exec_(co, mod.__dict__)
|
2011-07-07 12:24:04 +08:00
|
|
|
except:
|
|
|
|
del sys.modules[name]
|
|
|
|
raise
|
|
|
|
return sys.modules[name]
|
2011-06-29 10:13:12 +08:00
|
|
|
|
|
|
|
def _write_pyc(co, source_path, pyc):
|
2011-07-07 12:24:04 +08:00
|
|
|
# Technically, we don't have to have the same pyc format as (C)Python, since
|
|
|
|
# these "pycs" should never be seen by builtin import. However, there's
|
|
|
|
# little reason deviate, and I hope sometime to be able to use
|
|
|
|
# imp.load_compiled to load them. (See the comment in load_module above.)
|
2011-06-29 10:13:12 +08:00
|
|
|
mtime = int(source_path.mtime())
|
2011-07-07 12:24:04 +08:00
|
|
|
fp = open(pyc, "wb")
|
2011-06-29 10:13:12 +08:00
|
|
|
try:
|
|
|
|
fp.write(imp.get_magic())
|
|
|
|
fp.write(struct.pack("<l", mtime))
|
|
|
|
marshal.dump(co, fp)
|
|
|
|
finally:
|
|
|
|
fp.close()
|
|
|
|
|
2011-07-13 06:09:14 +08:00
|
|
|
def _rewrite_test(state, fn):
|
|
|
|
"""Try to read and rewrite *fn* and return the code object."""
|
2011-06-29 10:13:12 +08:00
|
|
|
try:
|
|
|
|
source = fn.read("rb")
|
|
|
|
except EnvironmentError:
|
|
|
|
return None
|
|
|
|
try:
|
|
|
|
tree = ast.parse(source)
|
|
|
|
except SyntaxError:
|
|
|
|
# Let this pop up again in the real import.
|
|
|
|
state.trace("failed to parse: %r" % (fn,))
|
|
|
|
return None
|
|
|
|
rewrite_asserts(tree)
|
|
|
|
try:
|
|
|
|
co = compile(tree, fn.strpath, "exec")
|
|
|
|
except SyntaxError:
|
|
|
|
# It's possible that this error is from some bug in the
|
|
|
|
# assertion rewriting, but I don't know of a fast way to tell.
|
|
|
|
state.trace("failed to compile: %r" % (fn,))
|
|
|
|
return None
|
2011-07-13 06:09:14 +08:00
|
|
|
return co
|
|
|
|
|
|
|
|
def _make_rewritten_pyc(state, fn, pyc, co):
|
|
|
|
"""Try to dump rewritten code to *pyc*."""
|
2011-07-09 02:17:42 +08:00
|
|
|
if sys.platform.startswith("win"):
|
|
|
|
# Windows grants exclusive access to open files and doesn't have atomic
|
|
|
|
# rename, so just write into the final file.
|
|
|
|
_write_pyc(co, fn, pyc)
|
|
|
|
else:
|
|
|
|
# When not on windows, assume rename is atomic. Dump the code object
|
|
|
|
# into a file specific to this process and atomically replace it.
|
|
|
|
proc_pyc = pyc + "." + str(os.getpid())
|
|
|
|
_write_pyc(co, fn, proc_pyc)
|
|
|
|
os.rename(proc_pyc, pyc)
|
2011-07-07 12:24:04 +08:00
|
|
|
return co
|
|
|
|
|
|
|
|
def _read_pyc(source, pyc):
|
|
|
|
"""Possibly read a py.test pyc containing rewritten code.
|
|
|
|
|
|
|
|
Return rewritten code if successful or None if not.
|
|
|
|
"""
|
|
|
|
try:
|
|
|
|
fp = open(pyc, "rb")
|
|
|
|
except IOError:
|
|
|
|
return None
|
2011-06-29 10:13:12 +08:00
|
|
|
try:
|
|
|
|
try:
|
2011-07-07 12:24:04 +08:00
|
|
|
mtime = int(source.mtime())
|
2011-06-29 10:13:12 +08:00
|
|
|
data = fp.read(8)
|
2011-07-07 12:24:04 +08:00
|
|
|
except EnvironmentError:
|
|
|
|
return None
|
|
|
|
# Check for invalid or out of date pyc file.
|
|
|
|
if (len(data) != 8 or
|
|
|
|
data[:4] != imp.get_magic() or
|
|
|
|
struct.unpack("<l", data[4:])[0] != mtime):
|
|
|
|
return None
|
|
|
|
co = marshal.load(fp)
|
|
|
|
if not isinstance(co, types.CodeType):
|
|
|
|
# That's interesting....
|
|
|
|
return None
|
|
|
|
return co
|
|
|
|
finally:
|
|
|
|
fp.close()
|
2011-07-06 01:02:53 +08:00
|
|
|
|
2011-06-29 10:13:12 +08:00
|
|
|
|
2011-05-19 04:31:10 +08:00
|
|
|
def rewrite_asserts(mod):
|
|
|
|
"""Rewrite the assert statements in mod."""
|
|
|
|
AssertionRewriter().run(mod)
|
|
|
|
|
|
|
|
|
|
|
|
_saferepr = py.io.saferepr
|
2011-05-27 01:01:34 +08:00
|
|
|
from _pytest.assertion.util import format_explanation as _format_explanation
|
2011-05-19 04:31:10 +08:00
|
|
|
|
2011-06-29 09:21:22 +08:00
|
|
|
def _format_boolop(explanations, is_or):
|
|
|
|
return "(" + (is_or and " or " or " and ").join(explanations) + ")"
|
2011-05-19 04:31:10 +08:00
|
|
|
|
|
|
|
def _call_reprcompare(ops, results, expls, each_obj):
|
|
|
|
for i, res, expl in zip(range(len(ops)), results, expls):
|
2011-05-25 07:15:08 +08:00
|
|
|
try:
|
|
|
|
done = not res
|
|
|
|
except Exception:
|
|
|
|
done = True
|
|
|
|
if done:
|
2011-05-19 04:31:10 +08:00
|
|
|
break
|
2011-05-27 01:01:34 +08:00
|
|
|
if util._reprcompare is not None:
|
|
|
|
custom = util._reprcompare(ops[i], each_obj[i], each_obj[i + 1])
|
2011-05-19 04:31:10 +08:00
|
|
|
if custom is not None:
|
|
|
|
return custom
|
|
|
|
return expl
|
|
|
|
|
|
|
|
|
|
|
|
unary_map = {
|
|
|
|
ast.Not : "not %s",
|
|
|
|
ast.Invert : "~%s",
|
|
|
|
ast.USub : "-%s",
|
|
|
|
ast.UAdd : "+%s"
|
|
|
|
}
|
|
|
|
|
|
|
|
binop_map = {
|
|
|
|
ast.BitOr : "|",
|
|
|
|
ast.BitXor : "^",
|
|
|
|
ast.BitAnd : "&",
|
|
|
|
ast.LShift : "<<",
|
|
|
|
ast.RShift : ">>",
|
|
|
|
ast.Add : "+",
|
|
|
|
ast.Sub : "-",
|
|
|
|
ast.Mult : "*",
|
|
|
|
ast.Div : "/",
|
|
|
|
ast.FloorDiv : "//",
|
|
|
|
ast.Mod : "%",
|
|
|
|
ast.Eq : "==",
|
|
|
|
ast.NotEq : "!=",
|
|
|
|
ast.Lt : "<",
|
|
|
|
ast.LtE : "<=",
|
|
|
|
ast.Gt : ">",
|
|
|
|
ast.GtE : ">=",
|
|
|
|
ast.Pow : "**",
|
|
|
|
ast.Is : "is",
|
|
|
|
ast.IsNot : "is not",
|
|
|
|
ast.In : "in",
|
|
|
|
ast.NotIn : "not in"
|
|
|
|
}
|
|
|
|
|
|
|
|
|
2011-05-20 10:49:37 +08:00
|
|
|
def set_location(node, lineno, col_offset):
|
|
|
|
"""Set node location information recursively."""
|
|
|
|
def _fix(node, lineno, col_offset):
|
|
|
|
if "lineno" in node._attributes:
|
|
|
|
node.lineno = lineno
|
|
|
|
if "col_offset" in node._attributes:
|
|
|
|
node.col_offset = col_offset
|
|
|
|
for child in ast.iter_child_nodes(node):
|
|
|
|
_fix(child, lineno, col_offset)
|
|
|
|
_fix(node, lineno, col_offset)
|
|
|
|
return node
|
|
|
|
|
|
|
|
|
2011-05-19 04:31:10 +08:00
|
|
|
class AssertionRewriter(ast.NodeVisitor):
|
|
|
|
|
|
|
|
def run(self, mod):
|
|
|
|
"""Find all assert statements in *mod* and rewrite them."""
|
|
|
|
if not mod.body:
|
|
|
|
# Nothing to do.
|
|
|
|
return
|
2011-05-25 06:28:20 +08:00
|
|
|
# Insert some special imports at the top of the module but after any
|
|
|
|
# docstrings and __future__ imports.
|
2011-05-19 04:31:10 +08:00
|
|
|
aliases = [ast.alias(py.builtin.builtins.__name__, "@py_builtins"),
|
2011-05-26 05:18:45 +08:00
|
|
|
ast.alias("_pytest.assertion.rewrite", "@pytest_ar")]
|
2011-05-20 05:53:13 +08:00
|
|
|
expect_docstring = True
|
2011-05-19 04:31:10 +08:00
|
|
|
pos = 0
|
2011-05-25 06:21:58 +08:00
|
|
|
lineno = 0
|
2011-05-20 05:53:13 +08:00
|
|
|
for item in mod.body:
|
|
|
|
if (expect_docstring and isinstance(item, ast.Expr) and
|
|
|
|
isinstance(item.value, ast.Str)):
|
2011-05-25 06:30:35 +08:00
|
|
|
doc = item.value.s
|
|
|
|
if "PYTEST_DONT_REWRITE" in doc:
|
|
|
|
# The module has disabled assertion rewriting.
|
|
|
|
return
|
|
|
|
lineno += len(doc) - 1
|
2011-05-20 05:53:13 +08:00
|
|
|
expect_docstring = False
|
2011-06-28 23:39:11 +08:00
|
|
|
elif (not isinstance(item, ast.ImportFrom) or item.level > 0 or
|
|
|
|
item.module != "__future__"):
|
2011-05-25 06:21:58 +08:00
|
|
|
lineno = item.lineno
|
2011-05-20 05:53:13 +08:00
|
|
|
break
|
|
|
|
pos += 1
|
2011-05-25 06:21:58 +08:00
|
|
|
imports = [ast.Import([alias], lineno=lineno, col_offset=0)
|
|
|
|
for alias in aliases]
|
2011-05-19 04:31:10 +08:00
|
|
|
mod.body[pos:pos] = imports
|
|
|
|
# Collect asserts.
|
2011-06-29 10:11:56 +08:00
|
|
|
nodes = [mod]
|
2011-05-19 04:31:10 +08:00
|
|
|
while nodes:
|
2011-06-29 10:11:56 +08:00
|
|
|
node = nodes.pop()
|
2011-05-19 04:31:10 +08:00
|
|
|
for name, field in ast.iter_fields(node):
|
|
|
|
if isinstance(field, list):
|
2011-05-20 07:56:48 +08:00
|
|
|
new = []
|
2011-05-19 04:31:10 +08:00
|
|
|
for i, child in enumerate(field):
|
|
|
|
if isinstance(child, ast.Assert):
|
2011-05-20 07:56:48 +08:00
|
|
|
# Transform assert.
|
|
|
|
new.extend(self.visit(child))
|
|
|
|
else:
|
|
|
|
new.append(child)
|
|
|
|
if isinstance(child, ast.AST):
|
|
|
|
nodes.append(child)
|
|
|
|
setattr(node, name, new)
|
2011-05-19 04:31:10 +08:00
|
|
|
elif (isinstance(field, ast.AST) and
|
|
|
|
# Don't recurse into expressions as they can't contain
|
|
|
|
# asserts.
|
|
|
|
not isinstance(field, ast.expr)):
|
|
|
|
nodes.append(field)
|
|
|
|
|
2011-05-20 07:32:48 +08:00
|
|
|
def variable(self):
|
|
|
|
"""Get a new variable."""
|
2011-05-19 04:31:10 +08:00
|
|
|
# Use a character invalid in python identifiers to avoid clashing.
|
|
|
|
name = "@py_assert" + str(next(self.variable_counter))
|
2011-06-29 09:21:22 +08:00
|
|
|
self.variables[self.cond_chain].add(name)
|
2011-05-20 07:32:48 +08:00
|
|
|
return name
|
|
|
|
|
|
|
|
def assign(self, expr):
|
|
|
|
"""Give *expr* a name."""
|
|
|
|
name = self.variable()
|
2011-05-19 04:31:10 +08:00
|
|
|
self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr))
|
|
|
|
return ast.Name(name, ast.Load())
|
|
|
|
|
|
|
|
def display(self, expr):
|
|
|
|
"""Call py.io.saferepr on the expression."""
|
2011-05-27 01:01:34 +08:00
|
|
|
return self.helper("saferepr", expr)
|
2011-05-19 04:31:10 +08:00
|
|
|
|
2011-05-27 01:01:34 +08:00
|
|
|
def helper(self, name, *args):
|
2011-05-19 04:31:10 +08:00
|
|
|
"""Call a helper in this module."""
|
2011-05-27 01:01:34 +08:00
|
|
|
py_name = ast.Name("@pytest_ar", ast.Load())
|
2011-05-19 04:31:10 +08:00
|
|
|
attr = ast.Attribute(py_name, "_" + name, ast.Load())
|
|
|
|
return ast.Call(attr, list(args), [], None, None)
|
|
|
|
|
|
|
|
def builtin(self, name):
|
|
|
|
"""Return the builtin called *name*."""
|
|
|
|
builtin_name = ast.Name("@py_builtins", ast.Load())
|
|
|
|
return ast.Attribute(builtin_name, name, ast.Load())
|
|
|
|
|
|
|
|
def explanation_param(self, expr):
|
|
|
|
specifier = "py" + str(next(self.variable_counter))
|
|
|
|
self.explanation_specifiers[specifier] = expr
|
|
|
|
return "%(" + specifier + ")s"
|
|
|
|
|
2011-06-29 09:21:22 +08:00
|
|
|
def enter_cond(self, cond, body):
|
|
|
|
self.statements.append(ast.If(cond, body, []))
|
|
|
|
self.cond_chain += cond,
|
|
|
|
|
|
|
|
def leave_cond(self, n=1):
|
|
|
|
self.cond_chain = self.cond_chain[:-n]
|
|
|
|
|
2011-05-19 04:31:10 +08:00
|
|
|
def push_format_context(self):
|
|
|
|
self.explanation_specifiers = {}
|
|
|
|
self.stack.append(self.explanation_specifiers)
|
|
|
|
|
|
|
|
def pop_format_context(self, expl_expr):
|
|
|
|
current = self.stack.pop()
|
|
|
|
if self.stack:
|
|
|
|
self.explanation_specifiers = self.stack[-1]
|
|
|
|
keys = [ast.Str(key) for key in current.keys()]
|
2011-05-25 07:28:20 +08:00
|
|
|
format_dict = ast.Dict(keys, list(current.values()))
|
2011-05-19 04:31:10 +08:00
|
|
|
form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
|
|
|
|
name = "@py_format" + str(next(self.variable_counter))
|
|
|
|
self.on_failure.append(ast.Assign([ast.Name(name, ast.Store())], form))
|
|
|
|
return ast.Name(name, ast.Load())
|
|
|
|
|
|
|
|
def generic_visit(self, node):
|
|
|
|
"""Handle expressions we don't have custom code for."""
|
|
|
|
assert isinstance(node, ast.expr)
|
|
|
|
res = self.assign(node)
|
|
|
|
return res, self.explanation_param(self.display(res))
|
|
|
|
|
|
|
|
def visit_Assert(self, assert_):
|
|
|
|
if assert_.msg:
|
|
|
|
# There's already a message. Don't mess with it.
|
|
|
|
return [assert_]
|
|
|
|
self.statements = []
|
2011-06-29 09:21:22 +08:00
|
|
|
self.cond_chain = ()
|
|
|
|
self.variables = collections.defaultdict(set)
|
2011-05-19 04:31:10 +08:00
|
|
|
self.variable_counter = itertools.count()
|
|
|
|
self.stack = []
|
|
|
|
self.on_failure = []
|
|
|
|
self.push_format_context()
|
|
|
|
# Rewrite assert into a bunch of statements.
|
|
|
|
top_condition, explanation = self.visit(assert_.test)
|
|
|
|
# Create failure message.
|
|
|
|
body = self.on_failure
|
|
|
|
negation = ast.UnaryOp(ast.Not(), top_condition)
|
|
|
|
self.statements.append(ast.If(negation, body, []))
|
|
|
|
explanation = "assert " + explanation
|
|
|
|
template = ast.Str(explanation)
|
|
|
|
msg = self.pop_format_context(template)
|
2011-05-27 01:01:34 +08:00
|
|
|
fmt = self.helper("format_explanation", msg)
|
2011-05-27 10:15:40 +08:00
|
|
|
err_name = ast.Name("AssertionError", ast.Load())
|
|
|
|
exc = ast.Call(err_name, [fmt], [], None, None)
|
|
|
|
if sys.version_info[0] >= 3:
|
|
|
|
raise_ = ast.Raise(exc, None)
|
|
|
|
else:
|
|
|
|
raise_ = ast.Raise(exc, None, None)
|
|
|
|
body.append(raise_)
|
2011-06-29 09:21:22 +08:00
|
|
|
# Delete temporary variables. This requires a bit cleverness about the
|
|
|
|
# order, so we don't delete variables that are themselves conditions for
|
|
|
|
# later variables.
|
|
|
|
for chain in sorted(self.variables, key=len, reverse=True):
|
|
|
|
if chain:
|
|
|
|
where = []
|
|
|
|
if len(chain) > 1:
|
2011-07-11 17:57:47 +08:00
|
|
|
cond = ast.BoolOp(ast.And(), list(chain))
|
2011-06-29 09:21:22 +08:00
|
|
|
else:
|
|
|
|
cond = chain[0]
|
|
|
|
self.statements.append(ast.If(cond, where, []))
|
|
|
|
else:
|
|
|
|
where = self.statements
|
|
|
|
v = self.variables[chain]
|
|
|
|
names = [ast.Name(name, ast.Del()) for name in v]
|
|
|
|
where.append(ast.Delete(names))
|
2011-05-19 04:31:10 +08:00
|
|
|
# Fix line numbers.
|
|
|
|
for stmt in self.statements:
|
2011-05-20 10:49:37 +08:00
|
|
|
set_location(stmt, assert_.lineno, assert_.col_offset)
|
2011-05-19 04:31:10 +08:00
|
|
|
return self.statements
|
|
|
|
|
|
|
|
def visit_Name(self, name):
|
|
|
|
# Check if the name is local or not.
|
|
|
|
locs = ast.Call(self.builtin("locals"), [], [], None, None)
|
|
|
|
globs = ast.Call(self.builtin("globals"), [], [], None, None)
|
|
|
|
ops = [ast.In(), ast.IsNot()]
|
|
|
|
test = ast.Compare(ast.Str(name.id), ops, [locs, globs])
|
|
|
|
expr = ast.IfExp(test, self.display(name), ast.Str(name.id))
|
|
|
|
return name, self.explanation_param(expr)
|
|
|
|
|
|
|
|
def visit_BoolOp(self, boolop):
|
2011-06-29 09:21:22 +08:00
|
|
|
res_var = self.variable()
|
|
|
|
expl_list = self.assign(ast.List([], ast.Load()))
|
|
|
|
app = ast.Attribute(expl_list, "append", ast.Load())
|
|
|
|
is_or = isinstance(boolop.op, ast.Or)
|
|
|
|
body = save = self.statements
|
|
|
|
levels = len(boolop.values) - 1
|
2011-05-19 04:31:10 +08:00
|
|
|
self.push_format_context()
|
2011-06-29 09:21:22 +08:00
|
|
|
# Process each operand, short-circuting if needed.
|
|
|
|
for i, v in enumerate(boolop.values):
|
|
|
|
res, expl = self.visit(v)
|
|
|
|
body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
|
|
|
|
call = ast.Call(app, [ast.Str(expl)], [], None, None)
|
|
|
|
body.append(ast.Expr(call))
|
|
|
|
if i < levels:
|
|
|
|
inner = []
|
|
|
|
cond = res
|
|
|
|
if is_or:
|
|
|
|
cond = ast.UnaryOp(ast.Not(), cond)
|
|
|
|
self.enter_cond(cond, inner)
|
|
|
|
self.statements = body = inner
|
|
|
|
# Leave all conditions.
|
|
|
|
self.leave_cond(levels)
|
|
|
|
self.statements = save
|
|
|
|
expl_template = self.helper("format_boolop", expl_list, ast.Num(is_or))
|
2011-05-19 04:31:10 +08:00
|
|
|
expl = self.pop_format_context(expl_template)
|
2011-06-29 09:21:22 +08:00
|
|
|
return ast.Name(res_var, ast.Load()), self.explanation_param(expl)
|
2011-05-19 04:31:10 +08:00
|
|
|
|
|
|
|
def visit_UnaryOp(self, unary):
|
|
|
|
pattern = unary_map[unary.op.__class__]
|
|
|
|
operand_res, operand_expl = self.visit(unary.operand)
|
|
|
|
res = self.assign(ast.UnaryOp(unary.op, operand_res))
|
|
|
|
return res, pattern % (operand_expl,)
|
|
|
|
|
|
|
|
def visit_BinOp(self, binop):
|
|
|
|
symbol = binop_map[binop.op.__class__]
|
|
|
|
left_expr, left_expl = self.visit(binop.left)
|
|
|
|
right_expr, right_expl = self.visit(binop.right)
|
|
|
|
explanation = "(%s %s %s)" % (left_expl, symbol, right_expl)
|
|
|
|
res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
|
|
|
|
return res, explanation
|
|
|
|
|
|
|
|
def visit_Call(self, call):
|
|
|
|
new_func, func_expl = self.visit(call.func)
|
|
|
|
arg_expls = []
|
|
|
|
new_args = []
|
|
|
|
new_kwargs = []
|
|
|
|
new_star = new_kwarg = None
|
|
|
|
for arg in call.args:
|
|
|
|
res, expl = self.visit(arg)
|
|
|
|
new_args.append(res)
|
|
|
|
arg_expls.append(expl)
|
|
|
|
for keyword in call.keywords:
|
|
|
|
res, expl = self.visit(keyword.value)
|
|
|
|
new_kwargs.append(ast.keyword(keyword.arg, res))
|
|
|
|
arg_expls.append(keyword.arg + "=" + expl)
|
|
|
|
if call.starargs:
|
|
|
|
new_star, expl = self.visit(call.starargs)
|
|
|
|
arg_expls.append("*" + expl)
|
|
|
|
if call.kwargs:
|
|
|
|
new_kwarg, expl = self.visit(call.kwarg)
|
|
|
|
arg_expls.append("**" + expl)
|
|
|
|
expl = "%s(%s)" % (func_expl, ', '.join(arg_expls))
|
|
|
|
new_call = ast.Call(new_func, new_args, new_kwargs, new_star, new_kwarg)
|
|
|
|
res = self.assign(new_call)
|
|
|
|
res_expl = self.explanation_param(self.display(res))
|
|
|
|
outer_expl = "%s\n{%s = %s\n}" % (res_expl, res_expl, expl)
|
|
|
|
return res, outer_expl
|
|
|
|
|
|
|
|
def visit_Attribute(self, attr):
|
|
|
|
if not isinstance(attr.ctx, ast.Load):
|
|
|
|
return self.generic_visit(attr)
|
|
|
|
value, value_expl = self.visit(attr.value)
|
|
|
|
res = self.assign(ast.Attribute(value, attr.attr, ast.Load()))
|
|
|
|
res_expl = self.explanation_param(self.display(res))
|
|
|
|
pat = "%s\n{%s = %s.%s\n}"
|
|
|
|
expl = pat % (res_expl, res_expl, value_expl, attr.attr)
|
|
|
|
return res, expl
|
|
|
|
|
|
|
|
def visit_Compare(self, comp):
|
|
|
|
self.push_format_context()
|
|
|
|
left_res, left_expl = self.visit(comp.left)
|
2011-05-20 07:32:48 +08:00
|
|
|
res_variables = [self.variable() for i in range(len(comp.ops))]
|
2011-05-19 04:31:10 +08:00
|
|
|
load_names = [ast.Name(v, ast.Load()) for v in res_variables]
|
|
|
|
store_names = [ast.Name(v, ast.Store()) for v in res_variables]
|
|
|
|
it = zip(range(len(comp.ops)), comp.ops, comp.comparators)
|
|
|
|
expls = []
|
|
|
|
syms = []
|
|
|
|
results = [left_res]
|
|
|
|
for i, op, next_operand in it:
|
|
|
|
next_res, next_expl = self.visit(next_operand)
|
|
|
|
results.append(next_res)
|
|
|
|
sym = binop_map[op.__class__]
|
|
|
|
syms.append(ast.Str(sym))
|
|
|
|
expl = "%s %s %s" % (left_expl, sym, next_expl)
|
|
|
|
expls.append(ast.Str(expl))
|
|
|
|
res_expr = ast.Compare(left_res, [op], [next_res])
|
|
|
|
self.statements.append(ast.Assign([store_names[i]], res_expr))
|
|
|
|
left_res, left_expl = next_res, next_expl
|
|
|
|
# Use py.code._reprcompare if that's available.
|
2011-05-27 01:01:34 +08:00
|
|
|
expl_call = self.helper("call_reprcompare",
|
2011-05-26 06:54:02 +08:00
|
|
|
ast.Tuple(syms, ast.Load()),
|
2011-05-19 04:31:10 +08:00
|
|
|
ast.Tuple(load_names, ast.Load()),
|
|
|
|
ast.Tuple(expls, ast.Load()),
|
|
|
|
ast.Tuple(results, ast.Load()))
|
2011-05-20 11:11:18 +08:00
|
|
|
if len(comp.ops) > 1:
|
|
|
|
res = ast.BoolOp(ast.And(), load_names)
|
|
|
|
else:
|
|
|
|
res = load_names[0]
|
2011-05-19 04:31:10 +08:00
|
|
|
return res, self.explanation_param(self.pop_format_context(expl_call))
|