Type annotate _pytest.logging

This commit is contained in:
Ran Benita 2020-05-01 14:40:16 +03:00
parent 90e58f8961
commit db52928684
1 changed files with 59 additions and 41 deletions

View File

@ -11,18 +11,24 @@ from typing import Generator
from typing import List
from typing import Mapping
from typing import Optional
from typing import Tuple
from typing import TypeVar
from typing import Union
import pytest
from _pytest import nodes
from _pytest._io import TerminalWriter
from _pytest.capture import CaptureManager
from _pytest.compat import nullcontext
from _pytest.config import _strtobool
from _pytest.config import Config
from _pytest.config import create_terminal_writer
from _pytest.config.argparsing import Parser
from _pytest.fixtures import FixtureRequest
from _pytest.main import Session
from _pytest.pathlib import Path
from _pytest.store import StoreKey
from _pytest.terminal import TerminalReporter
DEFAULT_LOG_FORMAT = "%(levelname)-8s %(name)s:%(filename)s:%(lineno)d %(message)s"
@ -32,7 +38,7 @@ catch_log_handler_key = StoreKey["LogCaptureHandler"]()
catch_log_records_key = StoreKey[Dict[str, List[logging.LogRecord]]]()
def _remove_ansi_escape_sequences(text):
def _remove_ansi_escape_sequences(text: str) -> str:
return _ANSI_ESCAPE_SEQ.sub("", text)
@ -52,7 +58,7 @@ class ColoredLevelFormatter(logging.Formatter):
} # type: Mapping[int, AbstractSet[str]]
LEVELNAME_FMT_REGEX = re.compile(r"%\(levelname\)([+-.]?\d*s)")
def __init__(self, terminalwriter, *args, **kwargs) -> None:
def __init__(self, terminalwriter: TerminalWriter, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._original_fmt = self._style._fmt
self._level_to_fmt_mapping = {} # type: Dict[int, str]
@ -77,7 +83,7 @@ class ColoredLevelFormatter(logging.Formatter):
colorized_formatted_levelname, self._fmt
)
def format(self, record):
def format(self, record: logging.LogRecord) -> str:
fmt = self._level_to_fmt_mapping.get(record.levelno, self._original_fmt)
self._style._fmt = fmt
return super().format(record)
@ -90,18 +96,20 @@ class PercentStyleMultiline(logging.PercentStyle):
formats the message as if each line were logged separately.
"""
def __init__(self, fmt, auto_indent):
def __init__(self, fmt: str, auto_indent: Union[int, str, bool]) -> None:
super().__init__(fmt)
self._auto_indent = self._get_auto_indent(auto_indent)
@staticmethod
def _update_message(record_dict, message):
def _update_message(
record_dict: Dict[str, object], message: str
) -> Dict[str, object]:
tmp = record_dict.copy()
tmp["message"] = message
return tmp
@staticmethod
def _get_auto_indent(auto_indent_option) -> int:
def _get_auto_indent(auto_indent_option: Union[int, str, bool]) -> int:
"""Determines the current auto indentation setting
Specify auto indent behavior (on/off/fixed) by passing in
@ -149,11 +157,11 @@ class PercentStyleMultiline(logging.PercentStyle):
return 0
def format(self, record):
def format(self, record: logging.LogRecord) -> str:
if "\n" in record.message:
if hasattr(record, "auto_indent"):
# passed in from the "extra={}" kwarg on the call to logging.log()
auto_indent = self._get_auto_indent(record.auto_indent)
auto_indent = self._get_auto_indent(record.auto_indent) # type: ignore[attr-defined] # noqa: F821
else:
auto_indent = self._auto_indent
@ -173,7 +181,7 @@ class PercentStyleMultiline(logging.PercentStyle):
return self._fmt % record.__dict__
def get_option_ini(config, *names):
def get_option_ini(config: Config, *names: str):
for name in names:
ret = config.getoption(name) # 'default' arg won't work as expected
if ret is None:
@ -268,13 +276,16 @@ def pytest_addoption(parser: Parser) -> None:
)
_HandlerType = TypeVar("_HandlerType", bound=logging.Handler)
# Not using @contextmanager for performance reasons.
class catching_logs:
"""Context manager that prepares the whole logging machinery properly."""
__slots__ = ("handler", "level", "orig_level")
def __init__(self, handler, level=None):
def __init__(self, handler: _HandlerType, level: Optional[int] = None) -> None:
self.handler = handler
self.level = level
@ -330,7 +341,7 @@ class LogCaptureFixture:
"""Creates a new funcarg."""
self._item = item
# dict of log name -> log level
self._initial_log_levels = {} # type: Dict[str, int]
self._initial_log_levels = {} # type: Dict[Optional[str], int]
def _finalize(self) -> None:
"""Finalizes the fixture.
@ -364,17 +375,17 @@ class LogCaptureFixture:
return self._item._store[catch_log_records_key].get(when, [])
@property
def text(self):
def text(self) -> str:
"""Returns the formatted log text."""
return _remove_ansi_escape_sequences(self.handler.stream.getvalue())
@property
def records(self):
def records(self) -> List[logging.LogRecord]:
"""Returns the list of log records."""
return self.handler.records
@property
def record_tuples(self):
def record_tuples(self) -> List[Tuple[str, int, str]]:
"""Returns a list of a stripped down version of log records intended
for use in assertion comparison.
@ -385,7 +396,7 @@ class LogCaptureFixture:
return [(r.name, r.levelno, r.getMessage()) for r in self.records]
@property
def messages(self):
def messages(self) -> List[str]:
"""Returns a list of format-interpolated log messages.
Unlike 'records', which contains the format string and parameters for interpolation, log messages in this list
@ -400,11 +411,11 @@ class LogCaptureFixture:
"""
return [r.getMessage() for r in self.records]
def clear(self):
def clear(self) -> None:
"""Reset the list of log records and the captured log text."""
self.handler.reset()
def set_level(self, level, logger=None):
def set_level(self, level: Union[int, str], logger: Optional[str] = None) -> None:
"""Sets the level for capturing of logs. The level will be restored to its previous value at the end of
the test.
@ -415,31 +426,32 @@ class LogCaptureFixture:
The levels of the loggers changed by this function will be restored to their initial values at the
end of the test.
"""
logger_name = logger
logger = logging.getLogger(logger_name)
logger_obj = logging.getLogger(logger)
# save the original log-level to restore it during teardown
self._initial_log_levels.setdefault(logger_name, logger.level)
logger.setLevel(level)
self._initial_log_levels.setdefault(logger, logger_obj.level)
logger_obj.setLevel(level)
@contextmanager
def at_level(self, level, logger=None):
def at_level(
self, level: int, logger: Optional[str] = None
) -> Generator[None, None, None]:
"""Context manager that sets the level for capturing of logs. After the end of the 'with' statement the
level is restored to its original value.
:param int level: the logger to level.
:param str logger: the logger to update the level. If not given, the root logger level is updated.
"""
logger = logging.getLogger(logger)
orig_level = logger.level
logger.setLevel(level)
logger_obj = logging.getLogger(logger)
orig_level = logger_obj.level
logger_obj.setLevel(level)
try:
yield
finally:
logger.setLevel(orig_level)
logger_obj.setLevel(orig_level)
@pytest.fixture
def caplog(request):
def caplog(request: FixtureRequest) -> Generator[LogCaptureFixture, None, None]:
"""Access and control log capturing.
Captured logs are available through the following properties/methods::
@ -557,7 +569,7 @@ class LoggingPlugin:
return formatter
def set_log_path(self, fname):
def set_log_path(self, fname: str) -> None:
"""Public method, which can set filename parameter for
Logging.FileHandler(). Also creates parent directory if
it does not exist.
@ -565,15 +577,15 @@ class LoggingPlugin:
.. warning::
Please considered as an experimental API.
"""
fname = Path(fname)
fpath = Path(fname)
if not fname.is_absolute():
fname = Path(self._config.rootdir, fname)
if not fpath.is_absolute():
fpath = Path(self._config.rootdir, fpath)
if not fname.parent.exists():
fname.parent.mkdir(exist_ok=True, parents=True)
if not fpath.parent.exists():
fpath.parent.mkdir(exist_ok=True, parents=True)
stream = fname.open(mode="w", encoding="UTF-8")
stream = fpath.open(mode="w", encoding="UTF-8")
if sys.version_info >= (3, 7):
old_stream = self.log_file_handler.setStream(stream)
else:
@ -715,29 +727,35 @@ class _LiveLoggingStreamHandler(logging.StreamHandler):
and won't appear in the terminal.
"""
def __init__(self, terminal_reporter, capture_manager):
# Officially stream needs to be a IO[str], but TerminalReporter
# isn't. So force it.
stream = None # type: TerminalReporter # type: ignore
def __init__(
self, terminal_reporter: TerminalReporter, capture_manager: CaptureManager
) -> None:
"""
:param _pytest.terminal.TerminalReporter terminal_reporter:
:param _pytest.capture.CaptureManager capture_manager:
"""
logging.StreamHandler.__init__(self, stream=terminal_reporter)
logging.StreamHandler.__init__(self, stream=terminal_reporter) # type: ignore[arg-type] # noqa: F821
self.capture_manager = capture_manager
self.reset()
self.set_when(None)
self._test_outcome_written = False
def reset(self):
def reset(self) -> None:
"""Reset the handler; should be called before the start of each test"""
self._first_record_emitted = False
def set_when(self, when):
def set_when(self, when: Optional[str]) -> None:
"""Prepares for the given test phase (setup/call/teardown)"""
self._when = when
self._section_name_shown = False
if when == "start":
self._test_outcome_written = False
def emit(self, record):
def emit(self, record: logging.LogRecord) -> None:
ctx_manager = (
self.capture_manager.global_and_fixture_disabled()
if self.capture_manager
@ -764,10 +782,10 @@ class _LiveLoggingStreamHandler(logging.StreamHandler):
class _LiveLoggingNullHandler(logging.NullHandler):
"""A handler used when live logging is disabled."""
def reset(self):
def reset(self) -> None:
pass
def set_when(self, when):
def set_when(self, when: str) -> None:
pass
def handleError(self, record: logging.LogRecord) -> None: