Merge pull request #6129 from blueyed/typing

Typing around Node.location, reportinfo, repr_excinfo etc
This commit is contained in:
Daniel Hahler 2019-11-05 18:29:29 +01:00 committed by GitHub
commit 0794289689
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 100 additions and 52 deletions

View File

@ -549,7 +549,7 @@ class ExceptionInfo(Generic[_E]):
funcargs: bool = False, funcargs: bool = False,
truncate_locals: bool = True, truncate_locals: bool = True,
chain: bool = True, chain: bool = True,
): ) -> Union["ReprExceptionInfo", "ExceptionChainRepr"]:
""" """
Return str()able representation of this exception info. Return str()able representation of this exception info.
@ -818,19 +818,19 @@ class FormattedExcinfo:
return traceback, extraline return traceback, extraline
def repr_excinfo(self, excinfo): def repr_excinfo(self, excinfo: ExceptionInfo) -> "ExceptionChainRepr":
repr_chain = ( repr_chain = (
[] []
) # type: List[Tuple[ReprTraceback, Optional[ReprFileLocation], Optional[str]]] ) # type: List[Tuple[ReprTraceback, Optional[ReprFileLocation], Optional[str]]]
e = excinfo.value e = excinfo.value
excinfo_ = excinfo # type: Optional[ExceptionInfo]
descr = None descr = None
seen = set() # type: Set[int] seen = set() # type: Set[int]
while e is not None and id(e) not in seen: while e is not None and id(e) not in seen:
seen.add(id(e)) seen.add(id(e))
if excinfo: if excinfo_:
reprtraceback = self.repr_traceback(excinfo) reprtraceback = self.repr_traceback(excinfo_)
reprcrash = excinfo._getreprcrash() reprcrash = excinfo_._getreprcrash() # type: Optional[ReprFileLocation]
else: else:
# fallback to native repr if the exception doesn't have a traceback: # fallback to native repr if the exception doesn't have a traceback:
# ExceptionInfo objects require a full traceback to work # ExceptionInfo objects require a full traceback to work
@ -842,7 +842,7 @@ class FormattedExcinfo:
repr_chain += [(reprtraceback, reprcrash, descr)] repr_chain += [(reprtraceback, reprcrash, descr)]
if e.__cause__ is not None and self.chain: if e.__cause__ is not None and self.chain:
e = e.__cause__ e = e.__cause__
excinfo = ( excinfo_ = (
ExceptionInfo((type(e), e, e.__traceback__)) ExceptionInfo((type(e), e, e.__traceback__))
if e.__traceback__ if e.__traceback__
else None else None
@ -852,7 +852,7 @@ class FormattedExcinfo:
e.__context__ is not None and not e.__suppress_context__ and self.chain e.__context__ is not None and not e.__suppress_context__ and self.chain
): ):
e = e.__context__ e = e.__context__
excinfo = ( excinfo_ = (
ExceptionInfo((type(e), e, e.__traceback__)) ExceptionInfo((type(e), e, e.__traceback__))
if e.__traceback__ if e.__traceback__
else None else None
@ -876,6 +876,9 @@ class TerminalRepr:
def __repr__(self): def __repr__(self):
return "<{} instance at {:0x}>".format(self.__class__, id(self)) return "<{} instance at {:0x}>".format(self.__class__, id(self))
def toterminal(self, tw) -> None:
raise NotImplementedError()
class ExceptionRepr(TerminalRepr): class ExceptionRepr(TerminalRepr):
def __init__(self) -> None: def __init__(self) -> None:
@ -884,7 +887,7 @@ class ExceptionRepr(TerminalRepr):
def addsection(self, name, content, sep="-"): def addsection(self, name, content, sep="-"):
self.sections.append((name, content, sep)) self.sections.append((name, content, sep))
def toterminal(self, tw): def toterminal(self, tw) -> None:
for name, content, sep in self.sections: for name, content, sep in self.sections:
tw.sep(sep, name) tw.sep(sep, name)
tw.line(content) tw.line(content)
@ -899,7 +902,7 @@ class ExceptionChainRepr(ExceptionRepr):
self.reprtraceback = chain[-1][0] self.reprtraceback = chain[-1][0]
self.reprcrash = chain[-1][1] self.reprcrash = chain[-1][1]
def toterminal(self, tw): def toterminal(self, tw) -> None:
for element in self.chain: for element in self.chain:
element[0].toterminal(tw) element[0].toterminal(tw)
if element[2] is not None: if element[2] is not None:
@ -914,7 +917,7 @@ class ReprExceptionInfo(ExceptionRepr):
self.reprtraceback = reprtraceback self.reprtraceback = reprtraceback
self.reprcrash = reprcrash self.reprcrash = reprcrash
def toterminal(self, tw): def toterminal(self, tw) -> None:
self.reprtraceback.toterminal(tw) self.reprtraceback.toterminal(tw)
super().toterminal(tw) super().toterminal(tw)
@ -927,7 +930,7 @@ class ReprTraceback(TerminalRepr):
self.extraline = extraline self.extraline = extraline
self.style = style self.style = style
def toterminal(self, tw): def toterminal(self, tw) -> None:
# the entries might have different styles # the entries might have different styles
for i, entry in enumerate(self.reprentries): for i, entry in enumerate(self.reprentries):
if entry.style == "long": if entry.style == "long":
@ -959,7 +962,7 @@ class ReprEntryNative(TerminalRepr):
def __init__(self, tblines): def __init__(self, tblines):
self.lines = tblines self.lines = tblines
def toterminal(self, tw): def toterminal(self, tw) -> None:
tw.write("".join(self.lines)) tw.write("".join(self.lines))
@ -971,7 +974,7 @@ class ReprEntry(TerminalRepr):
self.reprfileloc = filelocrepr self.reprfileloc = filelocrepr
self.style = style self.style = style
def toterminal(self, tw): def toterminal(self, tw) -> None:
if self.style == "short": if self.style == "short":
self.reprfileloc.toterminal(tw) self.reprfileloc.toterminal(tw)
for line in self.lines: for line in self.lines:
@ -1003,7 +1006,7 @@ class ReprFileLocation(TerminalRepr):
self.lineno = lineno self.lineno = lineno
self.message = message self.message = message
def toterminal(self, tw): def toterminal(self, tw) -> None:
# filename and lineno output for each entry, # filename and lineno output for each entry,
# using an output format that most editors unterstand # using an output format that most editors unterstand
msg = self.message msg = self.message
@ -1018,7 +1021,7 @@ class ReprLocals(TerminalRepr):
def __init__(self, lines): def __init__(self, lines):
self.lines = lines self.lines = lines
def toterminal(self, tw): def toterminal(self, tw) -> None:
for line in self.lines: for line in self.lines:
tw.line(line) tw.line(line)
@ -1027,7 +1030,7 @@ class ReprFuncArgs(TerminalRepr):
def __init__(self, args): def __init__(self, args):
self.args = args self.args = args
def toterminal(self, tw): def toterminal(self, tw) -> None:
if self.args: if self.args:
linesofar = "" linesofar = ""
for name, value in self.args: for name, value in self.args:

View File

@ -305,7 +305,7 @@ class DoctestItem(pytest.Item):
else: else:
return super().repr_failure(excinfo) return super().repr_failure(excinfo)
def reportinfo(self): def reportinfo(self) -> Tuple[str, int, str]:
return self.fspath, self.dtest.lineno, "[doctest] %s" % self.name return self.fspath, self.dtest.lineno, "[doctest] %s" % self.name

View File

@ -7,13 +7,13 @@ from collections import defaultdict
from collections import deque from collections import deque
from collections import OrderedDict from collections import OrderedDict
from typing import Dict from typing import Dict
from typing import List
from typing import Tuple from typing import Tuple
import attr import attr
import py import py
import _pytest import _pytest
from _pytest import nodes
from _pytest._code.code import FormattedExcinfo from _pytest._code.code import FormattedExcinfo
from _pytest._code.code import TerminalRepr from _pytest._code.code import TerminalRepr
from _pytest.compat import _format_args from _pytest.compat import _format_args
@ -35,6 +35,8 @@ from _pytest.outcomes import TEST_OUTCOME
if False: # TYPE_CHECKING if False: # TYPE_CHECKING
from typing import Type from typing import Type
from _pytest import nodes
@attr.s(frozen=True) @attr.s(frozen=True)
class PseudoFixtureDef: class PseudoFixtureDef:
@ -689,8 +691,8 @@ class FixtureLookupError(LookupError):
self.fixturestack = request._get_fixturestack() self.fixturestack = request._get_fixturestack()
self.msg = msg self.msg = msg
def formatrepr(self): def formatrepr(self) -> "FixtureLookupErrorRepr":
tblines = [] tblines = [] # type: List[str]
addline = tblines.append addline = tblines.append
stack = [self.request._pyfuncitem.obj] stack = [self.request._pyfuncitem.obj]
stack.extend(map(lambda x: x.func, self.fixturestack)) stack.extend(map(lambda x: x.func, self.fixturestack))
@ -742,7 +744,7 @@ class FixtureLookupErrorRepr(TerminalRepr):
self.firstlineno = firstlineno self.firstlineno = firstlineno
self.argname = argname self.argname = argname
def toterminal(self, tw): def toterminal(self, tw) -> None:
# tw.line("FixtureLookupError: %s" %(self.argname), red=True) # tw.line("FixtureLookupError: %s" %(self.argname), red=True)
for tbline in self.tblines: for tbline in self.tblines:
tw.line(tbline.rstrip()) tw.line(tbline.rstrip())
@ -1283,6 +1285,8 @@ class FixtureManager:
except AttributeError: except AttributeError:
pass pass
else: else:
from _pytest import nodes
# construct the base nodeid which is later used to check # construct the base nodeid which is later used to check
# what fixtures are visible for particular tests (as denoted # what fixtures are visible for particular tests (as denoted
# by their test id) # by their test id)
@ -1459,6 +1463,8 @@ class FixtureManager:
return tuple(self._matchfactories(fixturedefs, nodeid)) return tuple(self._matchfactories(fixturedefs, nodeid))
def _matchfactories(self, fixturedefs, nodeid): def _matchfactories(self, fixturedefs, nodeid):
from _pytest import nodes
for fixturedef in fixturedefs: for fixturedef in fixturedefs:
if nodes.ischildnode(fixturedef.baseid, nodeid): if nodes.ischildnode(fixturedef.baseid, nodeid):
yield fixturedef yield fixturedef

View File

@ -5,6 +5,7 @@ import functools
import importlib import importlib
import os import os
import sys import sys
from typing import Dict
import attr import attr
import py import py
@ -16,6 +17,7 @@ from _pytest.config import hookimpl
from _pytest.config import UsageError from _pytest.config import UsageError
from _pytest.outcomes import exit from _pytest.outcomes import exit
from _pytest.runner import collect_one_node from _pytest.runner import collect_one_node
from _pytest.runner import SetupState
class ExitCode(enum.IntEnum): class ExitCode(enum.IntEnum):
@ -359,8 +361,8 @@ class Failed(Exception):
class _bestrelpath_cache(dict): class _bestrelpath_cache(dict):
path = attr.ib() path = attr.ib()
def __missing__(self, path): def __missing__(self, path: str) -> str:
r = self.path.bestrelpath(path) r = self.path.bestrelpath(path) # type: str
self[path] = r self[path] = r
return r return r
@ -368,6 +370,7 @@ class _bestrelpath_cache(dict):
class Session(nodes.FSCollector): class Session(nodes.FSCollector):
Interrupted = Interrupted Interrupted = Interrupted
Failed = Failed Failed = Failed
_setupstate = None # type: SetupState
def __init__(self, config): def __init__(self, config):
nodes.FSCollector.__init__( nodes.FSCollector.__init__(
@ -383,7 +386,9 @@ class Session(nodes.FSCollector):
self._initialpaths = frozenset() self._initialpaths = frozenset()
# Keep track of any collected nodes in here, so we don't duplicate fixtures # Keep track of any collected nodes in here, so we don't duplicate fixtures
self._node_cache = {} self._node_cache = {}
self._bestrelpathcache = _bestrelpath_cache(config.rootdir) self._bestrelpathcache = _bestrelpath_cache(
config.rootdir
) # type: Dict[str, str]
# Dirnames of pkgs with dunder-init files. # Dirnames of pkgs with dunder-init files.
self._pkg_roots = {} self._pkg_roots = {}
@ -398,7 +403,7 @@ class Session(nodes.FSCollector):
self.testscollected, self.testscollected,
) )
def _node_location_to_relpath(self, node_path): def _node_location_to_relpath(self, node_path: str) -> str:
# bestrelpath is a quite slow function # bestrelpath is a quite slow function
return self._bestrelpathcache[node_path] return self._bestrelpathcache[node_path]

View File

@ -4,6 +4,7 @@ from functools import lru_cache
from typing import Any from typing import Any
from typing import Dict from typing import Dict
from typing import List from typing import List
from typing import Optional
from typing import Set from typing import Set
from typing import Tuple from typing import Tuple
from typing import Union from typing import Union
@ -11,15 +12,21 @@ from typing import Union
import py import py
import _pytest._code import _pytest._code
from _pytest._code.code import ExceptionChainRepr
from _pytest._code.code import ExceptionInfo
from _pytest._code.code import ReprExceptionInfo
from _pytest.compat import getfslineno from _pytest.compat import getfslineno
from _pytest.fixtures import FixtureDef
from _pytest.fixtures import FixtureLookupError
from _pytest.fixtures import FixtureLookupErrorRepr
from _pytest.mark.structures import Mark from _pytest.mark.structures import Mark
from _pytest.mark.structures import MarkDecorator from _pytest.mark.structures import MarkDecorator
from _pytest.mark.structures import NodeKeywords from _pytest.mark.structures import NodeKeywords
from _pytest.outcomes import fail from _pytest.outcomes import Failed
if False: # TYPE_CHECKING if False: # TYPE_CHECKING
# Imported here due to circular import. # Imported here due to circular import.
from _pytest.fixtures import FixtureDef from _pytest.main import Session # noqa: F401
SEP = "/" SEP = "/"
@ -69,8 +76,14 @@ class Node:
Collector subclasses have children, Items are terminal nodes.""" Collector subclasses have children, Items are terminal nodes."""
def __init__( def __init__(
self, name, parent=None, config=None, session=None, fspath=None, nodeid=None self,
): name,
parent=None,
config=None,
session: Optional["Session"] = None,
fspath=None,
nodeid=None,
) -> None:
#: a unique name within the scope of the parent node #: a unique name within the scope of the parent node
self.name = name self.name = name
@ -81,7 +94,11 @@ class Node:
self.config = config or parent.config self.config = config or parent.config
#: the session this node is part of #: the session this node is part of
self.session = session or parent.session if session is None:
assert parent.session is not None
self.session = parent.session
else:
self.session = session
#: filesystem path where this node was collected from (can be None) #: filesystem path where this node was collected from (can be None)
self.fspath = fspath or getattr(parent, "fspath", None) self.fspath = fspath or getattr(parent, "fspath", None)
@ -254,13 +271,13 @@ class Node:
def _prunetraceback(self, excinfo): def _prunetraceback(self, excinfo):
pass pass
def _repr_failure_py(self, excinfo, style=None): def _repr_failure_py(
# Type ignored: see comment where fail.Exception is defined. self, excinfo: ExceptionInfo[Union[Failed, FixtureLookupError]], style=None
if excinfo.errisinstance(fail.Exception): # type: ignore ) -> Union[str, ReprExceptionInfo, ExceptionChainRepr, FixtureLookupErrorRepr]:
if isinstance(excinfo.value, Failed):
if not excinfo.value.pytrace: if not excinfo.value.pytrace:
return str(excinfo.value) return str(excinfo.value)
fm = self.session._fixturemanager if isinstance(excinfo.value, FixtureLookupError):
if excinfo.errisinstance(fm.FixtureLookupError):
return excinfo.value.formatrepr() return excinfo.value.formatrepr()
if self.config.getoption("fulltrace", False): if self.config.getoption("fulltrace", False):
style = "long" style = "long"
@ -298,7 +315,9 @@ class Node:
truncate_locals=truncate_locals, truncate_locals=truncate_locals,
) )
def repr_failure(self, excinfo, style=None): def repr_failure(
self, excinfo, style=None
) -> Union[str, ReprExceptionInfo, ExceptionChainRepr, FixtureLookupErrorRepr]:
return self._repr_failure_py(excinfo, style) return self._repr_failure_py(excinfo, style)
@ -425,16 +444,20 @@ class Item(Node):
if content: if content:
self._report_sections.append((when, key, content)) self._report_sections.append((when, key, content))
def reportinfo(self): def reportinfo(self) -> Tuple[str, Optional[int], str]:
return self.fspath, None, "" return self.fspath, None, ""
@property @property
def location(self): def location(self) -> Tuple[str, Optional[int], str]:
try: try:
return self._location return self._location
except AttributeError: except AttributeError:
location = self.reportinfo() location = self.reportinfo()
fspath = self.session._node_location_to_relpath(location[0]) fspath = self.session._node_location_to_relpath(location[0])
location = (fspath, location[1], str(location[2])) assert type(location[2]) is str
self._location = location self._location = (
return location fspath,
location[1],
location[2],
) # type: Tuple[str, Optional[int], str]
return self._location

View File

@ -9,6 +9,7 @@ from collections import Counter
from collections.abc import Sequence from collections.abc import Sequence
from functools import partial from functools import partial
from textwrap import dedent from textwrap import dedent
from typing import Tuple
import py import py
@ -288,7 +289,7 @@ class PyobjMixin(PyobjContext):
s = ".".join(parts) s = ".".join(parts)
return s.replace(".[", "[") return s.replace(".[", "[")
def reportinfo(self): def reportinfo(self) -> Tuple[str, int, str]:
# XXX caching? # XXX caching?
obj = self.obj obj = self.obj
compat_co_firstlineno = getattr(obj, "compat_co_firstlineno", None) compat_co_firstlineno = getattr(obj, "compat_co_firstlineno", None)

View File

@ -1,6 +1,8 @@
from io import StringIO from io import StringIO
from pprint import pprint from pprint import pprint
from typing import List
from typing import Optional from typing import Optional
from typing import Tuple
from typing import Union from typing import Union
import py import py
@ -15,6 +17,7 @@ from _pytest._code.code import ReprFuncArgs
from _pytest._code.code import ReprLocals from _pytest._code.code import ReprLocals
from _pytest._code.code import ReprTraceback from _pytest._code.code import ReprTraceback
from _pytest._code.code import TerminalRepr from _pytest._code.code import TerminalRepr
from _pytest.nodes import Node
from _pytest.outcomes import skip from _pytest.outcomes import skip
from _pytest.pathlib import Path from _pytest.pathlib import Path
@ -34,13 +37,16 @@ def getslaveinfoline(node):
class BaseReport: class BaseReport:
when = None # type: Optional[str] when = None # type: Optional[str]
location = None location = None
longrepr = None
sections = [] # type: List[Tuple[str, str]]
nodeid = None # type: str
def __init__(self, **kw): def __init__(self, **kw):
self.__dict__.update(kw) self.__dict__.update(kw)
def toterminal(self, out): def toterminal(self, out) -> None:
if hasattr(self, "node"): if hasattr(self, "node"):
out.line(getslaveinfoline(self.node)) out.line(getslaveinfoline(self.node)) # type: ignore
longrepr = self.longrepr longrepr = self.longrepr
if longrepr is None: if longrepr is None:
@ -300,7 +306,9 @@ class TestReport(BaseReport):
class CollectReport(BaseReport): class CollectReport(BaseReport):
when = "collect" when = "collect"
def __init__(self, nodeid, outcome, longrepr, result, sections=(), **extra): def __init__(
self, nodeid: str, outcome, longrepr, result: List[Node], sections=(), **extra
) -> None:
self.nodeid = nodeid self.nodeid = nodeid
self.outcome = outcome self.outcome = outcome
self.longrepr = longrepr self.longrepr = longrepr
@ -322,7 +330,7 @@ class CollectErrorRepr(TerminalRepr):
def __init__(self, msg): def __init__(self, msg):
self.longrepr = msg self.longrepr = msg
def toterminal(self, out): def toterminal(self, out) -> None:
out.line(self.longrepr, red=True) out.line(self.longrepr, red=True)
@ -472,7 +480,9 @@ def _report_kwargs_from_json(reportdict):
description, description,
) )
) )
exception_info = ExceptionChainRepr(chain) exception_info = ExceptionChainRepr(
chain
) # type: Union[ExceptionChainRepr,ReprExceptionInfo]
else: else:
exception_info = ReprExceptionInfo(reprtraceback, reprcrash) exception_info = ReprExceptionInfo(reprtraceback, reprcrash)

View File

@ -6,6 +6,7 @@ from time import time
from typing import Callable from typing import Callable
from typing import Dict from typing import Dict
from typing import List from typing import List
from typing import Optional
from typing import Tuple from typing import Tuple
import attr import attr
@ -207,8 +208,7 @@ class CallInfo:
""" Result/Exception info a function invocation. """ """ Result/Exception info a function invocation. """
_result = attr.ib() _result = attr.ib()
# Optional[ExceptionInfo] excinfo = attr.ib(type=Optional[ExceptionInfo])
excinfo = attr.ib()
start = attr.ib() start = attr.ib()
stop = attr.ib() stop = attr.ib()
when = attr.ib() when = attr.ib()
@ -220,7 +220,7 @@ class CallInfo:
return self._result return self._result
@classmethod @classmethod
def from_call(cls, func, when, reraise=None): def from_call(cls, func, when, reraise=None) -> "CallInfo":
#: context of invocation: one of "setup", "call", #: context of invocation: one of "setup", "call",
#: "teardown", "memocollect" #: "teardown", "memocollect"
start = time() start = time()

View File

@ -902,7 +902,7 @@ raise ValueError()
from _pytest._code.code import TerminalRepr from _pytest._code.code import TerminalRepr
class MyRepr(TerminalRepr): class MyRepr(TerminalRepr):
def toterminal(self, tw): def toterminal(self, tw) -> None:
tw.line("я") tw.line("я")
x = str(MyRepr()) x = str(MyRepr())