Have pytest.raises match against exception `__notes__` (#11227)

The doctest is skipped because add_note is only available in 3.11,

Closes #11223
This commit is contained in:
Isaac Virshup 2023-07-18 13:39:39 +02:00 committed by GitHub
parent 7c30f674c5
commit 1de0923e83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 71 additions and 8 deletions

View File

@ -168,6 +168,7 @@ Ian Bicking
Ian Lesperance Ian Lesperance
Ilya Konstantinov Ilya Konstantinov
Ionuț Turturică Ionuț Turturică
Isaac Virshup
Itxaso Aizpurua Itxaso Aizpurua
Iwan Briquemont Iwan Briquemont
Jaap Broekhuizen Jaap Broekhuizen

View File

@ -0,0 +1 @@
Allow :func:`pytest.raises` ``match`` argument to match against `PEP-678 <https://peps.python.org/pep-0678/>` ``__notes__``.

View File

@ -704,7 +704,12 @@ class ExceptionInfo(Generic[E]):
If it matches `True` is returned, otherwise an `AssertionError` is raised. If it matches `True` is returned, otherwise an `AssertionError` is raised.
""" """
__tracebackhide__ = True __tracebackhide__ = True
value = str(self.value) value = "\n".join(
[
str(self.value),
*getattr(self.value, "__notes__", []),
]
)
msg = f"Regex pattern did not match.\n Regex: {regexp!r}\n Input: {value!r}" msg = f"Regex pattern did not match.\n Regex: {regexp!r}\n Input: {value!r}"
if regexp == value: if regexp == value:
msg += "\n Did you mean to `re.escape()` the regex?" msg += "\n Did you mean to `re.escape()` the regex?"

View File

@ -843,6 +843,14 @@ def raises( # noqa: F811
>>> with pytest.raises(ValueError, match=r'must be \d+$'): >>> with pytest.raises(ValueError, match=r'must be \d+$'):
... raise ValueError("value must be 42") ... raise ValueError("value must be 42")
The ``match`` argument searches the formatted exception string, which includes any
`PEP-678 <https://peps.python.org/pep-0678/>` ``__notes__``:
>>> with pytest.raises(ValueError, match=r'had a note added'): # doctest: +SKIP
... e = ValueError("value must be 42")
... e.add_note("had a note added")
... raise e
The context manager produces an :class:`ExceptionInfo` object which can be used to inspect the The context manager produces an :class:`ExceptionInfo` object which can be used to inspect the
details of the captured exception:: details of the captured exception::

View File

@ -1,15 +1,15 @@
from __future__ import annotations
import importlib import importlib
import io import io
import operator import operator
import queue import queue
import re
import sys import sys
import textwrap import textwrap
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from typing import Dict
from typing import Tuple
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import Union
import _pytest._code import _pytest._code
import pytest import pytest
@ -801,7 +801,7 @@ raise ValueError()
) )
excinfo = pytest.raises(ValueError, mod.entry) excinfo = pytest.raises(ValueError, mod.entry)
styles: Tuple[_TracebackStyle, ...] = ("long", "short") styles: tuple[_TracebackStyle, ...] = ("long", "short")
for style in styles: for style in styles:
p = FormattedExcinfo(style=style) p = FormattedExcinfo(style=style)
reprtb = p.repr_traceback(excinfo) reprtb = p.repr_traceback(excinfo)
@ -928,7 +928,7 @@ raise ValueError()
) )
excinfo = pytest.raises(ValueError, mod.entry) excinfo = pytest.raises(ValueError, mod.entry)
styles: Tuple[_TracebackStyle, ...] = ("short", "long", "no") styles: tuple[_TracebackStyle, ...] = ("short", "long", "no")
for style in styles: for style in styles:
for showlocals in (True, False): for showlocals in (True, False):
repr = excinfo.getrepr(style=style, showlocals=showlocals) repr = excinfo.getrepr(style=style, showlocals=showlocals)
@ -1090,7 +1090,7 @@ raise ValueError()
for funcargs in (True, False) for funcargs in (True, False)
], ],
) )
def test_format_excinfo(self, reproptions: Dict[str, Any]) -> None: def test_format_excinfo(self, reproptions: dict[str, Any]) -> None:
def bar(): def bar():
assert False, "some error" assert False, "some error"
@ -1398,7 +1398,7 @@ raise ValueError()
@pytest.mark.parametrize("encoding", [None, "utf8", "utf16"]) @pytest.mark.parametrize("encoding", [None, "utf8", "utf16"])
def test_repr_traceback_with_unicode(style, encoding): def test_repr_traceback_with_unicode(style, encoding):
if encoding is None: if encoding is None:
msg: Union[str, bytes] = "" msg: str | bytes = ""
else: else:
msg = "".encode(encoding) msg = "".encode(encoding)
try: try:
@ -1648,3 +1648,51 @@ def test_hidden_entries_of_chained_exceptions_are_not_shown(pytester: Pytester)
], ],
consecutive=True, consecutive=True,
) )
def add_note(err: BaseException, msg: str) -> None:
"""Adds a note to an exception inplace."""
if sys.version_info < (3, 11):
err.__notes__ = getattr(err, "__notes__", []) + [msg] # type: ignore[attr-defined]
else:
err.add_note(msg)
@pytest.mark.parametrize(
"error,notes,match",
[
(Exception("test"), [], "test"),
(AssertionError("foo"), ["bar"], "bar"),
(AssertionError("foo"), ["bar", "baz"], "bar"),
(AssertionError("foo"), ["bar", "baz"], "baz"),
(ValueError("foo"), ["bar", "baz"], re.compile(r"bar\nbaz", re.MULTILINE)),
(ValueError("foo"), ["bar", "baz"], re.compile(r"BAZ", re.IGNORECASE)),
],
)
def test_check_error_notes_success(
error: Exception, notes: list[str], match: str
) -> None:
for note in notes:
add_note(error, note)
with pytest.raises(Exception, match=match):
raise error
@pytest.mark.parametrize(
"error, notes, match",
[
(Exception("test"), [], "foo"),
(AssertionError("foo"), ["bar"], "baz"),
(AssertionError("foo"), ["bar"], "foo\nbaz"),
],
)
def test_check_error_notes_failure(
error: Exception, notes: list[str], match: str
) -> None:
for note in notes:
add_note(error, note)
with pytest.raises(AssertionError):
with pytest.raises(type(error), match=match):
raise error