1# Implementat marshal.loads() in pure Python 2 3import ast 4 5from typing import Any, Tuple 6 7 8class Type: 9 # Adapted from marshal.c 10 NULL = ord('0') 11 NONE = ord('N') 12 FALSE = ord('F') 13 TRUE = ord('T') 14 STOPITER = ord('S') 15 ELLIPSIS = ord('.') 16 INT = ord('i') 17 INT64 = ord('I') 18 FLOAT = ord('f') 19 BINARY_FLOAT = ord('g') 20 COMPLEX = ord('x') 21 BINARY_COMPLEX = ord('y') 22 LONG = ord('l') 23 STRING = ord('s') 24 INTERNED = ord('t') 25 REF = ord('r') 26 TUPLE = ord('(') 27 LIST = ord('[') 28 DICT = ord('{') 29 CODE = ord('c') 30 UNICODE = ord('u') 31 UNKNOWN = ord('?') 32 SET = ord('<') 33 FROZENSET = ord('>') 34 ASCII = ord('a') 35 ASCII_INTERNED = ord('A') 36 SMALL_TUPLE = ord(')') 37 SHORT_ASCII = ord('z') 38 SHORT_ASCII_INTERNED = ord('Z') 39 40 41FLAG_REF = 0x80 # with a type, add obj to index 42 43NULL = object() # marker 44 45# Cell kinds 46CO_FAST_LOCAL = 0x20 47CO_FAST_CELL = 0x40 48CO_FAST_FREE = 0x80 49 50 51class Code: 52 def __init__(self, **kwds: Any): 53 self.__dict__.update(kwds) 54 55 def __repr__(self) -> str: 56 return f"Code(**{self.__dict__})" 57 58 co_localsplusnames: Tuple[str] 59 co_localspluskinds: Tuple[int] 60 61 def get_localsplus_names(self, select_kind: int) -> Tuple[str, ...]: 62 varnames: list[str] = [] 63 for name, kind in zip(self.co_localsplusnames, 64 self.co_localspluskinds): 65 if kind & select_kind: 66 varnames.append(name) 67 return tuple(varnames) 68 69 @property 70 def co_varnames(self) -> Tuple[str, ...]: 71 return self.get_localsplus_names(CO_FAST_LOCAL) 72 73 @property 74 def co_cellvars(self) -> Tuple[str, ...]: 75 return self.get_localsplus_names(CO_FAST_CELL) 76 77 @property 78 def co_freevars(self) -> Tuple[str, ...]: 79 return self.get_localsplus_names(CO_FAST_FREE) 80 81 @property 82 def co_nlocals(self) -> int: 83 return len(self.co_varnames) 84 85 86class Reader: 87 # A fairly literal translation of the marshal reader. 88 89 def __init__(self, data: bytes): 90 self.data: bytes = data 91 self.end: int = len(self.data) 92 self.pos: int = 0 93 self.refs: list[Any] = [] 94 self.level: int = 0 95 96 def r_string(self, n: int) -> bytes: 97 assert 0 <= n <= self.end - self.pos 98 buf = self.data[self.pos : self.pos + n] 99 self.pos += n 100 return buf 101 102 def r_byte(self) -> int: 103 buf = self.r_string(1) 104 return buf[0] 105 106 def r_short(self) -> int: 107 buf = self.r_string(2) 108 x = buf[0] 109 x |= buf[1] << 8 110 x |= -(x & (1<<15)) # Sign-extend 111 return x 112 113 def r_long(self) -> int: 114 buf = self.r_string(4) 115 x = buf[0] 116 x |= buf[1] << 8 117 x |= buf[2] << 16 118 x |= buf[3] << 24 119 x |= -(x & (1<<31)) # Sign-extend 120 return x 121 122 def r_long64(self) -> int: 123 buf = self.r_string(8) 124 x = buf[0] 125 x |= buf[1] << 8 126 x |= buf[2] << 16 127 x |= buf[3] << 24 128 x |= buf[1] << 32 129 x |= buf[1] << 40 130 x |= buf[1] << 48 131 x |= buf[1] << 56 132 x |= -(x & (1<<63)) # Sign-extend 133 return x 134 135 def r_PyLong(self) -> int: 136 n = self.r_long() 137 size = abs(n) 138 x = 0 139 # Pray this is right 140 for i in range(size): 141 x |= self.r_short() << i*15 142 if n < 0: 143 x = -x 144 return x 145 146 def r_float_bin(self) -> float: 147 buf = self.r_string(8) 148 import struct # Lazy import to avoid breaking UNIX build 149 return struct.unpack("d", buf)[0] 150 151 def r_float_str(self) -> float: 152 n = self.r_byte() 153 buf = self.r_string(n) 154 return ast.literal_eval(buf.decode("ascii")) 155 156 def r_ref_reserve(self, flag: int) -> int: 157 if flag: 158 idx = len(self.refs) 159 self.refs.append(None) 160 return idx 161 else: 162 return 0 163 164 def r_ref_insert(self, obj: Any, idx: int, flag: int) -> Any: 165 if flag: 166 self.refs[idx] = obj 167 return obj 168 169 def r_ref(self, obj: Any, flag: int) -> Any: 170 assert flag & FLAG_REF 171 self.refs.append(obj) 172 return obj 173 174 def r_object(self) -> Any: 175 old_level = self.level 176 try: 177 return self._r_object() 178 finally: 179 self.level = old_level 180 181 def _r_object(self) -> Any: 182 code = self.r_byte() 183 flag = code & FLAG_REF 184 type = code & ~FLAG_REF 185 # print(" "*self.level + f"{code} {flag} {type} {chr(type)!r}") 186 self.level += 1 187 188 def R_REF(obj: Any) -> Any: 189 if flag: 190 obj = self.r_ref(obj, flag) 191 return obj 192 193 if type == Type.NULL: 194 return NULL 195 elif type == Type.NONE: 196 return None 197 elif type == Type.ELLIPSIS: 198 return Ellipsis 199 elif type == Type.FALSE: 200 return False 201 elif type == Type.TRUE: 202 return True 203 elif type == Type.INT: 204 return R_REF(self.r_long()) 205 elif type == Type.INT64: 206 return R_REF(self.r_long64()) 207 elif type == Type.LONG: 208 return R_REF(self.r_PyLong()) 209 elif type == Type.FLOAT: 210 return R_REF(self.r_float_str()) 211 elif type == Type.BINARY_FLOAT: 212 return R_REF(self.r_float_bin()) 213 elif type == Type.COMPLEX: 214 return R_REF(complex(self.r_float_str(), 215 self.r_float_str())) 216 elif type == Type.BINARY_COMPLEX: 217 return R_REF(complex(self.r_float_bin(), 218 self.r_float_bin())) 219 elif type == Type.STRING: 220 n = self.r_long() 221 return R_REF(self.r_string(n)) 222 elif type == Type.ASCII_INTERNED or type == Type.ASCII: 223 n = self.r_long() 224 return R_REF(self.r_string(n).decode("ascii")) 225 elif type == Type.SHORT_ASCII_INTERNED or type == Type.SHORT_ASCII: 226 n = self.r_byte() 227 return R_REF(self.r_string(n).decode("ascii")) 228 elif type == Type.INTERNED or type == Type.UNICODE: 229 n = self.r_long() 230 return R_REF(self.r_string(n).decode("utf8", "surrogatepass")) 231 elif type == Type.SMALL_TUPLE: 232 n = self.r_byte() 233 idx = self.r_ref_reserve(flag) 234 retval: Any = tuple(self.r_object() for _ in range(n)) 235 self.r_ref_insert(retval, idx, flag) 236 return retval 237 elif type == Type.TUPLE: 238 n = self.r_long() 239 idx = self.r_ref_reserve(flag) 240 retval = tuple(self.r_object() for _ in range(n)) 241 self.r_ref_insert(retval, idx, flag) 242 return retval 243 elif type == Type.LIST: 244 n = self.r_long() 245 retval = R_REF([]) 246 for _ in range(n): 247 retval.append(self.r_object()) 248 return retval 249 elif type == Type.DICT: 250 retval = R_REF({}) 251 while True: 252 key = self.r_object() 253 if key == NULL: 254 break 255 val = self.r_object() 256 retval[key] = val 257 return retval 258 elif type == Type.SET: 259 n = self.r_long() 260 retval = R_REF(set()) 261 for _ in range(n): 262 v = self.r_object() 263 retval.add(v) 264 return retval 265 elif type == Type.FROZENSET: 266 n = self.r_long() 267 s: set[Any] = set() 268 idx = self.r_ref_reserve(flag) 269 for _ in range(n): 270 v = self.r_object() 271 s.add(v) 272 retval = frozenset(s) 273 self.r_ref_insert(retval, idx, flag) 274 return retval 275 elif type == Type.CODE: 276 retval = R_REF(Code()) 277 retval.co_argcount = self.r_long() 278 retval.co_posonlyargcount = self.r_long() 279 retval.co_kwonlyargcount = self.r_long() 280 retval.co_stacksize = self.r_long() 281 retval.co_flags = self.r_long() 282 retval.co_code = self.r_object() 283 retval.co_consts = self.r_object() 284 retval.co_names = self.r_object() 285 retval.co_localsplusnames = self.r_object() 286 retval.co_localspluskinds = self.r_object() 287 retval.co_filename = self.r_object() 288 retval.co_name = self.r_object() 289 retval.co_qualname = self.r_object() 290 retval.co_firstlineno = self.r_long() 291 retval.co_linetable = self.r_object() 292 retval.co_exceptiontable = self.r_object() 293 return retval 294 elif type == Type.REF: 295 n = self.r_long() 296 retval = self.refs[n] 297 assert retval is not None 298 return retval 299 else: 300 breakpoint() 301 raise AssertionError(f"Unknown type {type} {chr(type)!r}") 302 303 304def loads(data: bytes) -> Any: 305 assert isinstance(data, bytes) 306 r = Reader(data) 307 return r.r_object() 308 309 310def main(): 311 # Test 312 import marshal, pprint 313 sample = {'foo': {(42, "bar", 3.14)}} 314 data = marshal.dumps(sample) 315 retval = loads(data) 316 assert retval == sample, retval 317 sample = main.__code__ 318 data = marshal.dumps(sample) 319 retval = loads(data) 320 assert isinstance(retval, Code), retval 321 pprint.pprint(retval.__dict__) 322 323 324if __name__ == "__main__": 325 main() 326