Some minor typing tweaks

This commit is contained in:
Ran Benita 2023-12-19 23:29:27 +02:00
parent 581762fcba
commit 75f292d9df
2 changed files with 15 additions and 7 deletions

View File

@ -11,15 +11,18 @@ from datetime import timezone
from io import StringIO from io import StringIO
from logging import LogRecord from logging import LogRecord
from pathlib import Path from pathlib import Path
from types import TracebackType
from typing import AbstractSet from typing import AbstractSet
from typing import Dict from typing import Dict
from typing import final from typing import final
from typing import Generator from typing import Generator
from typing import Generic
from typing import List from typing import List
from typing import Literal from typing import Literal
from typing import Mapping from typing import Mapping
from typing import Optional from typing import Optional
from typing import Tuple from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import TypeVar from typing import TypeVar
from typing import Union from typing import Union
@ -62,7 +65,7 @@ class DatetimeFormatter(logging.Formatter):
:func:`time.strftime` in case of microseconds in format string. :func:`time.strftime` in case of microseconds in format string.
""" """
def formatTime(self, record: LogRecord, datefmt=None) -> str: def formatTime(self, record: LogRecord, datefmt: Optional[str] = None) -> str:
if datefmt and "%f" in datefmt: if datefmt and "%f" in datefmt:
ct = self.converter(record.created) ct = self.converter(record.created)
tz = timezone(timedelta(seconds=ct.tm_gmtoff), ct.tm_zone) tz = timezone(timedelta(seconds=ct.tm_gmtoff), ct.tm_zone)
@ -331,7 +334,7 @@ _HandlerType = TypeVar("_HandlerType", bound=logging.Handler)
# Not using @contextmanager for performance reasons. # Not using @contextmanager for performance reasons.
class catching_logs: class catching_logs(Generic[_HandlerType]):
"""Context manager that prepares the whole logging machinery properly.""" """Context manager that prepares the whole logging machinery properly."""
__slots__ = ("handler", "level", "orig_level") __slots__ = ("handler", "level", "orig_level")
@ -340,7 +343,7 @@ class catching_logs:
self.handler = handler self.handler = handler
self.level = level self.level = level
def __enter__(self): def __enter__(self) -> _HandlerType:
root_logger = logging.getLogger() root_logger = logging.getLogger()
if self.level is not None: if self.level is not None:
self.handler.setLevel(self.level) self.handler.setLevel(self.level)
@ -350,7 +353,12 @@ class catching_logs:
root_logger.setLevel(min(self.orig_level, self.level)) root_logger.setLevel(min(self.orig_level, self.level))
return self.handler return self.handler
def __exit__(self, type, value, traceback): def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
root_logger = logging.getLogger() root_logger = logging.getLogger()
if self.level is not None: if self.level is not None:
root_logger.setLevel(self.orig_level) root_logger.setLevel(self.orig_level)
@ -421,7 +429,7 @@ class LogCaptureFixture:
return self._item.stash[caplog_handler_key] return self._item.stash[caplog_handler_key]
def get_records( def get_records(
self, when: "Literal['setup', 'call', 'teardown']" self, when: Literal["setup", "call", "teardown"]
) -> List[logging.LogRecord]: ) -> List[logging.LogRecord]:
"""Get the logging records for one of the possible test phases. """Get the logging records for one of the possible test phases.
@ -742,7 +750,7 @@ class LoggingPlugin:
if old_stream: if old_stream:
old_stream.close() old_stream.close()
def _log_cli_enabled(self): def _log_cli_enabled(self) -> bool:
"""Return whether live logging is enabled.""" """Return whether live logging is enabled."""
enabled = self._config.getoption( enabled = self._config.getoption(
"--log-cli-level" "--log-cli-level"

View File

@ -317,7 +317,7 @@ class CallInfo(Generic[TResult]):
@classmethod @classmethod
def from_call( def from_call(
cls, cls,
func: "Callable[[], TResult]", func: Callable[[], TResult],
when: Literal["collect", "setup", "call", "teardown"], when: Literal["collect", "setup", "call", "teardown"],
reraise: Optional[ reraise: Optional[
Union[Type[BaseException], Tuple[Type[BaseException], ...]] Union[Type[BaseException], Tuple[Type[BaseException], ...]]