From 3f8f3952107998f03a7fd1826427a9262a267f6c Mon Sep 17 00:00:00 2001 From: Daniel Hahler Date: Sat, 25 Jan 2020 18:11:38 +0100 Subject: [PATCH] typing: EncodedFile --- src/_pytest/capture.py | 15 +++++++-------- testing/test_capture.py | 3 ++- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/_pytest/capture.py b/src/_pytest/capture.py index e51fe2b67..33d2243b3 100644 --- a/src/_pytest/capture.py +++ b/src/_pytest/capture.py @@ -9,6 +9,7 @@ import os import sys from io import UnsupportedOperation from tempfile import TemporaryFile +from typing import BinaryIO from typing import List import pytest @@ -414,29 +415,27 @@ def safe_text_dupfile(f, mode, default_encoding="UTF8"): class EncodedFile: errors = "strict" # possibly needed by py3 code (issue555) - def __init__(self, buffer, encoding): + def __init__(self, buffer: BinaryIO, encoding: str) -> None: self.buffer = buffer self.encoding = encoding - def write(self, obj): - if isinstance(obj, str): - obj = obj.encode(self.encoding, "replace") - else: + def write(self, obj: str) -> int: + if not isinstance(obj, str): raise TypeError( "write() argument must be str, not {}".format(type(obj).__name__) ) - return self.buffer.write(obj) + return self.buffer.write(obj.encode(self.encoding, "replace")) def writelines(self, linelist: List[str]) -> None: self.buffer.writelines([x.encode(self.encoding, "replace") for x in linelist]) @property - def name(self): + def name(self) -> str: """Ensure that file.name is a string.""" return repr(self.buffer) @property - def mode(self): + def mode(self) -> str: return self.buffer.mode.replace("b", "") def __getattr__(self, name): diff --git a/testing/test_capture.py b/testing/test_capture.py index ebe30703b..e6862f313 100644 --- a/testing/test_capture.py +++ b/testing/test_capture.py @@ -7,6 +7,7 @@ import sys import textwrap from io import StringIO from io import UnsupportedOperation +from typing import BinaryIO from typing import List from typing import TextIO @@ -1499,7 +1500,7 @@ def test_stderr_write_returns_len(capsys): assert sys.stderr.write("Foo") == 3 -def test_encodedfile_writelines(tmpfile) -> None: +def test_encodedfile_writelines(tmpfile: BinaryIO) -> None: ef = capture.EncodedFile(tmpfile, "utf-8") with pytest.raises(AttributeError): ef.writelines([b"line1", b"line2"]) # type: ignore[list-item] # noqa: F821