xref: /aosp_15_r20/external/executorch/exir/_serialize/_flatbuffer.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import importlib.resources
10import os
11import re
12import shutil
13import subprocess
14
15import tempfile
16
17from dataclasses import dataclass
18from typing import Callable, Dict, List, Optional, Sequence
19
20# If this environment variable is set to true, save the flatc input files when
21# serialization fails.
22_SAVE_FLATC_ENV: str = "ET_EXIR_SAVE_FLATC_INPUTS_ON_FAILURE"
23
24
25def _is_valid_alignment(alignment: int) -> bool:
26    """Returns True if the alignment is valid, or is None."""
27    if alignment is None:
28        return True
29    return alignment > 0 and (alignment & (alignment - 1)) == 0
30
31
32def _patch_schema_alignment(
33    schema: bytes,
34    constant_tensor_alignment: Optional[int],
35    delegate_alignment: Optional[int],
36) -> bytes:
37    """Modifies annotated "force_align" values in a flatbuffer schema.
38
39    Args:
40        schema: The flatbuffer schema to modify.
41        constant_tensor_alignment: If provided, the alignment to use for lines annotated
42            with "@executorch-tensor-alignment". If not provided, does not patch
43            tensor alignment.
44        delegate_alignment: If provided, the alignment to use for lines
45            annotated with "@executorch-delegate-alignment". If not provided,
46            does not patch delegate alignment.
47
48    Returns:
49        The possibly-modified flatbuffer schema.
50    """
51
52    def assert_valid_alignment(alignment: Optional[int], name: str) -> None:
53        if not (alignment is None or _is_valid_alignment(alignment)):
54            raise ValueError(f"Bad {name} {alignment}")
55
56    assert_valid_alignment(constant_tensor_alignment, "constant_tensor_alignment")
57    assert_valid_alignment(delegate_alignment, "delegate_alignment")
58
59    def patch_alignment(line: bytes, alignment: int) -> bytes:
60        """Replaces an existing alignment with a new alignment."""
61        return re.sub(
62            rb"\(\s*force_align\s*:\s*\d+\s*\)",
63            f"(force_align: {alignment})".encode("utf-8"),
64            line,
65        )
66
67    lines = []
68    for line in schema.splitlines():
69        if constant_tensor_alignment and b"@executorch-tensor-alignment" in line:
70            lines.append(patch_alignment(line, constant_tensor_alignment))
71        elif delegate_alignment and b"@executorch-delegate-alignment" in line:
72            lines.append(patch_alignment(line, delegate_alignment))
73        else:
74            lines.append(line)
75    return b"\n".join(lines)
76
77
78class _SchemaMaxAlignmentGetter:
79    """Finds the largest (force_align: N) N value in flatbuffer schemas."""
80
81    def __init__(self) -> None:
82        self.max_alignment: int = 0
83
84    def __call__(self, schema: bytes) -> bytes:
85        """Finds all `(force_align: N)` instances and updates max_alignment.
86
87        Returns the input schema unmodified.
88        """
89        regex = re.compile(rb"\(\s*force_align\s*:\s*(\d+)\s*\)")
90        matches = regex.findall(schema)
91        for alignment in [int(match) for match in matches]:
92            if alignment > self.max_alignment:
93                self.max_alignment = alignment
94        return schema
95
96
97class _ResourceFiles:
98    """Manages a collection of python resources that will be written to files."""
99
100    def __init__(self, resource_names: Sequence[str]) -> None:
101        """Load the resources with the provided names."""
102        # Map each name to its contents.
103        self._files: Dict[str, bytes] = {}
104        for name in resource_names:
105            self._files[name] = importlib.resources.read_binary(__package__, name)
106
107    def patch_files(self, patch_fn: Callable[[bytes], bytes]) -> None:
108        """Uses the provided patching function to update the contents of all
109        files. `patch_fn` takes the current contents of a file as input and
110        returns the new contents.
111        """
112        for name in self._files.keys():
113            self._files[name] = patch_fn(self._files[name])
114
115    def write_to(self, out_dir: str) -> None:
116        """Writes the files to the specified directory. File names are based on
117        the original resource names.
118        """
119        for name, data in self._files.items():
120            with open(os.path.join(out_dir, name), "wb") as fp:
121                fp.write(data)
122
123
124@dataclass
125class _SchemaInfo:
126    # Path to a file containing the root schema. Other included schema files may
127    # be present in the same directly.
128    root_path: str
129
130    # An alignment value that can satisfy all "force_align" entries found in the
131    # schema files.
132    max_alignment: int
133
134
135def _prepare_schema(
136    out_dir: str,
137    constant_tensor_alignment: Optional[int] = None,
138    delegate_alignment: Optional[int] = None,
139) -> _SchemaInfo:
140    """Returns the path to the program schema file after copying it and its deps
141    into out_dir. May patch the schema contents depending on the parameters to
142    this function.
143    """
144    program_schema = "program.fbs"
145    # Included by the root program schema; must also be present.
146    deps = ["scalar_type.fbs"]
147
148    schemas = _ResourceFiles([program_schema] + deps)
149
150    # Update annotated alignments in the schema files.
151    schemas.patch_files(
152        lambda data: _patch_schema_alignment(
153            schema=data,
154            constant_tensor_alignment=constant_tensor_alignment,
155            delegate_alignment=delegate_alignment,
156        ),
157    )
158    # Find the largest alignment used in the patched schema files.
159    get_alignments = _SchemaMaxAlignmentGetter()
160    schemas.patch_files(get_alignments)
161
162    # Write the patched schema files to the filesystem.
163    schemas.write_to(out_dir)
164
165    return _SchemaInfo(
166        root_path=os.path.join(out_dir, program_schema),
167        max_alignment=get_alignments.max_alignment,
168    )
169
170
171@dataclass
172class _FlatbufferResult:
173    # Serialized flatbuffer data.
174    data: bytes
175
176    # The maximum "force_align" value from the schema used to serialize the data.
177    max_alignment: int
178
179
180# Name of an optional resource containing the `flatc` executable.
181_FLATC_RESOURCE_NAME: str = "flatbuffers-flatc"
182
183
184def _run_flatc(args: Sequence[str]) -> None:
185    """Runs the `flatc` command with the provided args.
186
187    If a resource matching _FLATC_RESOURCE_NAME exists, uses that executable.
188    Otherwise, expects the `flatc` tool to be available on the system path.
189    """
190    if importlib.resources.is_resource(__package__, _FLATC_RESOURCE_NAME):
191        # Use the provided flatc binary.
192        with importlib.resources.path(__package__, _FLATC_RESOURCE_NAME) as flatc_path:
193            subprocess.run([flatc_path] + list(args), check=True)
194    else:
195        # Expect the `flatc` tool to be on the system path or set as an env var.
196        flatc_path = os.getenv("FLATC_EXECUTABLE")
197        if not flatc_path:
198            flatc_path = "flatc"
199        subprocess.run([flatc_path] + list(args), check=True)
200
201
202def _flatc_compile(output_dir: str, schema_path: str, json_path: str) -> None:
203    """Serializes JSON data to a binary flatbuffer file.
204
205    Args:
206        output_dir: Directory under which to create the binary flatbuffer file.
207        schema_path: Path to the flatbuffer schema to use for serialization.
208            If the schema inclues other schema files, they must be present in
209            the same directory.
210        json_path: Path to the data to serialize, as JSON data whose structure
211            matches the schema.
212    """
213    _run_flatc(
214        [
215            "--binary",
216            "-o",
217            output_dir,
218            schema_path,
219            json_path,
220        ]
221    )
222
223
224def _flatc_decompile(
225    output_dir: str,
226    schema_path: str,
227    bin_path: str,
228    flatc_additional_args: Optional[List[str]] = None,
229) -> None:
230    """Deserializes binary flatbuffer data to a JSON file.
231
232    Args:
233        output_dir: Directory under which to create the JSON file.
234        schema_path: Path to the flatbuffer schema to use for deserialization.
235            If the schema inclues other schema files, they must be present in
236            the same directory.
237        bin_path: Path to the data to deserialize, as binary data compatible
238            with the schema.
239    """
240    flatc_additional_args = flatc_additional_args if flatc_additional_args else []
241    _run_flatc(
242        flatc_additional_args
243        + [
244            "--json",
245            "--defaults-json",
246            "--strict-json",
247            "-o",
248            output_dir,
249            schema_path,
250            "--",
251            bin_path,
252        ]
253    )
254
255
256def _program_json_to_flatbuffer(
257    program_json: str,
258    *,
259    constant_tensor_alignment: Optional[int] = None,
260    delegate_alignment: Optional[int] = None,
261) -> _FlatbufferResult:
262    """Converts Program-compatible JSON into binary flatbuffer data.
263
264    Args:
265        program_json: The JSON to convert. Must be compatible with the root
266            table type of //executorch/schema/program.fbs.
267        constant_tensor_alignment: If provided, the alignment to use for tensor
268            data embedded in the output flatbuffer data. If not provided, uses
269            the alignment in the schema.
270        delegate_alignment: If provided, the alignment to use for delegate
271            data embedded in the output flatbuffer data. If not provided, uses
272            the alignment in the schema.
273
274    Returns: The flatbuffer data and associated metadata.
275    """
276    with tempfile.TemporaryDirectory() as temp_dir:
277        schema_info = _prepare_schema(
278            out_dir=temp_dir,
279            constant_tensor_alignment=constant_tensor_alignment,
280            delegate_alignment=delegate_alignment,
281        )
282        file_stem = "data"
283        json_path = os.path.join(temp_dir, file_stem + ".json")
284        output_path = os.path.join(temp_dir, file_stem + ".pte")
285
286        with open(json_path, "wb") as json_file:
287            json_file.write(program_json.encode("ascii"))
288
289        try:
290            _flatc_compile(temp_dir, schema_info.root_path, json_path)
291        except Exception as err:
292            # It's helpful to save the breaking files for debugging. Optionally
293            # move them out of the auto-deleting temporary directory. Don't do
294            # this by default because some input files can be many GB in size,
295            # and these copies won't be auto-deleted.
296            should_save = os.getenv(_SAVE_FLATC_ENV, "").strip() not in {"", "0"}
297            extra_message = ""
298            if should_save:
299                try:
300                    saved_dir = tempfile.mkdtemp(prefix="exir-saved-flatc-")
301                    for f in os.listdir(temp_dir):
302                        shutil.move(src=os.path.join(temp_dir, f), dst=saved_dir)
303                    extra_message += f" Moved input files to '{saved_dir}'."
304                except Exception as err2:
305                    extra_message += (
306                        f" (Failed to save input files for debugging: {err2})"
307                    )
308            else:
309                extra_message += (
310                    f" Set {_SAVE_FLATC_ENV}=1 to save input files on failure."
311                )
312
313            raise RuntimeError(
314                f"Failed to compile {json_path} to {output_path}." + extra_message
315            ) from err
316        with open(output_path, "rb") as output_file:
317            return _FlatbufferResult(
318                data=output_file.read(), max_alignment=schema_info.max_alignment
319            )
320
321
322def _replace_infinity_in_json_file(content: bytes) -> bytes:
323    """Replace -inf and inf with "inf" and "-inf" in the JSON file. program.fbs
324    is used to convert from flatbuffer to JSON. +-inf float values are not
325    supported by JSON, so we replace them with the string equivalent. When
326    converting from JSON to python dataclasses, the string is read as a Union
327    of float and string (see schema.py).
328    """
329    content = re.sub(
330        rb'"double_val"\s*:\s*(-)?inf', rb'"double_val": "\g<1>inf"', content
331    )
332    return content
333
334
335def _program_flatbuffer_to_json(program_flatbuffer: bytes) -> bytes:
336    """Converts binary flatbuffer data into Program-compatible JSON.
337
338    The binary is parsed using the schema in //executorch/schema/program.fbs.
339    """
340    with tempfile.TemporaryDirectory() as temp_dir:
341        # No need to patch the alignment when reading. "force_align" is only
342        # used during serialization.
343        schema_info = _prepare_schema(temp_dir)
344        file_stem = "data"
345        bin_path = os.path.join(temp_dir, file_stem + ".bin")
346        json_path = os.path.join(temp_dir, file_stem + ".json")
347
348        with open(bin_path, "wb") as bin_file:
349            bin_file.write(program_flatbuffer)
350
351        _flatc_decompile(temp_dir, schema_info.root_path, bin_path)
352        with open(json_path, "rb") as output_file:
353            json_data = output_file.read()
354            return _replace_infinity_in_json_file(json_data)
355