typing: EncodedFile

This commit is contained in:
Daniel Hahler 2020-01-25 18:11:38 +01:00
parent 039d582b52
commit 3f8f395210
2 changed files with 9 additions and 9 deletions

View File

@ -9,6 +9,7 @@ import os
import sys import sys
from io import UnsupportedOperation from io import UnsupportedOperation
from tempfile import TemporaryFile from tempfile import TemporaryFile
from typing import BinaryIO
from typing import List from typing import List
import pytest import pytest
@ -414,29 +415,27 @@ def safe_text_dupfile(f, mode, default_encoding="UTF8"):
class EncodedFile: class EncodedFile:
errors = "strict" # possibly needed by py3 code (issue555) errors = "strict" # possibly needed by py3 code (issue555)
def __init__(self, buffer, encoding): def __init__(self, buffer: BinaryIO, encoding: str) -> None:
self.buffer = buffer self.buffer = buffer
self.encoding = encoding self.encoding = encoding
def write(self, obj): def write(self, obj: str) -> int:
if isinstance(obj, str): if not isinstance(obj, str):
obj = obj.encode(self.encoding, "replace")
else:
raise TypeError( raise TypeError(
"write() argument must be str, not {}".format(type(obj).__name__) "write() argument must be str, not {}".format(type(obj).__name__)
) )
return self.buffer.write(obj) return self.buffer.write(obj.encode(self.encoding, "replace"))
def writelines(self, linelist: List[str]) -> None: def writelines(self, linelist: List[str]) -> None:
self.buffer.writelines([x.encode(self.encoding, "replace") for x in linelist]) self.buffer.writelines([x.encode(self.encoding, "replace") for x in linelist])
@property @property
def name(self): def name(self) -> str:
"""Ensure that file.name is a string.""" """Ensure that file.name is a string."""
return repr(self.buffer) return repr(self.buffer)
@property @property
def mode(self): def mode(self) -> str:
return self.buffer.mode.replace("b", "") return self.buffer.mode.replace("b", "")
def __getattr__(self, name): def __getattr__(self, name):

View File

@ -7,6 +7,7 @@ import sys
import textwrap import textwrap
from io import StringIO from io import StringIO
from io import UnsupportedOperation from io import UnsupportedOperation
from typing import BinaryIO
from typing import List from typing import List
from typing import TextIO from typing import TextIO
@ -1499,7 +1500,7 @@ def test_stderr_write_returns_len(capsys):
assert sys.stderr.write("Foo") == 3 assert sys.stderr.write("Foo") == 3
def test_encodedfile_writelines(tmpfile) -> None: def test_encodedfile_writelines(tmpfile: BinaryIO) -> None:
ef = capture.EncodedFile(tmpfile, "utf-8") ef = capture.EncodedFile(tmpfile, "utf-8")
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
ef.writelines([b"line1", b"line2"]) # type: ignore[list-item] # noqa: F821 ef.writelines([b"line1", b"line2"]) # type: ignore[list-item] # noqa: F821