test_ok1/py/execnet/serializer.py

273 lines
7.3 KiB
Python
Executable File

"""
Simple marshal format (based on pickle) designed to work across Python versions.
"""
import sys
import struct
_INPY3 = _REALLY_PY3 = sys.version_info > (3, 0)
class SerializeError(Exception):
pass
class SerializationError(SerializeError):
"""Error while serializing an object."""
class UnserializableType(SerializationError):
"""Can't serialize a type."""
class UnserializationError(SerializeError):
"""Error while unserializing an object."""
class VersionMismatch(UnserializationError):
"""Data from a previous or later format."""
class Corruption(UnserializationError):
"""The pickle format appears to have been corrupted."""
if _INPY3:
def b(s):
return s.encode("ascii")
else:
b = str
FOUR_BYTE_INT_MAX = 2147483647
_int4_format = struct.Struct("!i")
_float_format = struct.Struct("!d")
# Protocol constants
VERSION_NUMBER = 1
VERSION = b(chr(VERSION_NUMBER))
PY2STRING = b('s')
PY3STRING = b('t')
UNICODE = b('u')
BYTES = b('b')
NEWLIST = b('l')
BUILDTUPLE = b('T')
SETITEM = b('m')
NEWDICT = b('d')
INT = b('i')
FLOAT = b('f')
STOP = b('S')
class CrossVersionOptions(object):
pass
class Serializer(object):
def __init__(self, stream):
self.stream = stream
def save(self, obj):
self.stream.write(VERSION)
self._save(obj)
self.stream.write(STOP)
def _save(self, obj):
tp = type(obj)
try:
dispatch = self.dispatch[tp]
except KeyError:
raise UnserializableType("can't serialize %s" % (tp,))
dispatch(self, obj)
dispatch = {}
def save_bytes(self, bytes_):
self.stream.write(BYTES)
self._write_byte_sequence(bytes_)
dispatch[bytes] = save_bytes
if _INPY3:
def save_string(self, s):
self.stream.write(PY3STRING)
self._write_unicode_string(s)
else:
def save_string(self, s):
self.stream.write(PY2STRING)
self._write_byte_sequence(s)
def save_unicode(self, s):
self.stream.write(UNICODE)
self._write_unicode_string(s)
dispatch[unicode] = save_unicode
dispatch[str] = save_string
def _write_unicode_string(self, s):
try:
as_bytes = s.encode("utf-8")
except UnicodeEncodeError:
raise SerializationError("strings must be utf-8 encodable")
self._write_byte_sequence(as_bytes)
def _write_byte_sequence(self, bytes_):
self._write_int4(len(bytes_), "string is too long")
self.stream.write(bytes_)
def save_int(self, i):
self.stream.write(INT)
self._write_int4(i)
dispatch[int] = save_int
def save_float(self, flt):
self.stream.write(FLOAT)
self.stream.write(_float_format.pack(flt))
dispatch[float] = save_float
def _write_int4(self, i, error="int must be less than %i" %
(FOUR_BYTE_INT_MAX,)):
if i > FOUR_BYTE_INT_MAX:
raise SerializationError(error)
self.stream.write(_int4_format.pack(i))
def save_list(self, L):
self.stream.write(NEWLIST)
self._write_int4(len(L), "list is too long")
for i, item in enumerate(L):
self._write_setitem(i, item)
dispatch[list] = save_list
def _write_setitem(self, key, value):
self._save(key)
self._save(value)
self.stream.write(SETITEM)
def save_dict(self, d):
self.stream.write(NEWDICT)
for key, value in d.items():
self._write_setitem(key, value)
dispatch[dict] = save_dict
def save_tuple(self, tup):
for item in tup:
self._save(item)
self.stream.write(BUILDTUPLE)
self._write_int4(len(tup), "tuple is too long")
dispatch[tuple] = save_tuple
class _UnserializationOptions(object):
pass
class _Py2UnserializationOptions(_UnserializationOptions):
def __init__(self, py3_strings_as_str=False):
self.py3_strings_as_str = py3_strings_as_str
class _Py3UnserializationOptions(_UnserializationOptions):
def __init__(self, py2_strings_as_str=False):
self.py2_strings_as_str = py2_strings_as_str
if _INPY3:
UnserializationOptions = _Py3UnserializationOptions
else:
UnserializationOptions = _Py2UnserializationOptions
class _Stop(Exception):
pass
class Unserializer(object):
def __init__(self, stream, options=UnserializationOptions()):
self.stream = stream
self.options = options
def load(self):
self.stack = []
version = ord(self.stream.read(1))
if version != VERSION_NUMBER:
raise VersionMismatch("%i != %i" % (version, VERSION_NUMBER))
try:
while True:
opcode = self.stream.read(1)
if not opcode:
raise EOFError
try:
loader = self.opcodes[opcode]
except KeyError:
raise Corruption("unkown opcode %s" % (opcode,))
loader(self)
except _Stop:
if len(self.stack) != 1:
raise UnserializationError("internal unserialization error")
return self.stack[0]
else:
raise Corruption("didn't get STOP")
opcodes = {}
def load_int(self):
i = self._read_int4()
self.stack.append(i)
opcodes[INT] = load_int
def load_float(self):
binary = self.stream.read(_float_format.size)
self.stack.append(_float_format.unpack(binary)[0])
opcodes[FLOAT] = load_float
def _read_int4(self):
return _int4_format.unpack(self.stream.read(4))[0]
def _read_byte_string(self):
length = self._read_int4()
as_bytes = self.stream.read(length)
return as_bytes
def load_py3string(self):
as_bytes = self._read_byte_string()
if not _INPY3 and self.options.py3_strings_as_str:
# XXX Should we try to decode into latin-1?
self.stack.append(as_bytes)
else:
self.stack.append(as_bytes.decode("utf-8"))
opcodes[PY3STRING] = load_py3string
def load_py2string(self):
as_bytes = self._read_byte_string()
if _INPY3 and self.options.py2_strings_as_str:
s = as_bytes.decode("latin-1")
else:
s = as_bytes
self.stack.append(s)
opcodes[PY2STRING] = load_py2string
def load_bytes(self):
s = self._read_byte_string()
self.stack.append(s)
opcodes[BYTES] = load_bytes
def load_unicode(self):
self.stack.append(self._read_byte_string().decode("utf-8"))
opcodes[UNICODE] = load_unicode
def load_newlist(self):
length = self._read_int4()
self.stack.append([None] * length)
opcodes[NEWLIST] = load_newlist
def load_setitem(self):
if len(self.stack) < 3:
raise Corruption("not enough items for setitem")
value = self.stack.pop()
key = self.stack.pop()
self.stack[-1][key] = value
opcodes[SETITEM] = load_setitem
def load_newdict(self):
self.stack.append({})
opcodes[NEWDICT] = load_newdict
def load_buildtuple(self):
length = self._read_int4()
tup = tuple(self.stack[-length:])
del self.stack[-length:]
self.stack.append(tup)
opcodes[BUILDTUPLE] = load_buildtuple
def load_stop(self):
raise _Stop
opcodes[STOP] = load_stop