diff --git a/src/_pytest/logging.py b/src/_pytest/logging.py index 1c6bb923b..5426c3513 100644 --- a/src/_pytest/logging.py +++ b/src/_pytest/logging.py @@ -11,15 +11,18 @@ from datetime import timezone from io import StringIO from logging import LogRecord from pathlib import Path +from types import TracebackType from typing import AbstractSet from typing import Dict from typing import final from typing import Generator +from typing import Generic from typing import List from typing import Literal from typing import Mapping from typing import Optional from typing import Tuple +from typing import Type from typing import TYPE_CHECKING from typing import TypeVar from typing import Union @@ -62,7 +65,7 @@ class DatetimeFormatter(logging.Formatter): :func:`time.strftime` in case of microseconds in format string. """ - def formatTime(self, record: LogRecord, datefmt=None) -> str: + def formatTime(self, record: LogRecord, datefmt: Optional[str] = None) -> str: if datefmt and "%f" in datefmt: ct = self.converter(record.created) tz = timezone(timedelta(seconds=ct.tm_gmtoff), ct.tm_zone) @@ -331,7 +334,7 @@ _HandlerType = TypeVar("_HandlerType", bound=logging.Handler) # Not using @contextmanager for performance reasons. -class catching_logs: +class catching_logs(Generic[_HandlerType]): """Context manager that prepares the whole logging machinery properly.""" __slots__ = ("handler", "level", "orig_level") @@ -340,7 +343,7 @@ class catching_logs: self.handler = handler self.level = level - def __enter__(self): + def __enter__(self) -> _HandlerType: root_logger = logging.getLogger() if self.level is not None: self.handler.setLevel(self.level) @@ -350,7 +353,12 @@ class catching_logs: root_logger.setLevel(min(self.orig_level, self.level)) return self.handler - def __exit__(self, type, value, traceback): + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: root_logger = logging.getLogger() if self.level is not None: root_logger.setLevel(self.orig_level) @@ -421,7 +429,7 @@ class LogCaptureFixture: return self._item.stash[caplog_handler_key] def get_records( - self, when: "Literal['setup', 'call', 'teardown']" + self, when: Literal["setup", "call", "teardown"] ) -> List[logging.LogRecord]: """Get the logging records for one of the possible test phases. @@ -742,7 +750,7 @@ class LoggingPlugin: if old_stream: old_stream.close() - def _log_cli_enabled(self): + def _log_cli_enabled(self) -> bool: """Return whether live logging is enabled.""" enabled = self._config.getoption( "--log-cli-level" diff --git a/src/_pytest/runner.py b/src/_pytest/runner.py index 1b39f93cf..c03d707dc 100644 --- a/src/_pytest/runner.py +++ b/src/_pytest/runner.py @@ -317,7 +317,7 @@ class CallInfo(Generic[TResult]): @classmethod def from_call( cls, - func: "Callable[[], TResult]", + func: Callable[[], TResult], when: Literal["collect", "setup", "call", "teardown"], reraise: Optional[ Union[Type[BaseException], Tuple[Type[BaseException], ...]]