1"""Deep freeze
2
3The script may be executed by _bootstrap_python interpreter.
4Shared library extension modules are not available in that case.
5On Windows, and in cross-compilation cases, it is executed
6by Python 3.10, and 3.11 features are not available.
7"""
8import argparse
9import ast
10import builtins
11import collections
12import contextlib
13import os
14import re
15import time
16import types
17from typing import Dict, FrozenSet, TextIO, Tuple
18
19import umarshal
20from generate_global_objects import get_identifiers_and_strings
21
22verbose = False
23identifiers, strings = get_identifiers_and_strings()
24
25# This must be kept in sync with opcode.py
26RESUME = 151
27
28def isprintable(b: bytes) -> bool:
29    return all(0x20 <= c < 0x7f for c in b)
30
31
32def make_string_literal(b: bytes) -> str:
33    res = ['"']
34    if isprintable(b):
35        res.append(b.decode("ascii").replace("\\", "\\\\").replace("\"", "\\\""))
36    else:
37        for i in b:
38            res.append(f"\\x{i:02x}")
39    res.append('"')
40    return "".join(res)
41
42
43CO_FAST_LOCAL = 0x20
44CO_FAST_CELL = 0x40
45CO_FAST_FREE = 0x80
46
47
48def get_localsplus(code: types.CodeType):
49    a = collections.defaultdict(int)
50    for name in code.co_varnames:
51        a[name] |= CO_FAST_LOCAL
52    for name in code.co_cellvars:
53        a[name] |= CO_FAST_CELL
54    for name in code.co_freevars:
55        a[name] |= CO_FAST_FREE
56    return tuple(a.keys()), bytes(a.values())
57
58
59def get_localsplus_counts(code: types.CodeType,
60                          names: Tuple[str, ...],
61                          kinds: bytes) -> Tuple[int, int, int, int]:
62    nlocals = 0
63    nplaincellvars = 0
64    ncellvars = 0
65    nfreevars = 0
66    assert len(names) == len(kinds)
67    for name, kind in zip(names, kinds):
68        if kind & CO_FAST_LOCAL:
69            nlocals += 1
70            if kind & CO_FAST_CELL:
71                ncellvars += 1
72        elif kind & CO_FAST_CELL:
73            ncellvars += 1
74            nplaincellvars += 1
75        elif kind & CO_FAST_FREE:
76            nfreevars += 1
77    assert nlocals == len(code.co_varnames) == code.co_nlocals, \
78        (nlocals, len(code.co_varnames), code.co_nlocals)
79    assert ncellvars == len(code.co_cellvars)
80    assert nfreevars == len(code.co_freevars)
81    assert len(names) == nlocals + nplaincellvars + nfreevars
82    return nlocals, nplaincellvars, ncellvars, nfreevars
83
84
85PyUnicode_1BYTE_KIND = 1
86PyUnicode_2BYTE_KIND = 2
87PyUnicode_4BYTE_KIND = 4
88
89
90def analyze_character_width(s: str) -> Tuple[int, bool]:
91    maxchar = ' '
92    for c in s:
93        maxchar = max(maxchar, c)
94    ascii = False
95    if maxchar <= '\xFF':
96        kind = PyUnicode_1BYTE_KIND
97        ascii = maxchar <= '\x7F'
98    elif maxchar <= '\uFFFF':
99        kind = PyUnicode_2BYTE_KIND
100    else:
101        kind = PyUnicode_4BYTE_KIND
102    return kind, ascii
103
104
105def removesuffix(base: str, suffix: str) -> str:
106    if base.endswith(suffix):
107        return base[:len(base) - len(suffix)]
108    return base
109
110class Printer:
111
112    def __init__(self, file: TextIO) -> None:
113        self.level = 0
114        self.file = file
115        self.cache: Dict[tuple[type, object, str], str] = {}
116        self.hits, self.misses = 0, 0
117        self.patchups: list[str] = []
118        self.deallocs: list[str] = []
119        self.interns: list[str] = []
120        self.write('#include "Python.h"')
121        self.write('#include "internal/pycore_gc.h"')
122        self.write('#include "internal/pycore_code.h"')
123        self.write('#include "internal/pycore_long.h"')
124        self.write("")
125
126    @contextlib.contextmanager
127    def indent(self) -> None:
128        save_level = self.level
129        try:
130            self.level += 1
131            yield
132        finally:
133            self.level = save_level
134
135    def write(self, arg: str) -> None:
136        self.file.writelines(("    "*self.level, arg, "\n"))
137
138    @contextlib.contextmanager
139    def block(self, prefix: str, suffix: str = "") -> None:
140        self.write(prefix + " {")
141        with self.indent():
142            yield
143        self.write("}" + suffix)
144
145    def object_head(self, typename: str) -> None:
146        with self.block(".ob_base =", ","):
147            self.write(f".ob_refcnt = 999999999,")
148            self.write(f".ob_type = &{typename},")
149
150    def object_var_head(self, typename: str, size: int) -> None:
151        with self.block(".ob_base =", ","):
152            self.object_head(typename)
153            self.write(f".ob_size = {size},")
154
155    def field(self, obj: object, name: str) -> None:
156        self.write(f".{name} = {getattr(obj, name)},")
157
158    def generate_bytes(self, name: str, b: bytes) -> str:
159        if b == b"":
160            return "(PyObject *)&_Py_SINGLETON(bytes_empty)"
161        if len(b) == 1:
162            return f"(PyObject *)&_Py_SINGLETON(bytes_characters[{b[0]}])"
163        self.write("static")
164        with self.indent():
165            with self.block("struct"):
166                self.write("PyObject_VAR_HEAD")
167                self.write("Py_hash_t ob_shash;")
168                self.write(f"char ob_sval[{len(b) + 1}];")
169        with self.block(f"{name} =", ";"):
170            self.object_var_head("PyBytes_Type", len(b))
171            self.write(".ob_shash = -1,")
172            self.write(f".ob_sval = {make_string_literal(b)},")
173        return f"& {name}.ob_base.ob_base"
174
175    def generate_unicode(self, name: str, s: str) -> str:
176        if s in strings:
177            return f"&_Py_STR({strings[s]})"
178        if s in identifiers:
179            return f"&_Py_ID({s})"
180        if re.match(r'\A[A-Za-z0-9_]+\Z', s):
181            name = f"const_str_{s}"
182        kind, ascii = analyze_character_width(s)
183        if kind == PyUnicode_1BYTE_KIND:
184            datatype = "uint8_t"
185        elif kind == PyUnicode_2BYTE_KIND:
186            datatype = "uint16_t"
187        else:
188            datatype = "uint32_t"
189        self.write("static")
190        with self.indent():
191            with self.block("struct"):
192                if ascii:
193                    self.write("PyASCIIObject _ascii;")
194                else:
195                    self.write("PyCompactUnicodeObject _compact;")
196                self.write(f"{datatype} _data[{len(s)+1}];")
197        self.deallocs.append(f"_PyStaticUnicode_Dealloc((PyObject *)&{name});")
198        with self.block(f"{name} =", ";"):
199            if ascii:
200                with self.block("._ascii =", ","):
201                    self.object_head("PyUnicode_Type")
202                    self.write(f".length = {len(s)},")
203                    self.write(".hash = -1,")
204                    with self.block(".state =", ","):
205                        self.write(".kind = 1,")
206                        self.write(".compact = 1,")
207                        self.write(".ascii = 1,")
208                        self.write(".ready = 1,")
209                self.write(f"._data = {make_string_literal(s.encode('ascii'))},")
210                return f"& {name}._ascii.ob_base"
211            else:
212                with self.block("._compact =", ","):
213                    with self.block("._base =", ","):
214                        self.object_head("PyUnicode_Type")
215                        self.write(f".length = {len(s)},")
216                        self.write(".hash = -1,")
217                        with self.block(".state =", ","):
218                            self.write(f".kind = {kind},")
219                            self.write(".compact = 1,")
220                            self.write(".ascii = 0,")
221                            self.write(".ready = 1,")
222                with self.block(f"._data =", ","):
223                    for i in range(0, len(s), 16):
224                        data = s[i:i+16]
225                        self.write(", ".join(map(str, map(ord, data))) + ",")
226                if kind == PyUnicode_2BYTE_KIND:
227                    self.patchups.append("if (sizeof(wchar_t) == 2) {")
228                    self.patchups.append(f"    {name}._compact._base.wstr = (wchar_t *) {name}._data;")
229                    self.patchups.append(f"    {name}._compact.wstr_length = {len(s)};")
230                    self.patchups.append("}")
231                if kind == PyUnicode_4BYTE_KIND:
232                    self.patchups.append("if (sizeof(wchar_t) == 4) {")
233                    self.patchups.append(f"    {name}._compact._base.wstr = (wchar_t *) {name}._data;")
234                    self.patchups.append(f"    {name}._compact.wstr_length = {len(s)};")
235                    self.patchups.append("}")
236                return f"& {name}._compact._base.ob_base"
237
238
239    def generate_code(self, name: str, code: types.CodeType) -> str:
240        # The ordering here matches PyCode_NewWithPosOnlyArgs()
241        # (but see below).
242        co_consts = self.generate(name + "_consts", code.co_consts)
243        co_names = self.generate(name + "_names", code.co_names)
244        co_filename = self.generate(name + "_filename", code.co_filename)
245        co_name = self.generate(name + "_name", code.co_name)
246        co_qualname = self.generate(name + "_qualname", code.co_qualname)
247        co_linetable = self.generate(name + "_linetable", code.co_linetable)
248        co_exceptiontable = self.generate(name + "_exceptiontable", code.co_exceptiontable)
249        # These fields are not directly accessible
250        localsplusnames, localspluskinds = get_localsplus(code)
251        co_localsplusnames = self.generate(name + "_localsplusnames", localsplusnames)
252        co_localspluskinds = self.generate(name + "_localspluskinds", localspluskinds)
253        # Derived values
254        nlocals, nplaincellvars, ncellvars, nfreevars = \
255            get_localsplus_counts(code, localsplusnames, localspluskinds)
256        co_code_adaptive = make_string_literal(code.co_code)
257        self.write("static")
258        with self.indent():
259            self.write(f"struct _PyCode_DEF({len(code.co_code)})")
260        with self.block(f"{name} =", ";"):
261            self.object_var_head("PyCode_Type", len(code.co_code) // 2)
262            # But the ordering here must match that in cpython/code.h
263            # (which is a pain because we tend to reorder those for perf)
264            # otherwise MSVC doesn't like it.
265            self.write(f".co_consts = {co_consts},")
266            self.write(f".co_names = {co_names},")
267            self.write(f".co_exceptiontable = {co_exceptiontable},")
268            self.field(code, "co_flags")
269            self.write(".co_warmup = QUICKENING_INITIAL_WARMUP_VALUE,")
270            self.write("._co_linearray_entry_size = 0,")
271            self.field(code, "co_argcount")
272            self.field(code, "co_posonlyargcount")
273            self.field(code, "co_kwonlyargcount")
274            self.field(code, "co_stacksize")
275            self.field(code, "co_firstlineno")
276            self.write(f".co_nlocalsplus = {len(localsplusnames)},")
277            self.field(code, "co_nlocals")
278            self.write(f".co_nplaincellvars = {nplaincellvars},")
279            self.write(f".co_ncellvars = {ncellvars},")
280            self.write(f".co_nfreevars = {nfreevars},")
281            self.write(f".co_localsplusnames = {co_localsplusnames},")
282            self.write(f".co_localspluskinds = {co_localspluskinds},")
283            self.write(f".co_filename = {co_filename},")
284            self.write(f".co_name = {co_name},")
285            self.write(f".co_qualname = {co_qualname},")
286            self.write(f".co_linetable = {co_linetable},")
287            self.write(f"._co_code = NULL,")
288            self.write("._co_linearray = NULL,")
289            self.write(f".co_code_adaptive = {co_code_adaptive},")
290            for i, op in enumerate(code.co_code[::2]):
291                if op == RESUME:
292                    self.write(f"._co_firsttraceable = {i},")
293                    break
294        name_as_code = f"(PyCodeObject *)&{name}"
295        self.deallocs.append(f"_PyStaticCode_Dealloc({name_as_code});")
296        self.interns.append(f"_PyStaticCode_InternStrings({name_as_code})")
297        return f"& {name}.ob_base.ob_base"
298
299    def generate_tuple(self, name: str, t: Tuple[object, ...]) -> str:
300        if len(t) == 0:
301            return f"(PyObject *)& _Py_SINGLETON(tuple_empty)"
302        items = [self.generate(f"{name}_{i}", it) for i, it in enumerate(t)]
303        self.write("static")
304        with self.indent():
305            with self.block("struct"):
306                self.write("PyGC_Head _gc_head;")
307                with self.block("struct", "_object;"):
308                    self.write("PyObject_VAR_HEAD")
309                    if t:
310                        self.write(f"PyObject *ob_item[{len(t)}];")
311        with self.block(f"{name} =", ";"):
312            with self.block("._object =", ","):
313                self.object_var_head("PyTuple_Type", len(t))
314                if items:
315                    with self.block(f".ob_item =", ","):
316                        for item in items:
317                            self.write(item + ",")
318        return f"& {name}._object.ob_base.ob_base"
319
320    def _generate_int_for_bits(self, name: str, i: int, digit: int) -> None:
321        sign = -1 if i < 0 else 0 if i == 0 else +1
322        i = abs(i)
323        digits: list[int] = []
324        while i:
325            i, rem = divmod(i, digit)
326            digits.append(rem)
327        self.write("static")
328        with self.indent():
329            with self.block("struct"):
330                self.write("PyObject_VAR_HEAD")
331                self.write(f"digit ob_digit[{max(1, len(digits))}];")
332        with self.block(f"{name} =", ";"):
333            self.object_var_head("PyLong_Type", sign*len(digits))
334            if digits:
335                ds = ", ".join(map(str, digits))
336                self.write(f".ob_digit = {{ {ds} }},")
337
338    def generate_int(self, name: str, i: int) -> str:
339        if -5 <= i <= 256:
340            return f"(PyObject *)&_PyLong_SMALL_INTS[_PY_NSMALLNEGINTS + {i}]"
341        if i >= 0:
342            name = f"const_int_{i}"
343        else:
344            name = f"const_int_negative_{abs(i)}"
345        if abs(i) < 2**15:
346            self._generate_int_for_bits(name, i, 2**15)
347        else:
348            connective = "if"
349            for bits_in_digit in 15, 30:
350                self.write(f"#{connective} PYLONG_BITS_IN_DIGIT == {bits_in_digit}")
351                self._generate_int_for_bits(name, i, 2**bits_in_digit)
352                connective = "elif"
353            self.write("#else")
354            self.write('#error "PYLONG_BITS_IN_DIGIT should be 15 or 30"')
355            self.write("#endif")
356            # If neither clause applies, it won't compile
357        return f"& {name}.ob_base.ob_base"
358
359    def generate_float(self, name: str, x: float) -> str:
360        with self.block(f"static PyFloatObject {name} =", ";"):
361            self.object_head("PyFloat_Type")
362            self.write(f".ob_fval = {x},")
363        return f"&{name}.ob_base"
364
365    def generate_complex(self, name: str, z: complex) -> str:
366        with self.block(f"static PyComplexObject {name} =", ";"):
367            self.object_head("PyComplex_Type")
368            self.write(f".cval = {{ {z.real}, {z.imag} }},")
369        return f"&{name}.ob_base"
370
371    def generate_frozenset(self, name: str, fs: FrozenSet[object]) -> str:
372        try:
373            fs = sorted(fs)
374        except TypeError:
375            # frozen set with incompatible types, fallback to repr()
376            fs = sorted(fs, key=repr)
377        ret = self.generate_tuple(name, tuple(fs))
378        self.write("// TODO: The above tuple should be a frozenset")
379        return ret
380
381    def generate_file(self, module: str, code: object)-> None:
382        module = module.replace(".", "_")
383        self.generate(f"{module}_toplevel", code)
384        with self.block(f"static void {module}_do_patchups(void)"):
385            for p in self.patchups:
386                self.write(p)
387        self.patchups.clear()
388        self.write(EPILOGUE.replace("%%NAME%%", module))
389
390    def generate(self, name: str, obj: object) -> str:
391        # Use repr() in the key to distinguish -0.0 from +0.0
392        key = (type(obj), obj, repr(obj))
393        if key in self.cache:
394            self.hits += 1
395            # print(f"Cache hit {key!r:.40}: {self.cache[key]!r:.40}")
396            return self.cache[key]
397        self.misses += 1
398        if isinstance(obj, (types.CodeType, umarshal.Code)) :
399            val = self.generate_code(name, obj)
400        elif isinstance(obj, tuple):
401            val = self.generate_tuple(name, obj)
402        elif isinstance(obj, str):
403            val = self.generate_unicode(name, obj)
404        elif isinstance(obj, bytes):
405            val = self.generate_bytes(name, obj)
406        elif obj is True:
407            return "Py_True"
408        elif obj is False:
409            return "Py_False"
410        elif isinstance(obj, int):
411            val = self.generate_int(name, obj)
412        elif isinstance(obj, float):
413            val = self.generate_float(name, obj)
414        elif isinstance(obj, complex):
415            val = self.generate_complex(name, obj)
416        elif isinstance(obj, frozenset):
417            val = self.generate_frozenset(name, obj)
418        elif obj is builtins.Ellipsis:
419            return "Py_Ellipsis"
420        elif obj is None:
421            return "Py_None"
422        else:
423            raise TypeError(
424                f"Cannot generate code for {type(obj).__name__} object")
425        # print(f"Cache store {key!r:.40}: {val!r:.40}")
426        self.cache[key] = val
427        return val
428
429
430EPILOGUE = """
431PyObject *
432_Py_get_%%NAME%%_toplevel(void)
433{
434    %%NAME%%_do_patchups();
435    return Py_NewRef((PyObject *) &%%NAME%%_toplevel);
436}
437"""
438
439FROZEN_COMMENT_C = "/* Auto-generated by Programs/_freeze_module.c */"
440FROZEN_COMMENT_PY = "/* Auto-generated by Programs/_freeze_module.py */"
441
442FROZEN_DATA_LINE = r"\s*(\d+,\s*)+\s*"
443
444
445def is_frozen_header(source: str) -> bool:
446    return source.startswith((FROZEN_COMMENT_C, FROZEN_COMMENT_PY))
447
448
449def decode_frozen_data(source: str) -> types.CodeType:
450    lines = source.splitlines()
451    while lines and re.match(FROZEN_DATA_LINE, lines[0]) is None:
452        del lines[0]
453    while lines and re.match(FROZEN_DATA_LINE, lines[-1]) is None:
454        del lines[-1]
455    values: Tuple[int, ...] = ast.literal_eval("".join(lines).strip())
456    data = bytes(values)
457    return umarshal.loads(data)
458
459
460def generate(args: list[str], output: TextIO) -> None:
461    printer = Printer(output)
462    for arg in args:
463        file, modname = arg.rsplit(':', 1)
464        with open(file, "r", encoding="utf8") as fd:
465            source = fd.read()
466            if is_frozen_header(source):
467                code = decode_frozen_data(source)
468            else:
469                code = compile(fd.read(), f"<frozen {modname}>", "exec")
470            printer.generate_file(modname, code)
471    with printer.block(f"void\n_Py_Deepfreeze_Fini(void)"):
472        for p in printer.deallocs:
473            printer.write(p)
474    with printer.block(f"int\n_Py_Deepfreeze_Init(void)"):
475        for p in printer.interns:
476            with printer.block(f"if ({p} < 0)"):
477                printer.write("return -1;")
478        printer.write("return 0;")
479    if verbose:
480        print(f"Cache hits: {printer.hits}, misses: {printer.misses}")
481
482
483parser = argparse.ArgumentParser()
484parser.add_argument("-o", "--output", help="Defaults to deepfreeze.c", default="deepfreeze.c")
485parser.add_argument("-v", "--verbose", action="store_true", help="Print diagnostics")
486parser.add_argument('args', nargs="+", help="Input file and module name (required) in file:modname format")
487
488@contextlib.contextmanager
489def report_time(label: str):
490    t0 = time.time()
491    try:
492        yield
493    finally:
494        t1 = time.time()
495    if verbose:
496        print(f"{label}: {t1-t0:.3f} sec")
497
498
499def main() -> None:
500    global verbose
501    args = parser.parse_args()
502    verbose = args.verbose
503    output = args.output
504    with open(output, "w", encoding="utf-8") as file:
505        with report_time("generate"):
506            generate(args.args, file)
507    if verbose:
508        print(f"Wrote {os.path.getsize(output)} bytes to {output}")
509
510
511if __name__ == "__main__":
512    main()
513