Add a simple (hopefully) cross-python marshaller
Will rewrite the tests soon... --HG-- branch : trunk
This commit is contained in:
parent
b3ca12d435
commit
1e71a5c392
|
@ -0,0 +1,291 @@
|
||||||
|
"""
|
||||||
|
Simple marshal format (based on pickle) designed to work across Python versions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import struct
|
||||||
|
|
||||||
|
import py
|
||||||
|
|
||||||
|
_INPY3 = _REALLY_PY3 = sys.version_info > (3, 0)
|
||||||
|
|
||||||
|
class SerializeError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class SerializationError(SerializeError):
|
||||||
|
"""Error while serializing an object."""
|
||||||
|
|
||||||
|
class UnserializableType(SerializationError):
|
||||||
|
"""Can't serialize a type."""
|
||||||
|
|
||||||
|
class UnserializationError(SerializeError):
|
||||||
|
"""Error while unserializing an object."""
|
||||||
|
|
||||||
|
class VersionMismatch(UnserializationError):
|
||||||
|
"""Data from a previous or later format."""
|
||||||
|
|
||||||
|
class Corruption(UnserializationError):
|
||||||
|
"""The pickle format appears to have been corrupted."""
|
||||||
|
|
||||||
|
if _INPY3:
|
||||||
|
def b(s):
|
||||||
|
return s.encode("ascii")
|
||||||
|
_b = b
|
||||||
|
class _unicode(str):
|
||||||
|
pass
|
||||||
|
bytes = bytes
|
||||||
|
else:
|
||||||
|
class bytes(str):
|
||||||
|
pass
|
||||||
|
b = str
|
||||||
|
_b = bytes
|
||||||
|
_unicode = unicode
|
||||||
|
|
||||||
|
FOUR_BYTE_INT_MAX = 2147483647
|
||||||
|
|
||||||
|
_int4_format = struct.Struct("!i")
|
||||||
|
|
||||||
|
# Protocol constants
|
||||||
|
VERSION_NUMBER = 1
|
||||||
|
VERSION = b(chr(VERSION_NUMBER))
|
||||||
|
PY2STRING = b('s')
|
||||||
|
PY3STRING = b('t')
|
||||||
|
UNICODE = b('u')
|
||||||
|
BYTES = b('b')
|
||||||
|
NEWLIST = b('l')
|
||||||
|
BUILDTUPLE = b('T')
|
||||||
|
SETITEM = b('m')
|
||||||
|
NEWDICT = b('d')
|
||||||
|
INT = b('i')
|
||||||
|
STOP = b('S')
|
||||||
|
|
||||||
|
class CrossVersionOptions(object):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Serializer(object):
|
||||||
|
|
||||||
|
def __init__(self, stream):
|
||||||
|
self.stream = stream
|
||||||
|
|
||||||
|
def save(self, obj):
|
||||||
|
self.stream.write(VERSION)
|
||||||
|
self._save(obj)
|
||||||
|
self.stream.write(STOP)
|
||||||
|
|
||||||
|
def _save(self, obj):
|
||||||
|
tp = type(obj)
|
||||||
|
try:
|
||||||
|
dispatch = self.dispatch[tp]
|
||||||
|
except KeyError:
|
||||||
|
raise UnserializableType("can't serialize %s" % (tp,))
|
||||||
|
dispatch(self, obj)
|
||||||
|
|
||||||
|
def save_bytes(self, bytes_):
|
||||||
|
self.stream.write(BYTES)
|
||||||
|
self._write_byte_sequence(bytes_)
|
||||||
|
|
||||||
|
def save_unicode(self, s):
|
||||||
|
self.stream.write(UNICODE)
|
||||||
|
self._write_unicode_string(s)
|
||||||
|
|
||||||
|
def save_string(self, s):
|
||||||
|
if _INPY3:
|
||||||
|
self.stream.write(PY3STRING)
|
||||||
|
self._write_unicode_string(s)
|
||||||
|
else:
|
||||||
|
# Case for tests
|
||||||
|
if _REALLY_PY3 and isinstance(s, str):
|
||||||
|
s = s.encode("latin-1")
|
||||||
|
self.stream.write(PY2STRING)
|
||||||
|
self._write_byte_sequence(s)
|
||||||
|
|
||||||
|
def _write_unicode_string(self, s):
|
||||||
|
try:
|
||||||
|
as_bytes = s.encode("utf-8")
|
||||||
|
except UnicodeEncodeError:
|
||||||
|
raise SerializationError("strings must be utf-8 encodable")
|
||||||
|
self._write_byte_sequence(as_bytes)
|
||||||
|
|
||||||
|
def _write_byte_sequence(self, bytes_):
|
||||||
|
self._write_int4(len(bytes_), "string is too long")
|
||||||
|
self.stream.write(bytes_)
|
||||||
|
|
||||||
|
def save_int(self, i):
|
||||||
|
self.stream.write(INT)
|
||||||
|
self._write_int4(i)
|
||||||
|
|
||||||
|
def _write_int4(self, i, error="int must be less than %i" %
|
||||||
|
(FOUR_BYTE_INT_MAX,)):
|
||||||
|
if i > FOUR_BYTE_INT_MAX:
|
||||||
|
raise SerializationError(error)
|
||||||
|
self.stream.write(_int4_format.pack(i))
|
||||||
|
|
||||||
|
def save_list(self, L):
|
||||||
|
self.stream.write(NEWLIST)
|
||||||
|
self._write_int4(len(L), "list is too long")
|
||||||
|
for i, item in enumerate(L):
|
||||||
|
self._write_setitem(i, item)
|
||||||
|
|
||||||
|
def _write_setitem(self, key, value):
|
||||||
|
self._save(key)
|
||||||
|
self._save(value)
|
||||||
|
self.stream.write(SETITEM)
|
||||||
|
|
||||||
|
def save_dict(self, d):
|
||||||
|
self.stream.write(NEWDICT)
|
||||||
|
for key, value in d.items():
|
||||||
|
self._write_setitem(key, value)
|
||||||
|
|
||||||
|
def save_tuple(self, tup):
|
||||||
|
for item in tup:
|
||||||
|
self._save(item)
|
||||||
|
self.stream.write(BUILDTUPLE)
|
||||||
|
self._write_int4(len(tup), "tuple is too long")
|
||||||
|
|
||||||
|
|
||||||
|
class _UnserializationOptions(object):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class _Py2UnserializationOptions(_UnserializationOptions):
|
||||||
|
|
||||||
|
def __init__(self, py3_strings_as_str=False):
|
||||||
|
self.py3_strings_as_str = py3_strings_as_str
|
||||||
|
|
||||||
|
class _Py3UnserializationOptions(_UnserializationOptions):
|
||||||
|
|
||||||
|
def __init__(self, py2_strings_as_str=False):
|
||||||
|
self.py2_strings_as_str = py2_strings_as_str
|
||||||
|
|
||||||
|
|
||||||
|
_unchanging_dispatch = {}
|
||||||
|
for tp in (dict, list, tuple, int):
|
||||||
|
name = "save_%s" % (tp.__name__,)
|
||||||
|
_unchanging_dispatch[tp] = getattr(Serializer, name)
|
||||||
|
del tp, name
|
||||||
|
|
||||||
|
def _setup_dispatch():
|
||||||
|
dispatch = _unchanging_dispatch.copy()
|
||||||
|
# This is sutble. bytes is aliased to str in 2.6, so
|
||||||
|
# dispatch[bytes] is overwritten. Additionally, we alias unicode
|
||||||
|
# to str in 3.x, so dispatch[unicode] is overwritten with
|
||||||
|
# save_string.
|
||||||
|
dispatch[bytes] = Serializer.save_bytes
|
||||||
|
dispatch[unicode] = Serializer.save_unicode
|
||||||
|
dispatch[str] = Serializer.save_string
|
||||||
|
Serializer.dispatch = dispatch
|
||||||
|
|
||||||
|
def _setup_version_dependent_constants(leave_unicode_alone=False):
|
||||||
|
global unicode, UnserializationOptions
|
||||||
|
if _INPY3:
|
||||||
|
unicode = str
|
||||||
|
UnserializationOptions = _Py3UnserializationOptions
|
||||||
|
else:
|
||||||
|
UnserializationOptions = _Py2UnserializationOptions
|
||||||
|
unicode = _unicode
|
||||||
|
_setup_dispatch()
|
||||||
|
_setup_version_dependent_constants()
|
||||||
|
|
||||||
|
|
||||||
|
class _Stop(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Unserializer(object):
|
||||||
|
|
||||||
|
def __init__(self, stream, options=None):
|
||||||
|
self.stream = stream
|
||||||
|
if options is None:
|
||||||
|
options = UnserializationOptions()
|
||||||
|
self.options = options
|
||||||
|
|
||||||
|
def load(self):
|
||||||
|
self.stack = []
|
||||||
|
version = ord(self.stream.read(1))
|
||||||
|
if version != VERSION_NUMBER:
|
||||||
|
raise VersionMismatch("%i != %i" % (version, VERSION_NUMBER))
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
opcode = self.stream.read(1)
|
||||||
|
if not opcode:
|
||||||
|
raise EOFError
|
||||||
|
try:
|
||||||
|
loader = self.opcodes[opcode]
|
||||||
|
except KeyError:
|
||||||
|
raise Corruption("unkown opcode %s" % (opcode,))
|
||||||
|
loader(self)
|
||||||
|
except _Stop:
|
||||||
|
if len(self.stack) != 1:
|
||||||
|
raise UnserializationError("internal unserialization error")
|
||||||
|
return self.stack[0]
|
||||||
|
else:
|
||||||
|
raise Corruption("didn't get STOP")
|
||||||
|
|
||||||
|
opcodes = {}
|
||||||
|
|
||||||
|
def load_int(self):
|
||||||
|
i = self._read_int4()
|
||||||
|
self.stack.append(i)
|
||||||
|
opcodes[INT] = load_int
|
||||||
|
|
||||||
|
def _read_int4(self):
|
||||||
|
return _int4_format.unpack(self.stream.read(4))[0]
|
||||||
|
|
||||||
|
def _read_byte_string(self):
|
||||||
|
length = self._read_int4()
|
||||||
|
as_bytes = self.stream.read(length)
|
||||||
|
return as_bytes
|
||||||
|
|
||||||
|
def load_py3string(self):
|
||||||
|
as_bytes = self._read_byte_string()
|
||||||
|
if (not _INPY3 and self.options.py3_strings_as_str) and not _REALLY_PY3:
|
||||||
|
# XXX Should we try to decode into latin-1?
|
||||||
|
self.stack.append(as_bytes)
|
||||||
|
else:
|
||||||
|
self.stack.append(as_bytes.decode("utf-8"))
|
||||||
|
opcodes[PY3STRING] = load_py3string
|
||||||
|
|
||||||
|
def load_py2string(self):
|
||||||
|
as_bytes = self._read_byte_string()
|
||||||
|
if (_INPY3 and self.options.py2_strings_as_str) or \
|
||||||
|
(_REALLY_PY3 and not _INPY3):
|
||||||
|
s = as_bytes.decode("latin-1")
|
||||||
|
else:
|
||||||
|
s = as_bytes
|
||||||
|
self.stack.append(s)
|
||||||
|
opcodes[PY2STRING] = load_py2string
|
||||||
|
|
||||||
|
def load_bytes(self):
|
||||||
|
s = bytes(self._read_byte_string())
|
||||||
|
self.stack.append(s)
|
||||||
|
opcodes[BYTES] = load_bytes
|
||||||
|
|
||||||
|
def load_unicode(self):
|
||||||
|
self.stack.append(self._read_byte_string().decode("utf-8"))
|
||||||
|
opcodes[UNICODE] = load_unicode
|
||||||
|
|
||||||
|
def load_newlist(self):
|
||||||
|
length = self._read_int4()
|
||||||
|
self.stack.append([None] * length)
|
||||||
|
opcodes[NEWLIST] = load_newlist
|
||||||
|
|
||||||
|
def load_setitem(self):
|
||||||
|
if len(self.stack) < 3:
|
||||||
|
raise Corruption("not enough items for setitem")
|
||||||
|
value = self.stack.pop()
|
||||||
|
key = self.stack.pop()
|
||||||
|
self.stack[-1][key] = value
|
||||||
|
opcodes[SETITEM] = load_setitem
|
||||||
|
|
||||||
|
def load_newdict(self):
|
||||||
|
self.stack.append({})
|
||||||
|
opcodes[NEWDICT] = load_newdict
|
||||||
|
|
||||||
|
def load_buildtuple(self):
|
||||||
|
length = self._read_int4()
|
||||||
|
tup = tuple(self.stack[-length:])
|
||||||
|
del self.stack[-length:]
|
||||||
|
self.stack.append(tup)
|
||||||
|
opcodes[BUILDTUPLE] = load_buildtuple
|
||||||
|
|
||||||
|
def load_stop(self):
|
||||||
|
raise _Stop
|
||||||
|
opcodes[STOP] = load_stop
|
|
@ -0,0 +1,127 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import shutil
|
||||||
|
import py
|
||||||
|
from py.__.execnet import serializer
|
||||||
|
|
||||||
|
def setup_module(mod):
|
||||||
|
mod._save_python3 = serializer._INPY3
|
||||||
|
|
||||||
|
def teardown_module(mod):
|
||||||
|
serializer._setup_version_dependent_constants()
|
||||||
|
|
||||||
|
def _dump(obj):
|
||||||
|
stream = py.io.BytesIO()
|
||||||
|
saver = serializer.Serializer(stream)
|
||||||
|
saver.save(obj)
|
||||||
|
return stream.getvalue()
|
||||||
|
|
||||||
|
def _load(serialized, str_coerion):
|
||||||
|
stream = py.io.BytesIO(serialized)
|
||||||
|
opts = serializer.UnserializationOptions(str_coerion)
|
||||||
|
unserializer = serializer.Unserializer(stream, opts)
|
||||||
|
return unserializer.load()
|
||||||
|
|
||||||
|
def _run_in_version(is_py3, func, *args):
|
||||||
|
serializer._INPY3 = is_py3
|
||||||
|
serializer._setup_version_dependent_constants()
|
||||||
|
try:
|
||||||
|
return func(*args)
|
||||||
|
finally:
|
||||||
|
serializer._INPY3 = _save_python3
|
||||||
|
|
||||||
|
def dump_py2(obj):
|
||||||
|
return _run_in_version(False, _dump, obj)
|
||||||
|
|
||||||
|
def dump_py3(obj):
|
||||||
|
return _run_in_version(True, _dump, obj)
|
||||||
|
|
||||||
|
def load_py2(serialized, str_coercion=False):
|
||||||
|
return _run_in_version(False, _load, serialized, str_coercion)
|
||||||
|
|
||||||
|
def load_py3(serialized, str_coercion=False):
|
||||||
|
return _run_in_version(True, _load, serialized, str_coercion)
|
||||||
|
|
||||||
|
try:
|
||||||
|
bytes
|
||||||
|
except NameError:
|
||||||
|
bytes = str
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_funcarg__py2(request):
|
||||||
|
return _py2_wrapper
|
||||||
|
|
||||||
|
def pytest_funcarg__py3(request):
|
||||||
|
return _py3_wrapper
|
||||||
|
|
||||||
|
class TestSerializer:
|
||||||
|
|
||||||
|
def test_int(self):
|
||||||
|
for dump in dump_py2, dump_py3:
|
||||||
|
p = dump_py2(4)
|
||||||
|
for load in load_py2, load_py3:
|
||||||
|
i = load(p)
|
||||||
|
assert isinstance(i, int)
|
||||||
|
assert i == 4
|
||||||
|
py.test.raises(serializer.SerializationError, dump, 123456678900)
|
||||||
|
|
||||||
|
def test_bytes(self):
|
||||||
|
for dump in dump_py2, dump_py3:
|
||||||
|
p = dump(serializer._b('hi'))
|
||||||
|
for load in load_py2, load_py3:
|
||||||
|
s = load(p)
|
||||||
|
assert isinstance(s, serializer.bytes)
|
||||||
|
assert s == serializer._b('hi')
|
||||||
|
|
||||||
|
def check_sequence(self, seq):
|
||||||
|
for dump in dump_py2, dump_py3:
|
||||||
|
p = dump(seq)
|
||||||
|
for load in load_py2, load_py3:
|
||||||
|
l = load(p)
|
||||||
|
assert l == seq
|
||||||
|
|
||||||
|
def test_list(self):
|
||||||
|
self.check_sequence([1, 2, 3])
|
||||||
|
|
||||||
|
@py.test.mark.xfail
|
||||||
|
# I'm not sure if we need the complexity.
|
||||||
|
def test_recursive_list(self):
|
||||||
|
l = [1, 2, 3]
|
||||||
|
l.append(l)
|
||||||
|
self.check_sequence(l)
|
||||||
|
|
||||||
|
def test_tuple(self):
|
||||||
|
self.check_sequence((1, 2, 3))
|
||||||
|
|
||||||
|
def test_dict(self):
|
||||||
|
for dump in dump_py2, dump_py3:
|
||||||
|
p = dump({"hi" : 2, (1, 2, 3) : 32})
|
||||||
|
for load in load_py2, load_py3:
|
||||||
|
d = load(p, True)
|
||||||
|
assert d == {"hi" : 2, (1, 2, 3) : 32}
|
||||||
|
|
||||||
|
def test_string(self):
|
||||||
|
py.test.skip("will rewrite")
|
||||||
|
p = dump_py2("xyz")
|
||||||
|
s = load_py2(p)
|
||||||
|
assert isinstance(s, str)
|
||||||
|
assert s == "xyz"
|
||||||
|
s = load_py3(p)
|
||||||
|
assert isinstance(s, bytes)
|
||||||
|
assert s == serializer.b("xyz")
|
||||||
|
p = dump_py2("xyz")
|
||||||
|
s = load_py3(p, True)
|
||||||
|
assert isinstance(s, serializer._unicode)
|
||||||
|
assert s == serializer.unicode("xyz")
|
||||||
|
p = dump_py3("xyz")
|
||||||
|
s = load_py2(p, True)
|
||||||
|
assert isinstance(s, str)
|
||||||
|
assert s == "xyz"
|
||||||
|
|
||||||
|
def test_unicode(self):
|
||||||
|
py.test.skip("will rewrite")
|
||||||
|
for dump, uni in (dump_py2, serializer._unicode), (dump_py3, str):
|
||||||
|
p = dump(uni("xyz"))
|
||||||
|
for load in load_py2, load_py3:
|
||||||
|
s = load(p)
|
||||||
|
assert isinstance(s, serializer._unicode)
|
||||||
|
assert s == serializer._unicode("xyz")
|
Loading…
Reference in New Issue