diff --git a/changelog/4066.bugfix.rst b/changelog/4066.bugfix.rst new file mode 100644 index 000000000..64980d6e8 --- /dev/null +++ b/changelog/4066.bugfix.rst @@ -0,0 +1 @@ +Fix source reindenting by using ``textwrap.dedent`` directly. diff --git a/src/_pytest/_code/source.py b/src/_pytest/_code/source.py index 3b037b7d4..f78d8bef0 100644 --- a/src/_pytest/_code/source.py +++ b/src/_pytest/_code/source.py @@ -7,6 +7,7 @@ import linecache import sys import six import inspect +import textwrap import tokenize import py @@ -23,7 +24,6 @@ class Source(object): def __init__(self, *parts, **kwargs): self.lines = lines = [] de = kwargs.get("deindent", True) - rstrip = kwargs.get("rstrip", True) for part in parts: if not part: partlines = [] @@ -33,11 +33,6 @@ class Source(object): partlines = [x.rstrip("\n") for x in part] elif isinstance(part, six.string_types): partlines = part.split("\n") - if rstrip: - while partlines: - if partlines[-1].strip(): - break - partlines.pop() else: partlines = getsource(part, deindent=de).lines if de: @@ -115,17 +110,10 @@ class Source(object): ast, start, end = getstatementrange_ast(lineno, self) return start, end - def deindent(self, offset=None): - """ return a new source object deindented by offset. - If offset is None then guess an indentation offset from - the first non-blank line. Subsequent lines which have a - lower indentation offset will be copied verbatim as - they are assumed to be part of multilines. - """ - # XXX maybe use the tokenizer to properly handle multiline - # strings etc.pp? + def deindent(self): + """return a new source object deindented.""" newsource = Source() - newsource.lines[:] = deindent(self.lines, offset) + newsource.lines[:] = deindent(self.lines) return newsource def isparseable(self, deindent=True): @@ -268,47 +256,8 @@ def getsource(obj, **kwargs): return Source(strsrc, **kwargs) -def deindent(lines, offset=None): - if offset is None: - for line in lines: - line = line.expandtabs() - s = line.lstrip() - if s: - offset = len(line) - len(s) - break - else: - offset = 0 - if offset == 0: - return list(lines) - newlines = [] - - def readline_generator(lines): - for line in lines: - yield line + "\n" - - it = readline_generator(lines) - - try: - for _, _, (sline, _), (eline, _), _ in tokenize.generate_tokens( - lambda: next(it) - ): - if sline > len(lines): - break # End of input reached - if sline > len(newlines): - line = lines[sline - 1].expandtabs() - if line.lstrip() and line[:offset].isspace(): - line = line[offset:] # Deindent - newlines.append(line) - - for i in range(sline, eline): - # Don't deindent continuing lines of - # multiline tokens (i.e. multiline strings) - newlines.append(lines[i]) - except (IndentationError, tokenize.TokenError): - pass - # Add any lines we didn't see. E.g. if an exception was raised. - newlines.extend(lines[len(newlines) :]) - return newlines +def deindent(lines): + return textwrap.dedent("\n".join(lines)).splitlines() def get_statement_startend2(lineno, node): diff --git a/testing/code/test_source.py b/testing/code/test_source.py index 66e880e05..f31748c0e 100644 --- a/testing/code/test_source.py +++ b/testing/code/test_source.py @@ -27,16 +27,7 @@ def test_source_str_function(): x = Source( """ 3 - """, - rstrip=False, - ) - assert str(x) == "\n3\n " - - x = Source( """ - 3 - """, - rstrip=True, ) assert str(x) == "\n3" @@ -400,10 +391,13 @@ def test_getfuncsource_with_multine_string(): pass """ - assert ( - str(_pytest._code.Source(f)).strip() - == 'def f():\n c = """while True:\n pass\n"""' - ) + expected = '''\ + def f(): + c = """while True: + pass +""" +''' + assert str(_pytest._code.Source(f)) == expected.rstrip() def test_deindent(): @@ -411,21 +405,13 @@ def test_deindent(): assert deindent(["\tfoo", "\tbar"]) == ["foo", "bar"] - def f(): - c = """while True: - pass -""" - - lines = deindent(inspect.getsource(f).splitlines()) - assert lines == ["def f():", ' c = """while True:', " pass", '"""'] - - source = """ + source = """\ def f(): def g(): pass """ lines = deindent(source.splitlines()) - assert lines == ["", "def f():", " def g():", " pass", " "] + assert lines == ["def f():", " def g():", " pass"] def test_source_of_class_at_eof_without_newline(tmpdir):