1# Implementat marshal.loads() in pure Python
2
3import ast
4
5from typing import Any, Tuple
6
7
8class Type:
9    # Adapted from marshal.c
10    NULL                = ord('0')
11    NONE                = ord('N')
12    FALSE               = ord('F')
13    TRUE                = ord('T')
14    STOPITER            = ord('S')
15    ELLIPSIS            = ord('.')
16    INT                 = ord('i')
17    INT64               = ord('I')
18    FLOAT               = ord('f')
19    BINARY_FLOAT        = ord('g')
20    COMPLEX             = ord('x')
21    BINARY_COMPLEX      = ord('y')
22    LONG                = ord('l')
23    STRING              = ord('s')
24    INTERNED            = ord('t')
25    REF                 = ord('r')
26    TUPLE               = ord('(')
27    LIST                = ord('[')
28    DICT                = ord('{')
29    CODE                = ord('c')
30    UNICODE             = ord('u')
31    UNKNOWN             = ord('?')
32    SET                 = ord('<')
33    FROZENSET           = ord('>')
34    ASCII               = ord('a')
35    ASCII_INTERNED      = ord('A')
36    SMALL_TUPLE         = ord(')')
37    SHORT_ASCII         = ord('z')
38    SHORT_ASCII_INTERNED = ord('Z')
39
40
41FLAG_REF = 0x80  # with a type, add obj to index
42
43NULL = object()  # marker
44
45# Cell kinds
46CO_FAST_LOCAL = 0x20
47CO_FAST_CELL = 0x40
48CO_FAST_FREE = 0x80
49
50
51class Code:
52    def __init__(self, **kwds: Any):
53        self.__dict__.update(kwds)
54
55    def __repr__(self) -> str:
56        return f"Code(**{self.__dict__})"
57
58    co_localsplusnames: Tuple[str]
59    co_localspluskinds: Tuple[int]
60
61    def get_localsplus_names(self, select_kind: int) -> Tuple[str, ...]:
62        varnames: list[str] = []
63        for name, kind in zip(self.co_localsplusnames,
64                              self.co_localspluskinds):
65            if kind & select_kind:
66                varnames.append(name)
67        return tuple(varnames)
68
69    @property
70    def co_varnames(self) -> Tuple[str, ...]:
71        return self.get_localsplus_names(CO_FAST_LOCAL)
72
73    @property
74    def co_cellvars(self) -> Tuple[str, ...]:
75        return self.get_localsplus_names(CO_FAST_CELL)
76
77    @property
78    def co_freevars(self) -> Tuple[str, ...]:
79        return self.get_localsplus_names(CO_FAST_FREE)
80
81    @property
82    def co_nlocals(self) -> int:
83        return len(self.co_varnames)
84
85
86class Reader:
87    # A fairly literal translation of the marshal reader.
88
89    def __init__(self, data: bytes):
90        self.data: bytes = data
91        self.end: int = len(self.data)
92        self.pos: int = 0
93        self.refs: list[Any] = []
94        self.level: int = 0
95
96    def r_string(self, n: int) -> bytes:
97        assert 0 <= n <= self.end - self.pos
98        buf = self.data[self.pos : self.pos + n]
99        self.pos += n
100        return buf
101
102    def r_byte(self) -> int:
103        buf = self.r_string(1)
104        return buf[0]
105
106    def r_short(self) -> int:
107        buf = self.r_string(2)
108        x = buf[0]
109        x |= buf[1] << 8
110        x |= -(x & (1<<15))  # Sign-extend
111        return x
112
113    def r_long(self) -> int:
114        buf = self.r_string(4)
115        x = buf[0]
116        x |= buf[1] << 8
117        x |= buf[2] << 16
118        x |= buf[3] << 24
119        x |= -(x & (1<<31))  # Sign-extend
120        return x
121
122    def r_long64(self) -> int:
123        buf = self.r_string(8)
124        x = buf[0]
125        x |= buf[1] << 8
126        x |= buf[2] << 16
127        x |= buf[3] << 24
128        x |= buf[1] << 32
129        x |= buf[1] << 40
130        x |= buf[1] << 48
131        x |= buf[1] << 56
132        x |= -(x & (1<<63))  # Sign-extend
133        return x
134
135    def r_PyLong(self) -> int:
136        n = self.r_long()
137        size = abs(n)
138        x = 0
139        # Pray this is right
140        for i in range(size):
141            x |= self.r_short() << i*15
142        if n < 0:
143            x = -x
144        return x
145
146    def r_float_bin(self) -> float:
147        buf = self.r_string(8)
148        import struct  # Lazy import to avoid breaking UNIX build
149        return struct.unpack("d", buf)[0]
150
151    def r_float_str(self) -> float:
152        n = self.r_byte()
153        buf = self.r_string(n)
154        return ast.literal_eval(buf.decode("ascii"))
155
156    def r_ref_reserve(self, flag: int) -> int:
157        if flag:
158            idx = len(self.refs)
159            self.refs.append(None)
160            return idx
161        else:
162            return 0
163
164    def r_ref_insert(self, obj: Any, idx: int, flag: int) -> Any:
165        if flag:
166            self.refs[idx] = obj
167        return obj
168
169    def r_ref(self, obj: Any, flag: int) -> Any:
170        assert flag & FLAG_REF
171        self.refs.append(obj)
172        return obj
173
174    def r_object(self) -> Any:
175        old_level = self.level
176        try:
177            return self._r_object()
178        finally:
179            self.level = old_level
180
181    def _r_object(self) -> Any:
182        code = self.r_byte()
183        flag = code & FLAG_REF
184        type = code & ~FLAG_REF
185        # print("  "*self.level + f"{code} {flag} {type} {chr(type)!r}")
186        self.level += 1
187
188        def R_REF(obj: Any) -> Any:
189            if flag:
190                obj = self.r_ref(obj, flag)
191            return obj
192
193        if type == Type.NULL:
194            return NULL
195        elif type == Type.NONE:
196            return None
197        elif type == Type.ELLIPSIS:
198            return Ellipsis
199        elif type == Type.FALSE:
200            return False
201        elif type == Type.TRUE:
202            return True
203        elif type == Type.INT:
204            return R_REF(self.r_long())
205        elif type == Type.INT64:
206            return R_REF(self.r_long64())
207        elif type == Type.LONG:
208            return R_REF(self.r_PyLong())
209        elif type == Type.FLOAT:
210            return R_REF(self.r_float_str())
211        elif type == Type.BINARY_FLOAT:
212            return R_REF(self.r_float_bin())
213        elif type == Type.COMPLEX:
214            return R_REF(complex(self.r_float_str(),
215                                    self.r_float_str()))
216        elif type == Type.BINARY_COMPLEX:
217            return R_REF(complex(self.r_float_bin(),
218                                    self.r_float_bin()))
219        elif type == Type.STRING:
220            n = self.r_long()
221            return R_REF(self.r_string(n))
222        elif type == Type.ASCII_INTERNED or type == Type.ASCII:
223            n = self.r_long()
224            return R_REF(self.r_string(n).decode("ascii"))
225        elif type == Type.SHORT_ASCII_INTERNED or type == Type.SHORT_ASCII:
226            n = self.r_byte()
227            return R_REF(self.r_string(n).decode("ascii"))
228        elif type == Type.INTERNED or type == Type.UNICODE:
229            n = self.r_long()
230            return R_REF(self.r_string(n).decode("utf8", "surrogatepass"))
231        elif type == Type.SMALL_TUPLE:
232            n = self.r_byte()
233            idx = self.r_ref_reserve(flag)
234            retval: Any = tuple(self.r_object() for _ in range(n))
235            self.r_ref_insert(retval, idx, flag)
236            return retval
237        elif type == Type.TUPLE:
238            n = self.r_long()
239            idx = self.r_ref_reserve(flag)
240            retval = tuple(self.r_object() for _ in range(n))
241            self.r_ref_insert(retval, idx, flag)
242            return retval
243        elif type == Type.LIST:
244            n = self.r_long()
245            retval = R_REF([])
246            for _ in range(n):
247                retval.append(self.r_object())
248            return retval
249        elif type == Type.DICT:
250            retval = R_REF({})
251            while True:
252                key = self.r_object()
253                if key == NULL:
254                    break
255                val = self.r_object()
256                retval[key] = val
257            return retval
258        elif type == Type.SET:
259            n = self.r_long()
260            retval = R_REF(set())
261            for _ in range(n):
262                v = self.r_object()
263                retval.add(v)
264            return retval
265        elif type == Type.FROZENSET:
266            n = self.r_long()
267            s: set[Any] = set()
268            idx = self.r_ref_reserve(flag)
269            for _ in range(n):
270                v = self.r_object()
271                s.add(v)
272            retval = frozenset(s)
273            self.r_ref_insert(retval, idx, flag)
274            return retval
275        elif type == Type.CODE:
276            retval = R_REF(Code())
277            retval.co_argcount = self.r_long()
278            retval.co_posonlyargcount = self.r_long()
279            retval.co_kwonlyargcount = self.r_long()
280            retval.co_stacksize = self.r_long()
281            retval.co_flags = self.r_long()
282            retval.co_code = self.r_object()
283            retval.co_consts = self.r_object()
284            retval.co_names = self.r_object()
285            retval.co_localsplusnames = self.r_object()
286            retval.co_localspluskinds = self.r_object()
287            retval.co_filename = self.r_object()
288            retval.co_name = self.r_object()
289            retval.co_qualname = self.r_object()
290            retval.co_firstlineno = self.r_long()
291            retval.co_linetable = self.r_object()
292            retval.co_exceptiontable = self.r_object()
293            return retval
294        elif type == Type.REF:
295            n = self.r_long()
296            retval = self.refs[n]
297            assert retval is not None
298            return retval
299        else:
300            breakpoint()
301            raise AssertionError(f"Unknown type {type} {chr(type)!r}")
302
303
304def loads(data: bytes) -> Any:
305    assert isinstance(data, bytes)
306    r = Reader(data)
307    return r.r_object()
308
309
310def main():
311    # Test
312    import marshal, pprint
313    sample = {'foo': {(42, "bar", 3.14)}}
314    data = marshal.dumps(sample)
315    retval = loads(data)
316    assert retval == sample, retval
317    sample = main.__code__
318    data = marshal.dumps(sample)
319    retval = loads(data)
320    assert isinstance(retval, Code), retval
321    pprint.pprint(retval.__dict__)
322
323
324if __name__ == "__main__":
325    main()
326