1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import importlib.resources 10import os 11import re 12import shutil 13import subprocess 14 15import tempfile 16 17from dataclasses import dataclass 18from typing import Callable, Dict, List, Optional, Sequence 19 20# If this environment variable is set to true, save the flatc input files when 21# serialization fails. 22_SAVE_FLATC_ENV: str = "ET_EXIR_SAVE_FLATC_INPUTS_ON_FAILURE" 23 24 25def _is_valid_alignment(alignment: int) -> bool: 26 """Returns True if the alignment is valid, or is None.""" 27 if alignment is None: 28 return True 29 return alignment > 0 and (alignment & (alignment - 1)) == 0 30 31 32def _patch_schema_alignment( 33 schema: bytes, 34 constant_tensor_alignment: Optional[int], 35 delegate_alignment: Optional[int], 36) -> bytes: 37 """Modifies annotated "force_align" values in a flatbuffer schema. 38 39 Args: 40 schema: The flatbuffer schema to modify. 41 constant_tensor_alignment: If provided, the alignment to use for lines annotated 42 with "@executorch-tensor-alignment". If not provided, does not patch 43 tensor alignment. 44 delegate_alignment: If provided, the alignment to use for lines 45 annotated with "@executorch-delegate-alignment". If not provided, 46 does not patch delegate alignment. 47 48 Returns: 49 The possibly-modified flatbuffer schema. 50 """ 51 52 def assert_valid_alignment(alignment: Optional[int], name: str) -> None: 53 if not (alignment is None or _is_valid_alignment(alignment)): 54 raise ValueError(f"Bad {name} {alignment}") 55 56 assert_valid_alignment(constant_tensor_alignment, "constant_tensor_alignment") 57 assert_valid_alignment(delegate_alignment, "delegate_alignment") 58 59 def patch_alignment(line: bytes, alignment: int) -> bytes: 60 """Replaces an existing alignment with a new alignment.""" 61 return re.sub( 62 rb"\(\s*force_align\s*:\s*\d+\s*\)", 63 f"(force_align: {alignment})".encode("utf-8"), 64 line, 65 ) 66 67 lines = [] 68 for line in schema.splitlines(): 69 if constant_tensor_alignment and b"@executorch-tensor-alignment" in line: 70 lines.append(patch_alignment(line, constant_tensor_alignment)) 71 elif delegate_alignment and b"@executorch-delegate-alignment" in line: 72 lines.append(patch_alignment(line, delegate_alignment)) 73 else: 74 lines.append(line) 75 return b"\n".join(lines) 76 77 78class _SchemaMaxAlignmentGetter: 79 """Finds the largest (force_align: N) N value in flatbuffer schemas.""" 80 81 def __init__(self) -> None: 82 self.max_alignment: int = 0 83 84 def __call__(self, schema: bytes) -> bytes: 85 """Finds all `(force_align: N)` instances and updates max_alignment. 86 87 Returns the input schema unmodified. 88 """ 89 regex = re.compile(rb"\(\s*force_align\s*:\s*(\d+)\s*\)") 90 matches = regex.findall(schema) 91 for alignment in [int(match) for match in matches]: 92 if alignment > self.max_alignment: 93 self.max_alignment = alignment 94 return schema 95 96 97class _ResourceFiles: 98 """Manages a collection of python resources that will be written to files.""" 99 100 def __init__(self, resource_names: Sequence[str]) -> None: 101 """Load the resources with the provided names.""" 102 # Map each name to its contents. 103 self._files: Dict[str, bytes] = {} 104 for name in resource_names: 105 self._files[name] = importlib.resources.read_binary(__package__, name) 106 107 def patch_files(self, patch_fn: Callable[[bytes], bytes]) -> None: 108 """Uses the provided patching function to update the contents of all 109 files. `patch_fn` takes the current contents of a file as input and 110 returns the new contents. 111 """ 112 for name in self._files.keys(): 113 self._files[name] = patch_fn(self._files[name]) 114 115 def write_to(self, out_dir: str) -> None: 116 """Writes the files to the specified directory. File names are based on 117 the original resource names. 118 """ 119 for name, data in self._files.items(): 120 with open(os.path.join(out_dir, name), "wb") as fp: 121 fp.write(data) 122 123 124@dataclass 125class _SchemaInfo: 126 # Path to a file containing the root schema. Other included schema files may 127 # be present in the same directly. 128 root_path: str 129 130 # An alignment value that can satisfy all "force_align" entries found in the 131 # schema files. 132 max_alignment: int 133 134 135def _prepare_schema( 136 out_dir: str, 137 constant_tensor_alignment: Optional[int] = None, 138 delegate_alignment: Optional[int] = None, 139) -> _SchemaInfo: 140 """Returns the path to the program schema file after copying it and its deps 141 into out_dir. May patch the schema contents depending on the parameters to 142 this function. 143 """ 144 program_schema = "program.fbs" 145 # Included by the root program schema; must also be present. 146 deps = ["scalar_type.fbs"] 147 148 schemas = _ResourceFiles([program_schema] + deps) 149 150 # Update annotated alignments in the schema files. 151 schemas.patch_files( 152 lambda data: _patch_schema_alignment( 153 schema=data, 154 constant_tensor_alignment=constant_tensor_alignment, 155 delegate_alignment=delegate_alignment, 156 ), 157 ) 158 # Find the largest alignment used in the patched schema files. 159 get_alignments = _SchemaMaxAlignmentGetter() 160 schemas.patch_files(get_alignments) 161 162 # Write the patched schema files to the filesystem. 163 schemas.write_to(out_dir) 164 165 return _SchemaInfo( 166 root_path=os.path.join(out_dir, program_schema), 167 max_alignment=get_alignments.max_alignment, 168 ) 169 170 171@dataclass 172class _FlatbufferResult: 173 # Serialized flatbuffer data. 174 data: bytes 175 176 # The maximum "force_align" value from the schema used to serialize the data. 177 max_alignment: int 178 179 180# Name of an optional resource containing the `flatc` executable. 181_FLATC_RESOURCE_NAME: str = "flatbuffers-flatc" 182 183 184def _run_flatc(args: Sequence[str]) -> None: 185 """Runs the `flatc` command with the provided args. 186 187 If a resource matching _FLATC_RESOURCE_NAME exists, uses that executable. 188 Otherwise, expects the `flatc` tool to be available on the system path. 189 """ 190 if importlib.resources.is_resource(__package__, _FLATC_RESOURCE_NAME): 191 # Use the provided flatc binary. 192 with importlib.resources.path(__package__, _FLATC_RESOURCE_NAME) as flatc_path: 193 subprocess.run([flatc_path] + list(args), check=True) 194 else: 195 # Expect the `flatc` tool to be on the system path or set as an env var. 196 flatc_path = os.getenv("FLATC_EXECUTABLE") 197 if not flatc_path: 198 flatc_path = "flatc" 199 subprocess.run([flatc_path] + list(args), check=True) 200 201 202def _flatc_compile(output_dir: str, schema_path: str, json_path: str) -> None: 203 """Serializes JSON data to a binary flatbuffer file. 204 205 Args: 206 output_dir: Directory under which to create the binary flatbuffer file. 207 schema_path: Path to the flatbuffer schema to use for serialization. 208 If the schema inclues other schema files, they must be present in 209 the same directory. 210 json_path: Path to the data to serialize, as JSON data whose structure 211 matches the schema. 212 """ 213 _run_flatc( 214 [ 215 "--binary", 216 "-o", 217 output_dir, 218 schema_path, 219 json_path, 220 ] 221 ) 222 223 224def _flatc_decompile( 225 output_dir: str, 226 schema_path: str, 227 bin_path: str, 228 flatc_additional_args: Optional[List[str]] = None, 229) -> None: 230 """Deserializes binary flatbuffer data to a JSON file. 231 232 Args: 233 output_dir: Directory under which to create the JSON file. 234 schema_path: Path to the flatbuffer schema to use for deserialization. 235 If the schema inclues other schema files, they must be present in 236 the same directory. 237 bin_path: Path to the data to deserialize, as binary data compatible 238 with the schema. 239 """ 240 flatc_additional_args = flatc_additional_args if flatc_additional_args else [] 241 _run_flatc( 242 flatc_additional_args 243 + [ 244 "--json", 245 "--defaults-json", 246 "--strict-json", 247 "-o", 248 output_dir, 249 schema_path, 250 "--", 251 bin_path, 252 ] 253 ) 254 255 256def _program_json_to_flatbuffer( 257 program_json: str, 258 *, 259 constant_tensor_alignment: Optional[int] = None, 260 delegate_alignment: Optional[int] = None, 261) -> _FlatbufferResult: 262 """Converts Program-compatible JSON into binary flatbuffer data. 263 264 Args: 265 program_json: The JSON to convert. Must be compatible with the root 266 table type of //executorch/schema/program.fbs. 267 constant_tensor_alignment: If provided, the alignment to use for tensor 268 data embedded in the output flatbuffer data. If not provided, uses 269 the alignment in the schema. 270 delegate_alignment: If provided, the alignment to use for delegate 271 data embedded in the output flatbuffer data. If not provided, uses 272 the alignment in the schema. 273 274 Returns: The flatbuffer data and associated metadata. 275 """ 276 with tempfile.TemporaryDirectory() as temp_dir: 277 schema_info = _prepare_schema( 278 out_dir=temp_dir, 279 constant_tensor_alignment=constant_tensor_alignment, 280 delegate_alignment=delegate_alignment, 281 ) 282 file_stem = "data" 283 json_path = os.path.join(temp_dir, file_stem + ".json") 284 output_path = os.path.join(temp_dir, file_stem + ".pte") 285 286 with open(json_path, "wb") as json_file: 287 json_file.write(program_json.encode("ascii")) 288 289 try: 290 _flatc_compile(temp_dir, schema_info.root_path, json_path) 291 except Exception as err: 292 # It's helpful to save the breaking files for debugging. Optionally 293 # move them out of the auto-deleting temporary directory. Don't do 294 # this by default because some input files can be many GB in size, 295 # and these copies won't be auto-deleted. 296 should_save = os.getenv(_SAVE_FLATC_ENV, "").strip() not in {"", "0"} 297 extra_message = "" 298 if should_save: 299 try: 300 saved_dir = tempfile.mkdtemp(prefix="exir-saved-flatc-") 301 for f in os.listdir(temp_dir): 302 shutil.move(src=os.path.join(temp_dir, f), dst=saved_dir) 303 extra_message += f" Moved input files to '{saved_dir}'." 304 except Exception as err2: 305 extra_message += ( 306 f" (Failed to save input files for debugging: {err2})" 307 ) 308 else: 309 extra_message += ( 310 f" Set {_SAVE_FLATC_ENV}=1 to save input files on failure." 311 ) 312 313 raise RuntimeError( 314 f"Failed to compile {json_path} to {output_path}." + extra_message 315 ) from err 316 with open(output_path, "rb") as output_file: 317 return _FlatbufferResult( 318 data=output_file.read(), max_alignment=schema_info.max_alignment 319 ) 320 321 322def _replace_infinity_in_json_file(content: bytes) -> bytes: 323 """Replace -inf and inf with "inf" and "-inf" in the JSON file. program.fbs 324 is used to convert from flatbuffer to JSON. +-inf float values are not 325 supported by JSON, so we replace them with the string equivalent. When 326 converting from JSON to python dataclasses, the string is read as a Union 327 of float and string (see schema.py). 328 """ 329 content = re.sub( 330 rb'"double_val"\s*:\s*(-)?inf', rb'"double_val": "\g<1>inf"', content 331 ) 332 return content 333 334 335def _program_flatbuffer_to_json(program_flatbuffer: bytes) -> bytes: 336 """Converts binary flatbuffer data into Program-compatible JSON. 337 338 The binary is parsed using the schema in //executorch/schema/program.fbs. 339 """ 340 with tempfile.TemporaryDirectory() as temp_dir: 341 # No need to patch the alignment when reading. "force_align" is only 342 # used during serialization. 343 schema_info = _prepare_schema(temp_dir) 344 file_stem = "data" 345 bin_path = os.path.join(temp_dir, file_stem + ".bin") 346 json_path = os.path.join(temp_dir, file_stem + ".json") 347 348 with open(bin_path, "wb") as bin_file: 349 bin_file.write(program_flatbuffer) 350 351 _flatc_decompile(temp_dir, schema_info.root_path, bin_path) 352 with open(json_path, "rb") as output_file: 353 json_data = output_file.read() 354 return _replace_infinity_in_json_file(json_data) 355