xref: /aosp_15_r20/external/pytorch/torch/utils/_freeze.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3"""
4Freeze Python packages.
5
6Freezing makes it possible to ship arbitrary Python modules as part of a C++
7library. The Python source of the module is compiled to bytecode and written
8to `.c` files, to be imported by Python's built-in FrozenImporter.
9
10In a normal Python installation, FrozenImporter is only used to bootstrap the
11initialization of the import machinery. Python's importers are defined in
12Python (see `_bootstrap.py` and `_bootstrap_external.py`) but need to be
13retrieved before any importers are available. Freezing the module bytecode
14resolves this circular dependency.
15
16This script will freeze the Python standard library. It produces two things:
17- Bytecode files: A set of `.c` that define C variables containing Python bytecode.
18- Main file: A `main.c` file listing all of these modules in the right form to be
19  consumed by FrozenImporter.
20
21The library that wishes to these modules make them available to the local
22Python instance by extending `PyImport_FrozenModules` appropriately (see
23https://docs.python.org/3/c-api/import.html#c.PyImport_FrozenModules).
24"""
25
26import argparse
27import functools
28import itertools
29import marshal
30import os
31import types
32from dataclasses import dataclass
33from pathlib import Path
34from typing import List
35
36
37PATH_MARKER = "<Generated by torch::deploy>"
38MAIN_INCLUDES = """#include <Python.h>
39
40"""
41
42MAIN_PREFIX_TEMPLATE = """
43// Compiled standard library modules. These should be appended to the existing
44// `PyImport_FrozenModules` that ships with CPython.
45struct _frozen {}[] = {{
46"""
47
48FAKE_PREFIX = MAIN_PREFIX_TEMPLATE.format("_PyImport_FrozenModules")
49
50MAIN_SUFFIX = """\
51    {0, 0, 0} /* sentinel */
52};
53"""
54
55# Exclude some standard library modules to:
56# 1. Slim down the final frozen lib.
57# 2. Remove functionality we don't want to support.
58DENY_LIST = [
59    # Interface to unix databases
60    "dbm",
61    # ncurses bindings (terminal interfaces)
62    "curses",
63    # Tcl/Tk GUI
64    "tkinter",
65    "tkinter",
66    # Tests for the standard library
67    "test",
68    "tests",
69    "idle_test",
70    "__phello__.foo.py",
71    # importlib frozen modules. These are already baked into CPython.
72    "_bootstrap.py",
73    "_bootstrap_external.py",
74]
75
76NUM_BYTECODE_FILES = 5
77
78
79def indent_msg(fn):
80    @functools.wraps(fn)
81    def wrapper(*args, **kwargs):
82        args[0].indent += 1
83        ret = fn(*args, **kwargs)
84        args[0].indent -= 1
85        return ret
86
87    return wrapper
88
89
90@dataclass
91class FrozenModule:
92    # The fully qualified module name, e.g. 'foo.bar.baz'
93    module_name: str
94    # The name of the C variable that holds the bytecode, e.g. 'M_foo__bar__baz'
95    c_name: str
96    # The size of the C variable. Negative if this module is a package.
97    size: int
98    # The frozen bytecode
99    bytecode: bytes
100
101
102class Freezer:
103    def __init__(self, verbose: bool):
104        self.frozen_modules: List[FrozenModule] = []
105        self.indent: int = 0
106        self.verbose: bool = verbose
107
108    def msg(self, path: Path, code: str):
109        if not self.verbose:
110            return
111        # P: package dir
112        # F: python file
113        # S: skipped (not a package dir)
114        # X: skipped (deny-listed)
115        # N: skipped (not a python file)
116        for i in range(self.indent):
117            print("    ", end="")
118        print(f"{code} {path}")
119
120    def write_bytecode(self, install_root):
121        """
122        Write the `.c` files containing the frozen bytecode.
123
124        Shared frozen modules evenly across the files.
125        """
126        bytecode_file_names = [f"bytecode_{i}.c" for i in range(NUM_BYTECODE_FILES)]
127        bytecode_files = [
128            open(os.path.join(install_root, name), "w") for name in bytecode_file_names
129        ]
130        it = itertools.cycle(bytecode_files)
131        for m in self.frozen_modules:
132            self.write_frozen(m, next(it))
133
134        for f in bytecode_files:
135            f.close()
136
137    def write_main(self, install_root, oss, symbol_name):
138        """Write the `main.c` file containing a table enumerating all the frozen modules."""
139        with open(os.path.join(install_root, "main.c"), "w") as outfp:
140            outfp.write(MAIN_INCLUDES)
141            for m in self.frozen_modules:
142                outfp.write(f"extern unsigned char {m.c_name}[];\n")
143
144            outfp.write(MAIN_PREFIX_TEMPLATE.format(symbol_name))
145            for m in self.frozen_modules:
146                outfp.write(f'\t{{"{m.module_name}", {m.c_name}, {m.size}}},\n')
147            outfp.write(MAIN_SUFFIX)
148            if oss:
149                outfp.write(FAKE_PREFIX)
150                outfp.write(MAIN_SUFFIX)
151
152    def write_frozen(self, m: FrozenModule, outfp):
153        """Write a single frozen module's bytecode out to a C variable."""
154        outfp.write(f"unsigned char {m.c_name}[] = {{")
155        for i in range(0, len(m.bytecode), 16):
156            outfp.write("\n\t")
157            for c in bytes(m.bytecode[i : i + 16]):
158                outfp.write("%d," % c)
159        outfp.write("\n};\n")
160
161    def compile_path(self, path: Path, top_package_path: Path):
162        """Entry point for compiling a Path object."""
163        if path.is_dir():
164            self.compile_package(path, top_package_path)
165        else:
166            self.compile_file(path, top_package_path)
167
168    @indent_msg
169    def compile_package(self, path: Path, top_package_path: Path):
170        """Compile all the files within a Python package dir."""
171        assert path.is_dir()
172        if path.name in DENY_LIST:
173            self.msg(path, "X")
174            return
175
176        # Python packages are directories that have __init__.py in them.
177        is_package_dir = any(child.name == "__init__.py" for child in path.iterdir())
178        if not is_package_dir:
179            self.msg(path, "S")
180            return
181
182        self.msg(path, "P")
183        # Recursively compile all children in this dir
184        for child in path.iterdir():
185            self.compile_path(child, top_package_path)
186
187    def get_module_qualname(self, file_path: Path, top_package_path: Path) -> List[str]:
188        # `path` looks like 'Lib/foo/bar/baz.py'
189
190        # chop off 'Lib/' to get something that represents a Python module hierarchy.
191        # e.g. 'foo/bar/baz.py', which maps to 'foo.bar.baz'
192        normalized_path = file_path.relative_to(top_package_path.parent)
193
194        if normalized_path.name == "__init__.py":
195            # Special handling for `__init__.py`. In this case, this file
196            # specifies that the containing directory should be treated as a package.
197            # For 'foo/bar/baz/__init__.py':
198            # - The module name is 'baz'
199            module_basename = normalized_path.parent.name
200            # - The parent is foo.bar (need to shave off the 'baz')
201            module_parent = normalized_path.parent.parent.parts
202        else:
203            module_basename = normalized_path.stem
204            module_parent = normalized_path.parent.parts
205        return list(module_parent) + [module_basename]
206
207    def compile_string(self, file_content: str) -> types.CodeType:
208        # instead of passing in the real build time path to 'compile', we
209        # pass in a marker instead. This prevents the build time path being
210        # leaked to runtime. That path may not be available at runtime.
211        # Setting the path to a mark make sure it's a hard error rather
212        # than a flaky error when inspect module tries to retrieve python source
213        # code during torchscripting.
214        path_marker = PATH_MARKER
215        return compile(file_content, path_marker, "exec")
216
217    @indent_msg
218    def compile_file(self, path: Path, top_package_path: Path):
219        """
220        Compile a Python source file to frozen bytecode.
221
222        Append the result to `self.frozen_modules`.
223        """
224        assert path.is_file()
225        if path.suffix != ".py":
226            self.msg(path, "N")
227            return
228
229        if path.name in DENY_LIST:
230            self.msg(path, "X")
231            return
232
233        self.msg(path, "F")
234        module_qualname = self.get_module_qualname(path, top_package_path)
235        module_mangled_name = "__".join(module_qualname)
236        c_name = "M_" + module_mangled_name
237
238        with open(path) as src_file:
239            co = self.compile_string(src_file.read())
240
241        bytecode = marshal.dumps(co)
242        size = len(bytecode)
243        if path.name == "__init__.py":
244            # Python packages are signified by negative size.
245            size = -size
246        self.frozen_modules.append(
247            FrozenModule(".".join(module_qualname), c_name, size, bytecode)
248        )
249
250
251def main() -> None:
252    parser = argparse.ArgumentParser(description="Compile py source")
253    parser.add_argument("paths", nargs="*", help="Paths to freeze.")
254    parser.add_argument("--verbose", action="store_true", help="Print debug logs")
255    parser.add_argument(
256        "--install-dir", "--install_dir", help="Root directory for all output files"
257    )
258    parser.add_argument(
259        "--oss",
260        action="store_true",
261        help="If it's OSS build, add a fake _PyImport_FrozenModules",
262    )
263    parser.add_argument(
264        "--symbol-name",
265        "--symbol_name",
266        help="The name of the frozen module array symbol to generate",
267        default="_PyImport_FrozenModules_torch",
268    )
269
270    args = parser.parse_args()
271
272    f = Freezer(args.verbose)
273
274    for p in args.paths:
275        path = Path(p)
276        if path.is_dir() and not Path.exists(path / "__init__.py"):
277            # this 'top level path p' is a standard directory containing modules,
278            # not a module itself
279            # each 'mod' could be a dir containing __init__.py or .py file
280            # NB: sorted to make sure this is deterministic
281            for mod in sorted(path.glob("*")):
282                f.compile_path(mod, mod)
283        else:
284            f.compile_path(path, path)
285
286    f.write_bytecode(args.install_dir)
287    f.write_main(args.install_dir, args.oss, args.symbol_name)
288
289
290if __name__ == "__main__":
291    main()  # pragma: no cover
292