diff --git a/_pytest/_code/source.py b/_pytest/_code/source.py index 2638c598b..322c72afb 100644 --- a/_pytest/_code/source.py +++ b/_pytest/_code/source.py @@ -1,19 +1,15 @@ from __future__ import absolute_import, division, generators, print_function +import ast +from ast import PyCF_ONLY_AST as _AST_FLAG from bisect import bisect_right import sys import six import inspect import tokenize import py -cpy_compile = compile -try: - import _ast - from _ast import PyCF_ONLY_AST as _AST_FLAG -except ImportError: - _AST_FLAG = 0 - _ast = None +cpy_compile = compile class Source(object): @@ -209,7 +205,7 @@ def compile_(source, filename=None, mode='exec', flags=generators.compiler_flag, retrieval of the source code for the code object and any recursively created code objects. """ - if _ast is not None and isinstance(source, _ast.AST): + if isinstance(source, ast.AST): # XXX should Source support having AST? return cpy_compile(source, filename, mode, flags, dont_inherit) _genframe = sys._getframe(1) # the caller @@ -322,7 +318,7 @@ def get_statement_startend2(lineno, node): # AST's line numbers start indexing at 1 values = [] for x in ast.walk(node): - if isinstance(x, _ast.stmt) or isinstance(x, _ast.ExceptHandler): + if isinstance(x, ast.stmt) or isinstance(x, ast.ExceptHandler): values.append(x.lineno - 1) for name in "finalbody", "orelse": val = getattr(x, name, None) diff --git a/_pytest/assertion/rewrite.py b/_pytest/assertion/rewrite.py index f64358f49..92deb6539 100644 --- a/_pytest/assertion/rewrite.py +++ b/_pytest/assertion/rewrite.py @@ -1,7 +1,6 @@ """Rewrite assertion AST to produce nice error messages""" from __future__ import absolute_import, division, print_function import ast -import _ast import errno import itertools import imp @@ -914,7 +913,7 @@ class AssertionRewriter(ast.NodeVisitor): def visit_Compare(self, comp): self.push_format_context() left_res, left_expl = self.visit(comp.left) - if isinstance(comp.left, (_ast.Compare, _ast.BoolOp)): + if isinstance(comp.left, (ast.Compare, ast.BoolOp)): left_expl = "({0})".format(left_expl) res_variables = [self.variable() for i in range(len(comp.ops))] load_names = [ast.Name(v, ast.Load()) for v in res_variables] @@ -925,7 +924,7 @@ class AssertionRewriter(ast.NodeVisitor): results = [left_res] for i, op, next_operand in it: next_res, next_expl = self.visit(next_operand) - if isinstance(next_operand, (_ast.Compare, _ast.BoolOp)): + if isinstance(next_operand, (ast.Compare, ast.BoolOp)): next_expl = "({0})".format(next_expl) results.append(next_res) sym = binop_map[op.__class__] diff --git a/_pytest/assertion/util.py b/_pytest/assertion/util.py index 511d98ef1..5a380ae09 100644 --- a/_pytest/assertion/util.py +++ b/_pytest/assertion/util.py @@ -5,11 +5,7 @@ import pprint import _pytest._code import py import six -try: - from collections import Sequence -except ImportError: - Sequence = list - +from collections import Sequence u = six.text_type @@ -113,7 +109,7 @@ def assertrepr_compare(config, op, left, right): summary = u('%s %s %s') % (ecu(left_repr), op, ecu(right_repr)) def issequence(x): - return (isinstance(x, (list, tuple, Sequence)) and not isinstance(x, basestring)) + return isinstance(x, Sequence) and not isinstance(x, basestring) def istext(x): return isinstance(x, basestring) diff --git a/changelog/3018.trivial b/changelog/3018.trivial new file mode 100644 index 000000000..8b4b4176b --- /dev/null +++ b/changelog/3018.trivial @@ -0,0 +1 @@ +Clean up code by replacing imports and references of `_ast` to `ast`. diff --git a/testing/code/test_source.py b/testing/code/test_source.py index 8eda68a6e..fcce3fa96 100644 --- a/testing/code/test_source.py +++ b/testing/code/test_source.py @@ -8,13 +8,10 @@ import _pytest._code import py import pytest from _pytest._code import Source -from _pytest._code.source import _ast +from _pytest._code.source import ast -if _ast is not None: - astonly = pytest.mark.nothing -else: - astonly = pytest.mark.xfail("True", reason="only works with AST-compile") +astonly = pytest.mark.nothing failsonjython = pytest.mark.xfail("sys.platform.startswith('java')")