1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import copy 10import json 11import re 12 13from dataclasses import dataclass 14from typing import ClassVar, List, Literal, Optional, Tuple 15 16from executorch.exir._serialize._cord import Cord 17from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass 18from executorch.exir._serialize._flatbuffer import ( 19 _FlatbufferResult, 20 _program_flatbuffer_to_json, 21 _program_json_to_flatbuffer, 22) 23 24from executorch.exir.schema import ( 25 BackendDelegateDataReference, 26 BackendDelegateInlineData, 27 Buffer, 28 DataLocation, 29 DataSegment, 30 Program, 31 SubsegmentOffsets, 32) 33from executorch.exir.tensor import ALIGNMENT 34 35 36# Byte order of numbers written to program headers. Always little-endian 37# regardless of the host system, since all commonly-used modern CPUs are little 38# endian. 39_HEADER_BYTEORDER: Literal["little"] = "little" 40 41 42def _program_to_json(program: Program) -> str: 43 """Returns the JSON representation of the given Program.""" 44 return json.dumps(program, cls=_DataclassEncoder) 45 46 47def _json_to_program(program_json: bytes) -> Program: 48 """Returns a Program deserialized from the given JSON string.""" 49 # construct program class recursively from dict 50 return _json_to_dataclass(json.loads(program_json), cls=Program) 51 52 53def _padding_required(offset: int, alignment: int) -> int: 54 """Returns the padding required to align `offset` to `alignment`.""" 55 remainder: int = offset % alignment 56 if remainder != 0: 57 return alignment - remainder 58 return 0 59 60 61def _aligned_size(input_size: int, alignment: int) -> int: 62 """Returns input_size padded up to the next whole multiple of alignment.""" 63 return input_size + _padding_required(input_size, alignment) 64 65 66def _insert_flatbuffer_header( 67 flatbuffer_data: bytes, magic_regex: str, header_data: bytes 68) -> bytes: 69 """Inserts a header just after the magic string of the provided flatbuffer data. 70 71 Args: 72 flatbuffer_data: The input data to modify. 73 magic_regex: A regex pattern that must match the magic file_identifier 74 characters of flatbuffer_data. 75 header_data: The data to insert into flatbuffer_data. To ensure that 76 flatbuffer internal alignment is preserved, the caller must 77 guaranteed that its length is a power of 2 >= the largest 78 force_align value in the schema. 79 Returns: 80 The modified flatbuffer_data with header_data inserted. 81 Raises: 82 ValueError: If flatbuffer_data is too short to be valid. 83 ValueError: If the magic bytes of flatbuffer_data does not match 84 magic_regex. 85 """ 86 # The binary flatbuffer file should begin with: 87 # - Offset in bytes to root table (4 bytes little endian) 88 # - file_identifier string from the schema (4 bytes, string order) 89 if len(flatbuffer_data) < 8: 90 raise ValueError(f"Flatbuffer data length {len(flatbuffer_data)} < 8") 91 92 # Ensure that the magic matches. 93 actual_magic: str = flatbuffer_data[4:8].decode(errors="replace") 94 if not re.match(magic_regex, actual_magic): 95 raise ValueError( 96 f"Flatbuffer data magic bytes {repr(actual_magic)} " 97 + f"does not match pattern /{magic_regex}/" 98 ) 99 100 # Avoid a potentially big allocation/copy if there's nothing to do. 101 if len(header_data) == 0: 102 return flatbuffer_data 103 104 # We will need to adjust the root object offset after inserting the header. 105 root_offset = int.from_bytes(flatbuffer_data[0:4], byteorder=_HEADER_BYTEORDER) 106 107 return ( 108 # New root offset. 109 (root_offset + len(header_data)).to_bytes(4, byteorder=_HEADER_BYTEORDER) 110 # Existing magic bytes. 111 + flatbuffer_data[4:8] 112 # Provided header + padding. 113 + header_data 114 # Remainder of the file. Note that this can be O(10MB to 100MB), so it 115 # can trigger a large allocation + copy. 116 + flatbuffer_data[8:] 117 ) 118 119 120@dataclass 121class _ExtendedHeader: 122 # Class constants 123 124 # The magic bytes that should be at the beginning of the header. 125 EXPECTED_MAGIC: ClassVar[bytes] = b"eh00" 126 # The length of the header in bytes. 127 EXPECTED_LENGTH: ClassVar[int] = ( 128 # Header magic 129 4 130 # Header length 131 + 4 132 # Flatbuffer data size 133 + 8 134 # Segment base offset 135 + 8 136 ) 137 138 # Instance attributes. @dataclass will turn these into ctor args. 139 140 # The size of the serialized program data in bytes. 141 program_size: int 142 # Offset to the start of the first segment, or zero if there 143 # are no segments. 144 segment_base_offset: int 145 146 # The magic bytes read from or to be written to the binary header. 147 magic: bytes = EXPECTED_MAGIC 148 # The header length, in bytes, read from or to be written to the binary 149 # header. 150 length: int = EXPECTED_LENGTH 151 152 @staticmethod 153 def from_bytes(data: bytes) -> "_ExtendedHeader": 154 """Tries to read an extended header from the provided data. 155 156 Does not validate that the header is well-formed. Callers should 157 use is_valid(). 158 159 Args: 160 data: The data to read from. 161 Returns: 162 The contents of the extended header. 163 Raises: 164 ValueError: If not enough data is provided. 165 """ 166 if len(data) < _ExtendedHeader.EXPECTED_LENGTH: 167 raise ValueError( 168 f"Not enough data for extended header: {len(data)} " 169 + f"< {_ExtendedHeader.EXPECTED_LENGTH}" 170 ) 171 172 return _ExtendedHeader( 173 magic=data[0:4], 174 length=int.from_bytes(data[4:8], byteorder=_HEADER_BYTEORDER), 175 program_size=int.from_bytes(data[8:16], byteorder=_HEADER_BYTEORDER), 176 segment_base_offset=int.from_bytes( 177 data[16:24], byteorder=_HEADER_BYTEORDER 178 ), 179 ) 180 181 def is_valid(self) -> bool: 182 """Returns true if the extended header appears to be well-formed.""" 183 return ( 184 self.magic == _ExtendedHeader.EXPECTED_MAGIC 185 and self.length >= _ExtendedHeader.EXPECTED_LENGTH 186 ) 187 188 def to_bytes(self) -> bytes: 189 """Returns the binary representation of the extended header. 190 191 Note that this will ignore self.magic and self.length and will always 192 write the proper magic/length. 193 """ 194 data: bytes = ( 195 # Extended header magic. This lets consumers detect whether the 196 # header was inserted or not. Always use the proper magic value 197 # (i.e., ignore self.magic) since there's no reason to create an 198 # invalid header. 199 self.EXPECTED_MAGIC 200 # uint32_t: Size of this header. This makes it easier to add new 201 # fields to this header in the future. Always use the proper size 202 # (i.e., ignore self.length) since there's no reason to create an 203 # invalid header. 204 + self.EXPECTED_LENGTH.to_bytes(4, byteorder=_HEADER_BYTEORDER) 205 # uint64_t: Size of the flatbuffer data, including this header. 206 + self.program_size.to_bytes(8, byteorder=_HEADER_BYTEORDER) 207 # uint64_t: Offset to the start of the first segment, or zero if 208 # there are no segments. 209 + self.segment_base_offset.to_bytes(8, byteorder=_HEADER_BYTEORDER) 210 ) 211 return data 212 213 214def _pad_to(data: bytes, length: int) -> bytes: 215 """Returns the input followed by enough zero bytes to become the requested length. 216 217 Args: 218 data: The data to pad. 219 length: The length of the returned data. 220 Returns: 221 The padded data. 222 Raises: 223 ValueError: If the requested length is less than the input length. 224 """ 225 if length < len(data): 226 raise ValueError(f"Data length {len(data)} > padded length {length}") 227 if length > len(data): 228 data = data + b"\x00" * (length - len(data)) 229 assert len(data) == length 230 return data 231 232 233def _get_extended_header(program_data: bytes) -> Optional[_ExtendedHeader]: 234 """Returns the extended header of the program data, if present and valid.""" 235 try: 236 eh = _ExtendedHeader.from_bytes(program_data[8:]) 237 if eh.is_valid(): 238 return eh 239 except ValueError: 240 pass 241 return None 242 243 244def _extract_delegate_segments( 245 program: Program, 246 segments: List[Cord], 247) -> None: 248 """Extracts the delegate segments inlined in the program into a list of buffers. 249 The program is modified in-place to remove the delegate data. 250 251 Args: 252 program: The program to extract segments from. Modified in-place. 253 segments: A list of buffers to append extracted segments to. Modified in-place. 254 """ 255 remaining_inline: List[BackendDelegateInlineData] = [] 256 inline_indices_seen: set[int] = set() 257 for plan in program.execution_plan: 258 for delegate in plan.delegates: 259 if delegate.processed.location != DataLocation.INLINE: 260 raise ValueError( 261 "Program must only contain inline delegate data, " 262 + f"saw {repr(delegate)}" 263 ) 264 # TODO(T144120904): Don't extract small blobs into segments; 265 # have a cutoff. Or callers could provide a callback that 266 # returns true/false for a given BackendDelegate, letting them 267 # use their own logic. 268 try: 269 inline: BackendDelegateInlineData = program.backend_delegate_data[ 270 delegate.processed.index 271 ] 272 except IndexError: 273 raise ValueError( 274 f"Delegate processed index {delegate.processed.index} " 275 + ">= len(Program.backend_delegate_data) " 276 + f"{len(program.backend_delegate_data)} " 277 + f"in {repr(delegate)}" 278 ) 279 inline_indices_seen.add(delegate.processed.index) 280 if inline.data: 281 # Move the delegate data out of the program. 282 segment_index = len(segments) 283 segments.append(Cord(inline.data)) 284 delegate.processed = BackendDelegateDataReference( 285 location=DataLocation.SEGMENT, 286 index=segment_index, 287 ) 288 else: 289 # Not moving into a segment. Keep it inline, but update the 290 # index. 291 new_index = len(remaining_inline) 292 remaining_inline.append(inline) 293 delegate.processed.index = new_index 294 295 # Make sure we visited all entries in backend_delegate_data, so that it's 296 # safe to overwrite it. 297 remaining_indices: set[int] = set( 298 range(len(program.backend_delegate_data)) 299 ).difference(inline_indices_seen) 300 if remaining_indices: 301 raise ValueError( 302 "Did not handle all elements of backend_delegate_data; " 303 + f"remaining: {remaining_indices}" 304 ) 305 306 # Preserve any entries that were not moved into segments. 307 program.backend_delegate_data = remaining_inline 308 309 310def _extract_constant_segment( 311 constant_buffer: List[Buffer], 312 tensor_alignment: Optional[int] = None, 313) -> Tuple[Cord, List[int]]: 314 """Copies the tensors from the provided list into a Cord and tracks the offsets 315 of each tensor. 316 317 Args: 318 constant_buffer: list of Buffers from which to extract constants from. Not modified. 319 tensor_alignment: Alignment in bytes. Each tensor in the cord will be padded to align 320 with this value. Defaults to ALIGNMENT. 321 322 Returns: 323 A tuple of (constant segment, list of offsets for each tensor in the segment) 324 """ 325 constant_segment_data: Cord = Cord() 326 constant_segment_offsets: List[int] = [] 327 current_offset: int = 0 328 for i in range(len(constant_buffer)): 329 buffer = constant_buffer[i] 330 constant_segment_data.append(buffer.storage) 331 buffer_length = len(buffer.storage) 332 pad_length = ( 333 _padding_required(buffer_length, tensor_alignment) 334 if tensor_alignment is not None 335 else 0 336 ) 337 if i < len(constant_buffer) - 1: 338 constant_segment_data.append(b"\x00" * pad_length) 339 constant_segment_offsets.append(current_offset) 340 current_offset += buffer_length + pad_length 341 342 return constant_segment_data, constant_segment_offsets 343 344 345def serialize_pte_binary( 346 program: Program, 347 *, 348 mutable_data: Optional[List[Buffer]] = None, 349 extract_delegate_segments: bool = False, 350 segment_alignment: int = 128, 351 constant_tensor_alignment: Optional[int] = None, 352 delegate_alignment: Optional[int] = None, 353) -> Cord: 354 """Returns the runtime binary representation of the given Program. 355 356 Args: 357 program: The Program to serialize. 358 extract_delegate_segments: Whether to move delegate data blobs from the 359 Program into separate segments, rather than encoding those blobs 360 in the flatbuffer data. When true, will also: 361 - Add an extended header to the output, containing the program size 362 and the starting segment offset. 363 - Update the Program.segments field with the offsets and lengths 364 of each segment. 365 segment_alignment: Alignment in bytes. The starting offset of each 366 segment will be aligned to this value in the output data. 367 constant_tensor_alignment: The minimum alignment of tensor 368 buffers in the program. Must be a power of 2. Defaults to ALIGNMENT. 369 delegate_alignment: If provided, the minimum alignment of delegate data 370 in the program. Must be a power of 2. If not provided, uses the 371 value in the schema file. 372 Returns: 373 The serialized form of the Program, ready for execution by the runtime. 374 """ 375 # Default tensor alignment. 376 if constant_tensor_alignment is None: 377 constant_tensor_alignment = ALIGNMENT 378 379 # Don't modify the original program. 380 # TODO(T144120904): Could avoid yet more huge copies with a more shallow 381 # copy, reusing the actual data blobs. 382 program = copy.deepcopy(program) 383 384 # Store extracted segment data; this may be constant data or delegate data. 385 segments: List[Cord] = [] 386 387 constant_segment_data, constant_segment_offsets = _extract_constant_segment( 388 program.constant_buffer, tensor_alignment=constant_tensor_alignment 389 ) 390 391 # If there are no constants, len(constant_segment_data) = 0. However, there may 392 # be non-constants, in which case len(constant_segment_offsets) = 1, containing 393 # the placeholder value 0. Ensure the placeholder value is put into 394 # program.constant_segment.offsets. 395 if len(constant_segment_offsets) > 0: 396 # Update program.constant_segment with constant subsegment offset information. 397 program.constant_segment = SubsegmentOffsets( 398 segment_index=len(segments), offsets=constant_segment_offsets 399 ) 400 # Clear the constant buffer, as constant data will be stored in segments. 401 program.constant_buffer = [] 402 # Add to the aggregate segments cord. 403 segments.append(constant_segment_data) 404 405 if mutable_data is not None: 406 mutable_segment_data, mutable_segment_offsets = _extract_constant_segment( 407 mutable_data, 408 tensor_alignment=None, # data is copied at Method load so no need to align. 409 ) 410 if len(mutable_segment_data) > 0: 411 # Update program.mutable_segment_data with constant subsegment offset information. 412 program.mutable_data_segments = [ 413 SubsegmentOffsets( 414 segment_index=len(segments), offsets=mutable_segment_offsets 415 ), 416 ] 417 # Add to the aggregate segments cord. 418 segments.append(mutable_segment_data) 419 420 if extract_delegate_segments: 421 _extract_delegate_segments(program, segments) 422 423 # Append all segments into a single Cord, adding any necessary padding to ensure that 424 # each segment begins at the required alignment. 425 # Update program.segments with the offsets to each segment. 426 segments_data = Cord() 427 for data in segments: 428 prev_end = ( 429 (program.segments[-1].offset + program.segments[-1].size) 430 if program.segments 431 else 0 432 ) 433 program.segments.append( 434 DataSegment( 435 offset=_aligned_size(prev_end, segment_alignment), size=len(data) 436 ) 437 ) 438 # Add to aggregate segments cord with padding. 439 padding_length = _padding_required(len(segments_data), segment_alignment) 440 if padding_length > 0: 441 segments_data.append(b"\x00" * padding_length) 442 segments_data.append(data) 443 444 # Convert to a standard flatbuffer binary. 445 result: _FlatbufferResult = _program_json_to_flatbuffer( 446 _program_to_json(program), 447 constant_tensor_alignment=constant_tensor_alignment, 448 delegate_alignment=delegate_alignment, 449 ) 450 451 # If there are no segments present, do not insert the extended header. 452 if len(segments_data) == 0: 453 return Cord(result.data) 454 455 # Size of the header to insert. Its size is padded to the largest 456 # force_align value present in the schema. 457 padded_header_length: int = _aligned_size( 458 input_size=_ExtendedHeader.EXPECTED_LENGTH, 459 alignment=result.max_alignment, 460 ) 461 # Size of the program with the header inserted. 462 program_size: int = padded_header_length + len(result.data) 463 # Offset to the first segment, or zero if there are no segments. 464 segment_base_offset: int = ( 465 _aligned_size(input_size=program_size, alignment=segment_alignment) 466 if len(segments_data) > 0 467 else 0 468 ) 469 470 # Construct and pad the extended header. 471 header_data: bytes = _ExtendedHeader( 472 program_size=program_size, segment_base_offset=segment_base_offset 473 ).to_bytes() 474 header_data = _pad_to(header_data, padded_header_length) 475 476 # Insert the header into the flatbuffer data. 477 program_data: bytes = _insert_flatbuffer_header( 478 flatbuffer_data=result.data, 479 magic_regex=r"ET[0-9a-zA-Z][0-9a-zA-Z]", 480 header_data=header_data, 481 ) 482 assert len(program_data) == program_size 483 484 # Potentially large. Try to free it as soon as we can. 485 del result.data 486 487 # Double-check that the extended header is in the right place and has the 488 # right contents. 489 eh = _get_extended_header(program_data) 490 assert eh is not None 491 assert eh.program_size == program_size 492 assert eh.segment_base_offset == segment_base_offset 493 494 # Construct the final pte file containing: 495 # - program data; written to offset 0. 496 # - segments data (optional); aligned to segment_alignment. 497 pte_data = Cord(program_data) 498 if len(segments_data) > 0: 499 padding_length = _padding_required(len(pte_data), segment_alignment) 500 pte_data.append(b"\x00" * padding_length) 501 # The first segment after program data should start at the segment base offset. 502 assert ( 503 len(pte_data) == segment_base_offset 504 ), f"Offset of first segment {len(pte_data)} != segment base offset {segment_base_offset}" 505 pte_data.append(segments_data) 506 return pte_data 507 508 509def _restore_segments(program: Program, segment_data: bytes) -> Program: 510 """Moves segments from `segment_data` into `program`. 511 512 This should recreate the original Program that the segments were extracted 513 from. 514 515 Args: 516 program: The Program to restore. `program.segments` must describe the 517 segment locations. 518 segment_data: The data containing the segments. Assumes that this data 519 begins at `segment_base_offset` from the extended header: i.e., 520 the preceding data has been stripped off so that the first segment 521 begins at offset zero. 522 Returns: 523 The Program with segments restored. 524 """ 525 # Extract the list of segment data blobs, which parallel program.segments. 526 segments: List[bytes] = [] 527 for i, segment in enumerate(program.segments): 528 if segment.offset + segment.size > len(segment_data): 529 raise ValueError( 530 f"Segment {i} {segment} overflows data length {len(segment_data)}" 531 ) 532 segments.append(segment_data[segment.offset : segment.offset + segment.size]) 533 534 # Find and replace the Program's references to these segments, inlining the 535 # data. 536 for plan_index, plan in enumerate(program.execution_plan): 537 for delegate_index, delegate in enumerate(plan.delegates): 538 if delegate.processed.location == DataLocation.INLINE: 539 continue 540 assert delegate.processed.location == DataLocation.SEGMENT 541 index = delegate.processed.index 542 if index >= len(segments): 543 raise ValueError( 544 f"Plan {plan_index} delegate {delegate_index} " 545 + f"segment index {index} >= num segments {len(segments)}" 546 ) 547 548 data_index: int = len(program.backend_delegate_data) 549 program.backend_delegate_data.append( 550 BackendDelegateInlineData(data=segments[index]) 551 ) 552 delegate.processed = BackendDelegateDataReference( 553 location=DataLocation.INLINE, index=data_index 554 ) 555 556 # Replace constants from constant_segment into constant_buffer. 557 if program.constant_segment and len(program.constant_segment.offsets) > 0: 558 buffers: List[Buffer] = [] 559 constant_segment = segments[program.constant_segment.segment_index] 560 for i in range(len(program.constant_segment.offsets)): 561 start_offset = program.constant_segment.offsets[i] 562 # Note: this is the original end offset plus any padding between 563 # it and the next start offset. 564 end_offset = ( 565 program.constant_segment.offsets[i + 1] 566 if i < len(program.constant_segment.offsets) - 1 567 else len(constant_segment) 568 ) 569 buffers.append(Buffer(storage=constant_segment[start_offset:end_offset])) 570 program.constant_buffer = buffers 571 program.constant_segment.segment_index = 0 572 program.constant_segment.offsets = [] 573 574 # Clear out the segments list since the original Program didn't have one. 575 program.segments = [] 576 return program 577 578 579def deserialize_pte_binary(program_data: bytes) -> Program: 580 """Returns a Program deserialized from the given runtime binary data.""" 581 program_size = len(program_data) 582 segment_base_offset = 0 583 584 # Look for an extended header to see if segments follow the flatbuffer 585 # data. 586 eh: Optional[_ExtendedHeader] = _get_extended_header(program_data) 587 if eh and eh.is_valid(): 588 program_size = eh.program_size 589 segment_base_offset = eh.segment_base_offset 590 591 # Parse the flatbuffer data. 592 program: Program = _json_to_program( 593 _program_flatbuffer_to_json(program_data[:program_size]) 594 ) 595 596 if segment_base_offset != 0: 597 # Move segment data back into the Program. 598 program = _restore_segments( 599 program=program, segment_data=program_data[segment_base_offset:] 600 ) 601 602 return program 603