Type annotate _pytest.logging
This commit is contained in:
parent
90e58f8961
commit
db52928684
|
@ -11,18 +11,24 @@ from typing import Generator
|
|||
from typing import List
|
||||
from typing import Mapping
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import TypeVar
|
||||
from typing import Union
|
||||
|
||||
import pytest
|
||||
from _pytest import nodes
|
||||
from _pytest._io import TerminalWriter
|
||||
from _pytest.capture import CaptureManager
|
||||
from _pytest.compat import nullcontext
|
||||
from _pytest.config import _strtobool
|
||||
from _pytest.config import Config
|
||||
from _pytest.config import create_terminal_writer
|
||||
from _pytest.config.argparsing import Parser
|
||||
from _pytest.fixtures import FixtureRequest
|
||||
from _pytest.main import Session
|
||||
from _pytest.pathlib import Path
|
||||
from _pytest.store import StoreKey
|
||||
from _pytest.terminal import TerminalReporter
|
||||
|
||||
|
||||
DEFAULT_LOG_FORMAT = "%(levelname)-8s %(name)s:%(filename)s:%(lineno)d %(message)s"
|
||||
|
@ -32,7 +38,7 @@ catch_log_handler_key = StoreKey["LogCaptureHandler"]()
|
|||
catch_log_records_key = StoreKey[Dict[str, List[logging.LogRecord]]]()
|
||||
|
||||
|
||||
def _remove_ansi_escape_sequences(text):
|
||||
def _remove_ansi_escape_sequences(text: str) -> str:
|
||||
return _ANSI_ESCAPE_SEQ.sub("", text)
|
||||
|
||||
|
||||
|
@ -52,7 +58,7 @@ class ColoredLevelFormatter(logging.Formatter):
|
|||
} # type: Mapping[int, AbstractSet[str]]
|
||||
LEVELNAME_FMT_REGEX = re.compile(r"%\(levelname\)([+-.]?\d*s)")
|
||||
|
||||
def __init__(self, terminalwriter, *args, **kwargs) -> None:
|
||||
def __init__(self, terminalwriter: TerminalWriter, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self._original_fmt = self._style._fmt
|
||||
self._level_to_fmt_mapping = {} # type: Dict[int, str]
|
||||
|
@ -77,7 +83,7 @@ class ColoredLevelFormatter(logging.Formatter):
|
|||
colorized_formatted_levelname, self._fmt
|
||||
)
|
||||
|
||||
def format(self, record):
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
fmt = self._level_to_fmt_mapping.get(record.levelno, self._original_fmt)
|
||||
self._style._fmt = fmt
|
||||
return super().format(record)
|
||||
|
@ -90,18 +96,20 @@ class PercentStyleMultiline(logging.PercentStyle):
|
|||
formats the message as if each line were logged separately.
|
||||
"""
|
||||
|
||||
def __init__(self, fmt, auto_indent):
|
||||
def __init__(self, fmt: str, auto_indent: Union[int, str, bool]) -> None:
|
||||
super().__init__(fmt)
|
||||
self._auto_indent = self._get_auto_indent(auto_indent)
|
||||
|
||||
@staticmethod
|
||||
def _update_message(record_dict, message):
|
||||
def _update_message(
|
||||
record_dict: Dict[str, object], message: str
|
||||
) -> Dict[str, object]:
|
||||
tmp = record_dict.copy()
|
||||
tmp["message"] = message
|
||||
return tmp
|
||||
|
||||
@staticmethod
|
||||
def _get_auto_indent(auto_indent_option) -> int:
|
||||
def _get_auto_indent(auto_indent_option: Union[int, str, bool]) -> int:
|
||||
"""Determines the current auto indentation setting
|
||||
|
||||
Specify auto indent behavior (on/off/fixed) by passing in
|
||||
|
@ -149,11 +157,11 @@ class PercentStyleMultiline(logging.PercentStyle):
|
|||
|
||||
return 0
|
||||
|
||||
def format(self, record):
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
if "\n" in record.message:
|
||||
if hasattr(record, "auto_indent"):
|
||||
# passed in from the "extra={}" kwarg on the call to logging.log()
|
||||
auto_indent = self._get_auto_indent(record.auto_indent)
|
||||
auto_indent = self._get_auto_indent(record.auto_indent) # type: ignore[attr-defined] # noqa: F821
|
||||
else:
|
||||
auto_indent = self._auto_indent
|
||||
|
||||
|
@ -173,7 +181,7 @@ class PercentStyleMultiline(logging.PercentStyle):
|
|||
return self._fmt % record.__dict__
|
||||
|
||||
|
||||
def get_option_ini(config, *names):
|
||||
def get_option_ini(config: Config, *names: str):
|
||||
for name in names:
|
||||
ret = config.getoption(name) # 'default' arg won't work as expected
|
||||
if ret is None:
|
||||
|
@ -268,13 +276,16 @@ def pytest_addoption(parser: Parser) -> None:
|
|||
)
|
||||
|
||||
|
||||
_HandlerType = TypeVar("_HandlerType", bound=logging.Handler)
|
||||
|
||||
|
||||
# Not using @contextmanager for performance reasons.
|
||||
class catching_logs:
|
||||
"""Context manager that prepares the whole logging machinery properly."""
|
||||
|
||||
__slots__ = ("handler", "level", "orig_level")
|
||||
|
||||
def __init__(self, handler, level=None):
|
||||
def __init__(self, handler: _HandlerType, level: Optional[int] = None) -> None:
|
||||
self.handler = handler
|
||||
self.level = level
|
||||
|
||||
|
@ -330,7 +341,7 @@ class LogCaptureFixture:
|
|||
"""Creates a new funcarg."""
|
||||
self._item = item
|
||||
# dict of log name -> log level
|
||||
self._initial_log_levels = {} # type: Dict[str, int]
|
||||
self._initial_log_levels = {} # type: Dict[Optional[str], int]
|
||||
|
||||
def _finalize(self) -> None:
|
||||
"""Finalizes the fixture.
|
||||
|
@ -364,17 +375,17 @@ class LogCaptureFixture:
|
|||
return self._item._store[catch_log_records_key].get(when, [])
|
||||
|
||||
@property
|
||||
def text(self):
|
||||
def text(self) -> str:
|
||||
"""Returns the formatted log text."""
|
||||
return _remove_ansi_escape_sequences(self.handler.stream.getvalue())
|
||||
|
||||
@property
|
||||
def records(self):
|
||||
def records(self) -> List[logging.LogRecord]:
|
||||
"""Returns the list of log records."""
|
||||
return self.handler.records
|
||||
|
||||
@property
|
||||
def record_tuples(self):
|
||||
def record_tuples(self) -> List[Tuple[str, int, str]]:
|
||||
"""Returns a list of a stripped down version of log records intended
|
||||
for use in assertion comparison.
|
||||
|
||||
|
@ -385,7 +396,7 @@ class LogCaptureFixture:
|
|||
return [(r.name, r.levelno, r.getMessage()) for r in self.records]
|
||||
|
||||
@property
|
||||
def messages(self):
|
||||
def messages(self) -> List[str]:
|
||||
"""Returns a list of format-interpolated log messages.
|
||||
|
||||
Unlike 'records', which contains the format string and parameters for interpolation, log messages in this list
|
||||
|
@ -400,11 +411,11 @@ class LogCaptureFixture:
|
|||
"""
|
||||
return [r.getMessage() for r in self.records]
|
||||
|
||||
def clear(self):
|
||||
def clear(self) -> None:
|
||||
"""Reset the list of log records and the captured log text."""
|
||||
self.handler.reset()
|
||||
|
||||
def set_level(self, level, logger=None):
|
||||
def set_level(self, level: Union[int, str], logger: Optional[str] = None) -> None:
|
||||
"""Sets the level for capturing of logs. The level will be restored to its previous value at the end of
|
||||
the test.
|
||||
|
||||
|
@ -415,31 +426,32 @@ class LogCaptureFixture:
|
|||
The levels of the loggers changed by this function will be restored to their initial values at the
|
||||
end of the test.
|
||||
"""
|
||||
logger_name = logger
|
||||
logger = logging.getLogger(logger_name)
|
||||
logger_obj = logging.getLogger(logger)
|
||||
# save the original log-level to restore it during teardown
|
||||
self._initial_log_levels.setdefault(logger_name, logger.level)
|
||||
logger.setLevel(level)
|
||||
self._initial_log_levels.setdefault(logger, logger_obj.level)
|
||||
logger_obj.setLevel(level)
|
||||
|
||||
@contextmanager
|
||||
def at_level(self, level, logger=None):
|
||||
def at_level(
|
||||
self, level: int, logger: Optional[str] = None
|
||||
) -> Generator[None, None, None]:
|
||||
"""Context manager that sets the level for capturing of logs. After the end of the 'with' statement the
|
||||
level is restored to its original value.
|
||||
|
||||
:param int level: the logger to level.
|
||||
:param str logger: the logger to update the level. If not given, the root logger level is updated.
|
||||
"""
|
||||
logger = logging.getLogger(logger)
|
||||
orig_level = logger.level
|
||||
logger.setLevel(level)
|
||||
logger_obj = logging.getLogger(logger)
|
||||
orig_level = logger_obj.level
|
||||
logger_obj.setLevel(level)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
logger.setLevel(orig_level)
|
||||
logger_obj.setLevel(orig_level)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def caplog(request):
|
||||
def caplog(request: FixtureRequest) -> Generator[LogCaptureFixture, None, None]:
|
||||
"""Access and control log capturing.
|
||||
|
||||
Captured logs are available through the following properties/methods::
|
||||
|
@ -557,7 +569,7 @@ class LoggingPlugin:
|
|||
|
||||
return formatter
|
||||
|
||||
def set_log_path(self, fname):
|
||||
def set_log_path(self, fname: str) -> None:
|
||||
"""Public method, which can set filename parameter for
|
||||
Logging.FileHandler(). Also creates parent directory if
|
||||
it does not exist.
|
||||
|
@ -565,15 +577,15 @@ class LoggingPlugin:
|
|||
.. warning::
|
||||
Please considered as an experimental API.
|
||||
"""
|
||||
fname = Path(fname)
|
||||
fpath = Path(fname)
|
||||
|
||||
if not fname.is_absolute():
|
||||
fname = Path(self._config.rootdir, fname)
|
||||
if not fpath.is_absolute():
|
||||
fpath = Path(self._config.rootdir, fpath)
|
||||
|
||||
if not fname.parent.exists():
|
||||
fname.parent.mkdir(exist_ok=True, parents=True)
|
||||
if not fpath.parent.exists():
|
||||
fpath.parent.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
stream = fname.open(mode="w", encoding="UTF-8")
|
||||
stream = fpath.open(mode="w", encoding="UTF-8")
|
||||
if sys.version_info >= (3, 7):
|
||||
old_stream = self.log_file_handler.setStream(stream)
|
||||
else:
|
||||
|
@ -715,29 +727,35 @@ class _LiveLoggingStreamHandler(logging.StreamHandler):
|
|||
and won't appear in the terminal.
|
||||
"""
|
||||
|
||||
def __init__(self, terminal_reporter, capture_manager):
|
||||
# Officially stream needs to be a IO[str], but TerminalReporter
|
||||
# isn't. So force it.
|
||||
stream = None # type: TerminalReporter # type: ignore
|
||||
|
||||
def __init__(
|
||||
self, terminal_reporter: TerminalReporter, capture_manager: CaptureManager
|
||||
) -> None:
|
||||
"""
|
||||
:param _pytest.terminal.TerminalReporter terminal_reporter:
|
||||
:param _pytest.capture.CaptureManager capture_manager:
|
||||
"""
|
||||
logging.StreamHandler.__init__(self, stream=terminal_reporter)
|
||||
logging.StreamHandler.__init__(self, stream=terminal_reporter) # type: ignore[arg-type] # noqa: F821
|
||||
self.capture_manager = capture_manager
|
||||
self.reset()
|
||||
self.set_when(None)
|
||||
self._test_outcome_written = False
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
"""Reset the handler; should be called before the start of each test"""
|
||||
self._first_record_emitted = False
|
||||
|
||||
def set_when(self, when):
|
||||
def set_when(self, when: Optional[str]) -> None:
|
||||
"""Prepares for the given test phase (setup/call/teardown)"""
|
||||
self._when = when
|
||||
self._section_name_shown = False
|
||||
if when == "start":
|
||||
self._test_outcome_written = False
|
||||
|
||||
def emit(self, record):
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
ctx_manager = (
|
||||
self.capture_manager.global_and_fixture_disabled()
|
||||
if self.capture_manager
|
||||
|
@ -764,10 +782,10 @@ class _LiveLoggingStreamHandler(logging.StreamHandler):
|
|||
class _LiveLoggingNullHandler(logging.NullHandler):
|
||||
"""A handler used when live logging is disabled."""
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def set_when(self, when):
|
||||
def set_when(self, when: str) -> None:
|
||||
pass
|
||||
|
||||
def handleError(self, record: logging.LogRecord) -> None:
|
||||
|
|
Loading…
Reference in New Issue