Merge pull request #8540 from hauntsaninja/assert310

This commit is contained in:
Bruno Oliveira 2021-04-15 08:55:42 -03:00 committed by GitHub
commit af31c60db1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 22 additions and 6 deletions

View File

@ -277,6 +277,7 @@ Sankt Petersbug
Segev Finer Segev Finer
Serhii Mozghovyi Serhii Mozghovyi
Seth Junot Seth Junot
Shantanu Jain
Shubham Adep Shubham Adep
Simon Gomizelj Simon Gomizelj
Simon Kerr Simon Kerr

View File

@ -0,0 +1 @@
Fixed assertion rewriting on Python 3.10.

View File

@ -684,12 +684,9 @@ class AssertionRewriter(ast.NodeVisitor):
if not mod.body: if not mod.body:
# Nothing to do. # Nothing to do.
return return
# Insert some special imports at the top of the module but after any
# docstrings and __future__ imports. # We'll insert some special imports at the top of the module, but after any
aliases = [ # docstrings and __future__ imports, so first figure out where that is.
ast.alias("builtins", "@py_builtins"),
ast.alias("_pytest.assertion.rewrite", "@pytest_ar"),
]
doc = getattr(mod, "docstring", None) doc = getattr(mod, "docstring", None)
expect_docstring = doc is None expect_docstring = doc is None
if doc is not None and self.is_rewrite_disabled(doc): if doc is not None and self.is_rewrite_disabled(doc):
@ -721,10 +718,27 @@ class AssertionRewriter(ast.NodeVisitor):
lineno = item.decorator_list[0].lineno lineno = item.decorator_list[0].lineno
else: else:
lineno = item.lineno lineno = item.lineno
# Now actually insert the special imports.
if sys.version_info >= (3, 10):
aliases = [
ast.alias("builtins", "@py_builtins", lineno=lineno, col_offset=0),
ast.alias(
"_pytest.assertion.rewrite",
"@pytest_ar",
lineno=lineno,
col_offset=0,
),
]
else:
aliases = [
ast.alias("builtins", "@py_builtins"),
ast.alias("_pytest.assertion.rewrite", "@pytest_ar"),
]
imports = [ imports = [
ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases
] ]
mod.body[pos:pos] = imports mod.body[pos:pos] = imports
# Collect asserts. # Collect asserts.
nodes: List[ast.AST] = [mod] nodes: List[ast.AST] = [mod]
while nodes: while nodes: