Type annotate _pytest.logging

This commit is contained in:
Ran Benita 2020-05-01 14:40:16 +03:00
parent 90e58f8961
commit db52928684
1 changed files with 59 additions and 41 deletions

View File

@ -11,18 +11,24 @@ from typing import Generator
from typing import List from typing import List
from typing import Mapping from typing import Mapping
from typing import Optional from typing import Optional
from typing import Tuple
from typing import TypeVar
from typing import Union from typing import Union
import pytest import pytest
from _pytest import nodes from _pytest import nodes
from _pytest._io import TerminalWriter
from _pytest.capture import CaptureManager
from _pytest.compat import nullcontext from _pytest.compat import nullcontext
from _pytest.config import _strtobool from _pytest.config import _strtobool
from _pytest.config import Config from _pytest.config import Config
from _pytest.config import create_terminal_writer from _pytest.config import create_terminal_writer
from _pytest.config.argparsing import Parser from _pytest.config.argparsing import Parser
from _pytest.fixtures import FixtureRequest
from _pytest.main import Session from _pytest.main import Session
from _pytest.pathlib import Path from _pytest.pathlib import Path
from _pytest.store import StoreKey from _pytest.store import StoreKey
from _pytest.terminal import TerminalReporter
DEFAULT_LOG_FORMAT = "%(levelname)-8s %(name)s:%(filename)s:%(lineno)d %(message)s" DEFAULT_LOG_FORMAT = "%(levelname)-8s %(name)s:%(filename)s:%(lineno)d %(message)s"
@ -32,7 +38,7 @@ catch_log_handler_key = StoreKey["LogCaptureHandler"]()
catch_log_records_key = StoreKey[Dict[str, List[logging.LogRecord]]]() catch_log_records_key = StoreKey[Dict[str, List[logging.LogRecord]]]()
def _remove_ansi_escape_sequences(text): def _remove_ansi_escape_sequences(text: str) -> str:
return _ANSI_ESCAPE_SEQ.sub("", text) return _ANSI_ESCAPE_SEQ.sub("", text)
@ -52,7 +58,7 @@ class ColoredLevelFormatter(logging.Formatter):
} # type: Mapping[int, AbstractSet[str]] } # type: Mapping[int, AbstractSet[str]]
LEVELNAME_FMT_REGEX = re.compile(r"%\(levelname\)([+-.]?\d*s)") LEVELNAME_FMT_REGEX = re.compile(r"%\(levelname\)([+-.]?\d*s)")
def __init__(self, terminalwriter, *args, **kwargs) -> None: def __init__(self, terminalwriter: TerminalWriter, *args, **kwargs) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._original_fmt = self._style._fmt self._original_fmt = self._style._fmt
self._level_to_fmt_mapping = {} # type: Dict[int, str] self._level_to_fmt_mapping = {} # type: Dict[int, str]
@ -77,7 +83,7 @@ class ColoredLevelFormatter(logging.Formatter):
colorized_formatted_levelname, self._fmt colorized_formatted_levelname, self._fmt
) )
def format(self, record): def format(self, record: logging.LogRecord) -> str:
fmt = self._level_to_fmt_mapping.get(record.levelno, self._original_fmt) fmt = self._level_to_fmt_mapping.get(record.levelno, self._original_fmt)
self._style._fmt = fmt self._style._fmt = fmt
return super().format(record) return super().format(record)
@ -90,18 +96,20 @@ class PercentStyleMultiline(logging.PercentStyle):
formats the message as if each line were logged separately. formats the message as if each line were logged separately.
""" """
def __init__(self, fmt, auto_indent): def __init__(self, fmt: str, auto_indent: Union[int, str, bool]) -> None:
super().__init__(fmt) super().__init__(fmt)
self._auto_indent = self._get_auto_indent(auto_indent) self._auto_indent = self._get_auto_indent(auto_indent)
@staticmethod @staticmethod
def _update_message(record_dict, message): def _update_message(
record_dict: Dict[str, object], message: str
) -> Dict[str, object]:
tmp = record_dict.copy() tmp = record_dict.copy()
tmp["message"] = message tmp["message"] = message
return tmp return tmp
@staticmethod @staticmethod
def _get_auto_indent(auto_indent_option) -> int: def _get_auto_indent(auto_indent_option: Union[int, str, bool]) -> int:
"""Determines the current auto indentation setting """Determines the current auto indentation setting
Specify auto indent behavior (on/off/fixed) by passing in Specify auto indent behavior (on/off/fixed) by passing in
@ -149,11 +157,11 @@ class PercentStyleMultiline(logging.PercentStyle):
return 0 return 0
def format(self, record): def format(self, record: logging.LogRecord) -> str:
if "\n" in record.message: if "\n" in record.message:
if hasattr(record, "auto_indent"): if hasattr(record, "auto_indent"):
# passed in from the "extra={}" kwarg on the call to logging.log() # passed in from the "extra={}" kwarg on the call to logging.log()
auto_indent = self._get_auto_indent(record.auto_indent) auto_indent = self._get_auto_indent(record.auto_indent) # type: ignore[attr-defined] # noqa: F821
else: else:
auto_indent = self._auto_indent auto_indent = self._auto_indent
@ -173,7 +181,7 @@ class PercentStyleMultiline(logging.PercentStyle):
return self._fmt % record.__dict__ return self._fmt % record.__dict__
def get_option_ini(config, *names): def get_option_ini(config: Config, *names: str):
for name in names: for name in names:
ret = config.getoption(name) # 'default' arg won't work as expected ret = config.getoption(name) # 'default' arg won't work as expected
if ret is None: if ret is None:
@ -268,13 +276,16 @@ def pytest_addoption(parser: Parser) -> None:
) )
_HandlerType = TypeVar("_HandlerType", bound=logging.Handler)
# Not using @contextmanager for performance reasons. # Not using @contextmanager for performance reasons.
class catching_logs: class catching_logs:
"""Context manager that prepares the whole logging machinery properly.""" """Context manager that prepares the whole logging machinery properly."""
__slots__ = ("handler", "level", "orig_level") __slots__ = ("handler", "level", "orig_level")
def __init__(self, handler, level=None): def __init__(self, handler: _HandlerType, level: Optional[int] = None) -> None:
self.handler = handler self.handler = handler
self.level = level self.level = level
@ -330,7 +341,7 @@ class LogCaptureFixture:
"""Creates a new funcarg.""" """Creates a new funcarg."""
self._item = item self._item = item
# dict of log name -> log level # dict of log name -> log level
self._initial_log_levels = {} # type: Dict[str, int] self._initial_log_levels = {} # type: Dict[Optional[str], int]
def _finalize(self) -> None: def _finalize(self) -> None:
"""Finalizes the fixture. """Finalizes the fixture.
@ -364,17 +375,17 @@ class LogCaptureFixture:
return self._item._store[catch_log_records_key].get(when, []) return self._item._store[catch_log_records_key].get(when, [])
@property @property
def text(self): def text(self) -> str:
"""Returns the formatted log text.""" """Returns the formatted log text."""
return _remove_ansi_escape_sequences(self.handler.stream.getvalue()) return _remove_ansi_escape_sequences(self.handler.stream.getvalue())
@property @property
def records(self): def records(self) -> List[logging.LogRecord]:
"""Returns the list of log records.""" """Returns the list of log records."""
return self.handler.records return self.handler.records
@property @property
def record_tuples(self): def record_tuples(self) -> List[Tuple[str, int, str]]:
"""Returns a list of a stripped down version of log records intended """Returns a list of a stripped down version of log records intended
for use in assertion comparison. for use in assertion comparison.
@ -385,7 +396,7 @@ class LogCaptureFixture:
return [(r.name, r.levelno, r.getMessage()) for r in self.records] return [(r.name, r.levelno, r.getMessage()) for r in self.records]
@property @property
def messages(self): def messages(self) -> List[str]:
"""Returns a list of format-interpolated log messages. """Returns a list of format-interpolated log messages.
Unlike 'records', which contains the format string and parameters for interpolation, log messages in this list Unlike 'records', which contains the format string and parameters for interpolation, log messages in this list
@ -400,11 +411,11 @@ class LogCaptureFixture:
""" """
return [r.getMessage() for r in self.records] return [r.getMessage() for r in self.records]
def clear(self): def clear(self) -> None:
"""Reset the list of log records and the captured log text.""" """Reset the list of log records and the captured log text."""
self.handler.reset() self.handler.reset()
def set_level(self, level, logger=None): def set_level(self, level: Union[int, str], logger: Optional[str] = None) -> None:
"""Sets the level for capturing of logs. The level will be restored to its previous value at the end of """Sets the level for capturing of logs. The level will be restored to its previous value at the end of
the test. the test.
@ -415,31 +426,32 @@ class LogCaptureFixture:
The levels of the loggers changed by this function will be restored to their initial values at the The levels of the loggers changed by this function will be restored to their initial values at the
end of the test. end of the test.
""" """
logger_name = logger logger_obj = logging.getLogger(logger)
logger = logging.getLogger(logger_name)
# save the original log-level to restore it during teardown # save the original log-level to restore it during teardown
self._initial_log_levels.setdefault(logger_name, logger.level) self._initial_log_levels.setdefault(logger, logger_obj.level)
logger.setLevel(level) logger_obj.setLevel(level)
@contextmanager @contextmanager
def at_level(self, level, logger=None): def at_level(
self, level: int, logger: Optional[str] = None
) -> Generator[None, None, None]:
"""Context manager that sets the level for capturing of logs. After the end of the 'with' statement the """Context manager that sets the level for capturing of logs. After the end of the 'with' statement the
level is restored to its original value. level is restored to its original value.
:param int level: the logger to level. :param int level: the logger to level.
:param str logger: the logger to update the level. If not given, the root logger level is updated. :param str logger: the logger to update the level. If not given, the root logger level is updated.
""" """
logger = logging.getLogger(logger) logger_obj = logging.getLogger(logger)
orig_level = logger.level orig_level = logger_obj.level
logger.setLevel(level) logger_obj.setLevel(level)
try: try:
yield yield
finally: finally:
logger.setLevel(orig_level) logger_obj.setLevel(orig_level)
@pytest.fixture @pytest.fixture
def caplog(request): def caplog(request: FixtureRequest) -> Generator[LogCaptureFixture, None, None]:
"""Access and control log capturing. """Access and control log capturing.
Captured logs are available through the following properties/methods:: Captured logs are available through the following properties/methods::
@ -557,7 +569,7 @@ class LoggingPlugin:
return formatter return formatter
def set_log_path(self, fname): def set_log_path(self, fname: str) -> None:
"""Public method, which can set filename parameter for """Public method, which can set filename parameter for
Logging.FileHandler(). Also creates parent directory if Logging.FileHandler(). Also creates parent directory if
it does not exist. it does not exist.
@ -565,15 +577,15 @@ class LoggingPlugin:
.. warning:: .. warning::
Please considered as an experimental API. Please considered as an experimental API.
""" """
fname = Path(fname) fpath = Path(fname)
if not fname.is_absolute(): if not fpath.is_absolute():
fname = Path(self._config.rootdir, fname) fpath = Path(self._config.rootdir, fpath)
if not fname.parent.exists(): if not fpath.parent.exists():
fname.parent.mkdir(exist_ok=True, parents=True) fpath.parent.mkdir(exist_ok=True, parents=True)
stream = fname.open(mode="w", encoding="UTF-8") stream = fpath.open(mode="w", encoding="UTF-8")
if sys.version_info >= (3, 7): if sys.version_info >= (3, 7):
old_stream = self.log_file_handler.setStream(stream) old_stream = self.log_file_handler.setStream(stream)
else: else:
@ -715,29 +727,35 @@ class _LiveLoggingStreamHandler(logging.StreamHandler):
and won't appear in the terminal. and won't appear in the terminal.
""" """
def __init__(self, terminal_reporter, capture_manager): # Officially stream needs to be a IO[str], but TerminalReporter
# isn't. So force it.
stream = None # type: TerminalReporter # type: ignore
def __init__(
self, terminal_reporter: TerminalReporter, capture_manager: CaptureManager
) -> None:
""" """
:param _pytest.terminal.TerminalReporter terminal_reporter: :param _pytest.terminal.TerminalReporter terminal_reporter:
:param _pytest.capture.CaptureManager capture_manager: :param _pytest.capture.CaptureManager capture_manager:
""" """
logging.StreamHandler.__init__(self, stream=terminal_reporter) logging.StreamHandler.__init__(self, stream=terminal_reporter) # type: ignore[arg-type] # noqa: F821
self.capture_manager = capture_manager self.capture_manager = capture_manager
self.reset() self.reset()
self.set_when(None) self.set_when(None)
self._test_outcome_written = False self._test_outcome_written = False
def reset(self): def reset(self) -> None:
"""Reset the handler; should be called before the start of each test""" """Reset the handler; should be called before the start of each test"""
self._first_record_emitted = False self._first_record_emitted = False
def set_when(self, when): def set_when(self, when: Optional[str]) -> None:
"""Prepares for the given test phase (setup/call/teardown)""" """Prepares for the given test phase (setup/call/teardown)"""
self._when = when self._when = when
self._section_name_shown = False self._section_name_shown = False
if when == "start": if when == "start":
self._test_outcome_written = False self._test_outcome_written = False
def emit(self, record): def emit(self, record: logging.LogRecord) -> None:
ctx_manager = ( ctx_manager = (
self.capture_manager.global_and_fixture_disabled() self.capture_manager.global_and_fixture_disabled()
if self.capture_manager if self.capture_manager
@ -764,10 +782,10 @@ class _LiveLoggingStreamHandler(logging.StreamHandler):
class _LiveLoggingNullHandler(logging.NullHandler): class _LiveLoggingNullHandler(logging.NullHandler):
"""A handler used when live logging is disabled.""" """A handler used when live logging is disabled."""
def reset(self): def reset(self) -> None:
pass pass
def set_when(self, when): def set_when(self, when: str) -> None:
pass pass
def handleError(self, record: logging.LogRecord) -> None: def handleError(self, record: logging.LogRecord) -> None: