Merge pull request #8014 from bluetech/pyc-pep552

assertion/rewrite: write pyc's according to PEP-552 on Python>=3.7
This commit is contained in:
Ran Benita 2020-11-14 23:38:45 +02:00 committed by GitHub
commit 825b81ba52
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 77 additions and 10 deletions

View File

@ -0,0 +1,2 @@
`.pyc` files created by pytest's assertion rewriting now conform to the newer PEP-552 format on Python>=3.7.
(These files are internal and only interpreted by pytest itself.)

View File

@ -281,12 +281,16 @@ def _write_pyc_fp(
) -> None: ) -> None:
# Technically, we don't have to have the same pyc format as # Technically, we don't have to have the same pyc format as
# (C)Python, since these "pycs" should never be seen by builtin # (C)Python, since these "pycs" should never be seen by builtin
# import. However, there's little reason deviate. # import. However, there's little reason to deviate.
fp.write(importlib.util.MAGIC_NUMBER) fp.write(importlib.util.MAGIC_NUMBER)
# https://www.python.org/dev/peps/pep-0552/
if sys.version_info >= (3, 7):
flags = b"\x00\x00\x00\x00"
fp.write(flags)
# as of now, bytecode header expects 32-bit numbers for size and mtime (#4903) # as of now, bytecode header expects 32-bit numbers for size and mtime (#4903)
mtime = int(source_stat.st_mtime) & 0xFFFFFFFF mtime = int(source_stat.st_mtime) & 0xFFFFFFFF
size = source_stat.st_size & 0xFFFFFFFF size = source_stat.st_size & 0xFFFFFFFF
# "<LL" stands for 2 unsigned longs, little-ending # "<LL" stands for 2 unsigned longs, little-endian.
fp.write(struct.pack("<LL", mtime, size)) fp.write(struct.pack("<LL", mtime, size))
fp.write(marshal.dumps(co)) fp.write(marshal.dumps(co))
@ -365,21 +369,33 @@ def _read_pyc(
except OSError: except OSError:
return None return None
with fp: with fp:
# https://www.python.org/dev/peps/pep-0552/
has_flags = sys.version_info >= (3, 7)
try: try:
stat_result = os.stat(os.fspath(source)) stat_result = os.stat(os.fspath(source))
mtime = int(stat_result.st_mtime) mtime = int(stat_result.st_mtime)
size = stat_result.st_size size = stat_result.st_size
data = fp.read(12) data = fp.read(16 if has_flags else 12)
except OSError as e: except OSError as e:
trace(f"_read_pyc({source}): OSError {e}") trace(f"_read_pyc({source}): OSError {e}")
return None return None
# Check for invalid or out of date pyc file. # Check for invalid or out of date pyc file.
if ( if len(data) != (16 if has_flags else 12):
len(data) != 12 trace("_read_pyc(%s): invalid pyc (too short)" % source)
or data[:4] != importlib.util.MAGIC_NUMBER return None
or struct.unpack("<LL", data[4:]) != (mtime & 0xFFFFFFFF, size & 0xFFFFFFFF) if data[:4] != importlib.util.MAGIC_NUMBER:
): trace("_read_pyc(%s): invalid pyc (bad magic number)" % source)
trace("_read_pyc(%s): invalid or out of date pyc" % source) return None
if has_flags and data[4:8] != b"\x00\x00\x00\x00":
trace("_read_pyc(%s): invalid pyc (unsupported flags)" % source)
return None
mtime_data = data[8 if has_flags else 4 : 12 if has_flags else 8]
if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF:
trace("_read_pyc(%s): out of date" % source)
return None
size_data = data[12 if has_flags else 8 : 16 if has_flags else 12]
if int.from_bytes(size_data, "little") != size & 0xFFFFFFFF:
trace("_read_pyc(%s): invalid pyc (incorrect size)" % source)
return None return None
try: try:
co = marshal.load(fp) co = marshal.load(fp)

View File

@ -2,6 +2,7 @@ import ast
import errno import errno
import glob import glob
import importlib import importlib
import marshal
import os import os
import py_compile import py_compile
import stat import stat
@ -1063,12 +1064,60 @@ class TestAssertionRewriteHookDetails:
py_compile.compile(str(source), str(pyc)) py_compile.compile(str(source), str(pyc))
contents = pyc.read_bytes() contents = pyc.read_bytes()
strip_bytes = 20 # header is around 8 bytes, strip a little more strip_bytes = 20 # header is around 16 bytes, strip a little more
assert len(contents) > strip_bytes assert len(contents) > strip_bytes
pyc.write_bytes(contents[:strip_bytes]) pyc.write_bytes(contents[:strip_bytes])
assert _read_pyc(source, pyc) is None # no error assert _read_pyc(source, pyc) is None # no error
@pytest.mark.skipif(
sys.version_info < (3, 7), reason="Only the Python 3.7 format for simplicity"
)
def test_read_pyc_more_invalid(self, tmp_path: Path) -> None:
from _pytest.assertion.rewrite import _read_pyc
source = tmp_path / "source.py"
pyc = tmp_path / "source.pyc"
source_bytes = b"def test(): pass\n"
source.write_bytes(source_bytes)
magic = importlib.util.MAGIC_NUMBER
flags = b"\x00\x00\x00\x00"
mtime = b"\x58\x3c\xb0\x5f"
mtime_int = int.from_bytes(mtime, "little")
os.utime(source, (mtime_int, mtime_int))
size = len(source_bytes).to_bytes(4, "little")
code = marshal.dumps(compile(source_bytes, str(source), "exec"))
# Good header.
pyc.write_bytes(magic + flags + mtime + size + code)
assert _read_pyc(source, pyc, print) is not None
# Too short.
pyc.write_bytes(magic + flags + mtime)
assert _read_pyc(source, pyc, print) is None
# Bad magic.
pyc.write_bytes(b"\x12\x34\x56\x78" + flags + mtime + size + code)
assert _read_pyc(source, pyc, print) is None
# Unsupported flags.
pyc.write_bytes(magic + b"\x00\xff\x00\x00" + mtime + size + code)
assert _read_pyc(source, pyc, print) is None
# Bad mtime.
pyc.write_bytes(magic + flags + b"\x58\x3d\xb0\x5f" + size + code)
assert _read_pyc(source, pyc, print) is None
# Bad size.
pyc.write_bytes(magic + flags + mtime + b"\x99\x00\x00\x00" + code)
assert _read_pyc(source, pyc, print) is None
def test_reload_is_same_and_reloads(self, pytester: Pytester) -> None: def test_reload_is_same_and_reloads(self, pytester: Pytester) -> None:
"""Reloading a (collected) module after change picks up the change.""" """Reloading a (collected) module after change picks up the change."""
pytester.makeini( pytester.makeini(