xref: /aosp_15_r20/external/executorch/exir/_serialize/_program.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import copy
10import json
11import re
12
13from dataclasses import dataclass
14from typing import ClassVar, List, Literal, Optional, Tuple
15
16from executorch.exir._serialize._cord import Cord
17from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass
18from executorch.exir._serialize._flatbuffer import (
19    _FlatbufferResult,
20    _program_flatbuffer_to_json,
21    _program_json_to_flatbuffer,
22)
23
24from executorch.exir.schema import (
25    BackendDelegateDataReference,
26    BackendDelegateInlineData,
27    Buffer,
28    DataLocation,
29    DataSegment,
30    Program,
31    SubsegmentOffsets,
32)
33from executorch.exir.tensor import ALIGNMENT
34
35
36# Byte order of numbers written to program headers. Always little-endian
37# regardless of the host system, since all commonly-used modern CPUs are little
38# endian.
39_HEADER_BYTEORDER: Literal["little"] = "little"
40
41
42def _program_to_json(program: Program) -> str:
43    """Returns the JSON representation of the given Program."""
44    return json.dumps(program, cls=_DataclassEncoder)
45
46
47def _json_to_program(program_json: bytes) -> Program:
48    """Returns a Program deserialized from the given JSON string."""
49    # construct program class recursively from dict
50    return _json_to_dataclass(json.loads(program_json), cls=Program)
51
52
53def _padding_required(offset: int, alignment: int) -> int:
54    """Returns the padding required to align `offset` to `alignment`."""
55    remainder: int = offset % alignment
56    if remainder != 0:
57        return alignment - remainder
58    return 0
59
60
61def _aligned_size(input_size: int, alignment: int) -> int:
62    """Returns input_size padded up to the next whole multiple of alignment."""
63    return input_size + _padding_required(input_size, alignment)
64
65
66def _insert_flatbuffer_header(
67    flatbuffer_data: bytes, magic_regex: str, header_data: bytes
68) -> bytes:
69    """Inserts a header just after the magic string of the provided flatbuffer data.
70
71    Args:
72        flatbuffer_data: The input data to modify.
73        magic_regex: A regex pattern that must match the magic file_identifier
74            characters of flatbuffer_data.
75        header_data: The data to insert into flatbuffer_data. To ensure that
76            flatbuffer internal alignment is preserved, the caller must
77            guaranteed that its length is a power of 2 >= the largest
78            force_align value in the schema.
79    Returns:
80        The modified flatbuffer_data with header_data inserted.
81    Raises:
82        ValueError: If flatbuffer_data is too short to be valid.
83        ValueError: If the magic bytes of flatbuffer_data does not match
84            magic_regex.
85    """
86    # The binary flatbuffer file should begin with:
87    # - Offset in bytes to root table (4 bytes little endian)
88    # - file_identifier string from the schema (4 bytes, string order)
89    if len(flatbuffer_data) < 8:
90        raise ValueError(f"Flatbuffer data length {len(flatbuffer_data)} < 8")
91
92    # Ensure that the magic matches.
93    actual_magic: str = flatbuffer_data[4:8].decode(errors="replace")
94    if not re.match(magic_regex, actual_magic):
95        raise ValueError(
96            f"Flatbuffer data magic bytes {repr(actual_magic)} "
97            + f"does not match pattern /{magic_regex}/"
98        )
99
100    # Avoid a potentially big allocation/copy if there's nothing to do.
101    if len(header_data) == 0:
102        return flatbuffer_data
103
104    # We will need to adjust the root object offset after inserting the header.
105    root_offset = int.from_bytes(flatbuffer_data[0:4], byteorder=_HEADER_BYTEORDER)
106
107    return (
108        # New root offset.
109        (root_offset + len(header_data)).to_bytes(4, byteorder=_HEADER_BYTEORDER)
110        # Existing magic bytes.
111        + flatbuffer_data[4:8]
112        # Provided header + padding.
113        + header_data
114        # Remainder of the file. Note that this can be O(10MB to 100MB), so it
115        # can trigger a large allocation + copy.
116        + flatbuffer_data[8:]
117    )
118
119
120@dataclass
121class _ExtendedHeader:
122    # Class constants
123
124    # The magic bytes that should be at the beginning of the header.
125    EXPECTED_MAGIC: ClassVar[bytes] = b"eh00"
126    # The length of the header in bytes.
127    EXPECTED_LENGTH: ClassVar[int] = (
128        # Header magic
129        4
130        # Header length
131        + 4
132        # Flatbuffer data size
133        + 8
134        # Segment base offset
135        + 8
136    )
137
138    # Instance attributes. @dataclass will turn these into ctor args.
139
140    # The size of the serialized program data in bytes.
141    program_size: int
142    # Offset to the start of the first segment, or zero if there
143    # are no segments.
144    segment_base_offset: int
145
146    # The magic bytes read from or to be written to the binary header.
147    magic: bytes = EXPECTED_MAGIC
148    # The header length, in bytes, read from or to be written to the binary
149    # header.
150    length: int = EXPECTED_LENGTH
151
152    @staticmethod
153    def from_bytes(data: bytes) -> "_ExtendedHeader":
154        """Tries to read an extended header from the provided data.
155
156        Does not validate that the header is well-formed. Callers should
157        use is_valid().
158
159        Args:
160            data: The data to read from.
161        Returns:
162            The contents of the extended header.
163        Raises:
164            ValueError: If not enough data is provided.
165        """
166        if len(data) < _ExtendedHeader.EXPECTED_LENGTH:
167            raise ValueError(
168                f"Not enough data for extended header: {len(data)} "
169                + f"< {_ExtendedHeader.EXPECTED_LENGTH}"
170            )
171
172        return _ExtendedHeader(
173            magic=data[0:4],
174            length=int.from_bytes(data[4:8], byteorder=_HEADER_BYTEORDER),
175            program_size=int.from_bytes(data[8:16], byteorder=_HEADER_BYTEORDER),
176            segment_base_offset=int.from_bytes(
177                data[16:24], byteorder=_HEADER_BYTEORDER
178            ),
179        )
180
181    def is_valid(self) -> bool:
182        """Returns true if the extended header appears to be well-formed."""
183        return (
184            self.magic == _ExtendedHeader.EXPECTED_MAGIC
185            and self.length >= _ExtendedHeader.EXPECTED_LENGTH
186        )
187
188    def to_bytes(self) -> bytes:
189        """Returns the binary representation of the extended header.
190
191        Note that this will ignore self.magic and self.length and will always
192        write the proper magic/length.
193        """
194        data: bytes = (
195            # Extended header magic. This lets consumers detect whether the
196            # header was inserted or not. Always use the proper magic value
197            # (i.e., ignore self.magic) since there's no reason to create an
198            # invalid header.
199            self.EXPECTED_MAGIC
200            # uint32_t: Size of this header. This makes it easier to add new
201            # fields to this header in the future. Always use the proper size
202            # (i.e., ignore self.length) since there's no reason to create an
203            # invalid header.
204            + self.EXPECTED_LENGTH.to_bytes(4, byteorder=_HEADER_BYTEORDER)
205            # uint64_t: Size of the flatbuffer data, including this header.
206            + self.program_size.to_bytes(8, byteorder=_HEADER_BYTEORDER)
207            # uint64_t: Offset to the start of the first segment, or zero if
208            # there are no segments.
209            + self.segment_base_offset.to_bytes(8, byteorder=_HEADER_BYTEORDER)
210        )
211        return data
212
213
214def _pad_to(data: bytes, length: int) -> bytes:
215    """Returns the input followed by enough zero bytes to become the requested length.
216
217    Args:
218        data: The data to pad.
219        length: The length of the returned data.
220    Returns:
221        The padded data.
222    Raises:
223        ValueError: If the requested length is less than the input length.
224    """
225    if length < len(data):
226        raise ValueError(f"Data length {len(data)} > padded length {length}")
227    if length > len(data):
228        data = data + b"\x00" * (length - len(data))
229    assert len(data) == length
230    return data
231
232
233def _get_extended_header(program_data: bytes) -> Optional[_ExtendedHeader]:
234    """Returns the extended header of the program data, if present and valid."""
235    try:
236        eh = _ExtendedHeader.from_bytes(program_data[8:])
237        if eh.is_valid():
238            return eh
239    except ValueError:
240        pass
241    return None
242
243
244def _extract_delegate_segments(
245    program: Program,
246    segments: List[Cord],
247) -> None:
248    """Extracts the delegate segments inlined in the program into a list of buffers.
249        The program is modified in-place to remove the delegate data.
250
251    Args:
252        program: The program to extract segments from. Modified in-place.
253        segments: A list of buffers to append extracted segments to. Modified in-place.
254    """
255    remaining_inline: List[BackendDelegateInlineData] = []
256    inline_indices_seen: set[int] = set()
257    for plan in program.execution_plan:
258        for delegate in plan.delegates:
259            if delegate.processed.location != DataLocation.INLINE:
260                raise ValueError(
261                    "Program must only contain inline delegate data, "
262                    + f"saw {repr(delegate)}"
263                )
264            # TODO(T144120904): Don't extract small blobs into segments;
265            # have a cutoff. Or callers could provide a callback that
266            # returns true/false for a given BackendDelegate, letting them
267            # use their own logic.
268            try:
269                inline: BackendDelegateInlineData = program.backend_delegate_data[
270                    delegate.processed.index
271                ]
272            except IndexError:
273                raise ValueError(
274                    f"Delegate processed index {delegate.processed.index} "
275                    + ">= len(Program.backend_delegate_data) "
276                    + f"{len(program.backend_delegate_data)} "
277                    + f"in {repr(delegate)}"
278                )
279            inline_indices_seen.add(delegate.processed.index)
280            if inline.data:
281                # Move the delegate data out of the program.
282                segment_index = len(segments)
283                segments.append(Cord(inline.data))
284                delegate.processed = BackendDelegateDataReference(
285                    location=DataLocation.SEGMENT,
286                    index=segment_index,
287                )
288            else:
289                # Not moving into a segment. Keep it inline, but update the
290                # index.
291                new_index = len(remaining_inline)
292                remaining_inline.append(inline)
293                delegate.processed.index = new_index
294
295    # Make sure we visited all entries in backend_delegate_data, so that it's
296    # safe to overwrite it.
297    remaining_indices: set[int] = set(
298        range(len(program.backend_delegate_data))
299    ).difference(inline_indices_seen)
300    if remaining_indices:
301        raise ValueError(
302            "Did not handle all elements of backend_delegate_data; "
303            + f"remaining: {remaining_indices}"
304        )
305
306    # Preserve any entries that were not moved into segments.
307    program.backend_delegate_data = remaining_inline
308
309
310def _extract_constant_segment(
311    constant_buffer: List[Buffer],
312    tensor_alignment: Optional[int] = None,
313) -> Tuple[Cord, List[int]]:
314    """Copies the tensors from the provided list into a Cord and tracks the offsets
315        of each tensor.
316
317    Args:
318        constant_buffer: list of Buffers from which to extract constants from. Not modified.
319        tensor_alignment: Alignment in bytes. Each tensor in the cord will be padded to align
320            with this value. Defaults to ALIGNMENT.
321
322    Returns:
323        A tuple of (constant segment, list of offsets for each tensor in the segment)
324    """
325    constant_segment_data: Cord = Cord()
326    constant_segment_offsets: List[int] = []
327    current_offset: int = 0
328    for i in range(len(constant_buffer)):
329        buffer = constant_buffer[i]
330        constant_segment_data.append(buffer.storage)
331        buffer_length = len(buffer.storage)
332        pad_length = (
333            _padding_required(buffer_length, tensor_alignment)
334            if tensor_alignment is not None
335            else 0
336        )
337        if i < len(constant_buffer) - 1:
338            constant_segment_data.append(b"\x00" * pad_length)
339        constant_segment_offsets.append(current_offset)
340        current_offset += buffer_length + pad_length
341
342    return constant_segment_data, constant_segment_offsets
343
344
345def serialize_pte_binary(
346    program: Program,
347    *,
348    mutable_data: Optional[List[Buffer]] = None,
349    extract_delegate_segments: bool = False,
350    segment_alignment: int = 128,
351    constant_tensor_alignment: Optional[int] = None,
352    delegate_alignment: Optional[int] = None,
353) -> Cord:
354    """Returns the runtime binary representation of the given Program.
355
356    Args:
357        program: The Program to serialize.
358        extract_delegate_segments: Whether to move delegate data blobs from the
359            Program into separate segments, rather than encoding those blobs
360            in the flatbuffer data. When true, will also:
361            - Add an extended header to the output, containing the program size
362              and the starting segment offset.
363            - Update the Program.segments field with the offsets and lengths
364              of each segment.
365        segment_alignment: Alignment in bytes. The starting offset of each
366            segment will be aligned to this value in the output data.
367        constant_tensor_alignment: The minimum alignment of tensor
368            buffers in the program. Must be a power of 2. Defaults to ALIGNMENT.
369        delegate_alignment: If provided, the minimum alignment of delegate data
370            in the program. Must be a power of 2. If not provided, uses the
371            value in the schema file.
372    Returns:
373        The serialized form of the Program, ready for execution by the runtime.
374    """
375    # Default tensor alignment.
376    if constant_tensor_alignment is None:
377        constant_tensor_alignment = ALIGNMENT
378
379    # Don't modify the original program.
380    # TODO(T144120904): Could avoid yet more huge copies with a more shallow
381    # copy, reusing the actual data blobs.
382    program = copy.deepcopy(program)
383
384    # Store extracted segment data; this may be constant data or delegate data.
385    segments: List[Cord] = []
386
387    constant_segment_data, constant_segment_offsets = _extract_constant_segment(
388        program.constant_buffer, tensor_alignment=constant_tensor_alignment
389    )
390
391    # If there are no constants, len(constant_segment_data) = 0. However, there may
392    # be non-constants, in which case len(constant_segment_offsets) = 1, containing
393    # the placeholder value 0. Ensure the placeholder value is put into
394    # program.constant_segment.offsets.
395    if len(constant_segment_offsets) > 0:
396        # Update program.constant_segment with constant subsegment offset information.
397        program.constant_segment = SubsegmentOffsets(
398            segment_index=len(segments), offsets=constant_segment_offsets
399        )
400        # Clear the constant buffer, as constant data will be stored in segments.
401        program.constant_buffer = []
402        # Add to the aggregate segments cord.
403        segments.append(constant_segment_data)
404
405    if mutable_data is not None:
406        mutable_segment_data, mutable_segment_offsets = _extract_constant_segment(
407            mutable_data,
408            tensor_alignment=None,  # data is copied at Method load so no need to align.
409        )
410        if len(mutable_segment_data) > 0:
411            # Update program.mutable_segment_data with constant subsegment offset information.
412            program.mutable_data_segments = [
413                SubsegmentOffsets(
414                    segment_index=len(segments), offsets=mutable_segment_offsets
415                ),
416            ]
417            # Add to the aggregate segments cord.
418            segments.append(mutable_segment_data)
419
420    if extract_delegate_segments:
421        _extract_delegate_segments(program, segments)
422
423    # Append all segments into a single Cord, adding any necessary padding to ensure that
424    # each segment begins at the required alignment.
425    # Update program.segments with the offsets to each segment.
426    segments_data = Cord()
427    for data in segments:
428        prev_end = (
429            (program.segments[-1].offset + program.segments[-1].size)
430            if program.segments
431            else 0
432        )
433        program.segments.append(
434            DataSegment(
435                offset=_aligned_size(prev_end, segment_alignment), size=len(data)
436            )
437        )
438        # Add to aggregate segments cord with padding.
439        padding_length = _padding_required(len(segments_data), segment_alignment)
440        if padding_length > 0:
441            segments_data.append(b"\x00" * padding_length)
442        segments_data.append(data)
443
444    # Convert to a standard flatbuffer binary.
445    result: _FlatbufferResult = _program_json_to_flatbuffer(
446        _program_to_json(program),
447        constant_tensor_alignment=constant_tensor_alignment,
448        delegate_alignment=delegate_alignment,
449    )
450
451    # If there are no segments present, do not insert the extended header.
452    if len(segments_data) == 0:
453        return Cord(result.data)
454
455    # Size of the header to insert. Its size is padded to the largest
456    # force_align value present in the schema.
457    padded_header_length: int = _aligned_size(
458        input_size=_ExtendedHeader.EXPECTED_LENGTH,
459        alignment=result.max_alignment,
460    )
461    # Size of the program with the header inserted.
462    program_size: int = padded_header_length + len(result.data)
463    # Offset to the first segment, or zero if there are no segments.
464    segment_base_offset: int = (
465        _aligned_size(input_size=program_size, alignment=segment_alignment)
466        if len(segments_data) > 0
467        else 0
468    )
469
470    # Construct and pad the extended header.
471    header_data: bytes = _ExtendedHeader(
472        program_size=program_size, segment_base_offset=segment_base_offset
473    ).to_bytes()
474    header_data = _pad_to(header_data, padded_header_length)
475
476    # Insert the header into the flatbuffer data.
477    program_data: bytes = _insert_flatbuffer_header(
478        flatbuffer_data=result.data,
479        magic_regex=r"ET[0-9a-zA-Z][0-9a-zA-Z]",
480        header_data=header_data,
481    )
482    assert len(program_data) == program_size
483
484    # Potentially large. Try to free it as soon as we can.
485    del result.data
486
487    # Double-check that the extended header is in the right place and has the
488    # right contents.
489    eh = _get_extended_header(program_data)
490    assert eh is not None
491    assert eh.program_size == program_size
492    assert eh.segment_base_offset == segment_base_offset
493
494    # Construct the final pte file containing:
495    # - program data; written to offset 0.
496    # - segments data (optional); aligned to segment_alignment.
497    pte_data = Cord(program_data)
498    if len(segments_data) > 0:
499        padding_length = _padding_required(len(pte_data), segment_alignment)
500        pte_data.append(b"\x00" * padding_length)
501        # The first segment after program data should start at the segment base offset.
502        assert (
503            len(pte_data) == segment_base_offset
504        ), f"Offset of first segment {len(pte_data)} != segment base offset {segment_base_offset}"
505        pte_data.append(segments_data)
506    return pte_data
507
508
509def _restore_segments(program: Program, segment_data: bytes) -> Program:
510    """Moves segments from `segment_data` into `program`.
511
512    This should recreate the original Program that the segments were extracted
513    from.
514
515    Args:
516        program: The Program to restore. `program.segments` must describe the
517            segment locations.
518        segment_data: The data containing the segments. Assumes that this data
519            begins at `segment_base_offset` from the extended header: i.e.,
520            the preceding data has been stripped off so that the first segment
521            begins at offset zero.
522    Returns:
523        The Program with segments restored.
524    """
525    # Extract the list of segment data blobs, which parallel program.segments.
526    segments: List[bytes] = []
527    for i, segment in enumerate(program.segments):
528        if segment.offset + segment.size > len(segment_data):
529            raise ValueError(
530                f"Segment {i} {segment} overflows data length {len(segment_data)}"
531            )
532        segments.append(segment_data[segment.offset : segment.offset + segment.size])
533
534    # Find and replace the Program's references to these segments, inlining the
535    # data.
536    for plan_index, plan in enumerate(program.execution_plan):
537        for delegate_index, delegate in enumerate(plan.delegates):
538            if delegate.processed.location == DataLocation.INLINE:
539                continue
540            assert delegate.processed.location == DataLocation.SEGMENT
541            index = delegate.processed.index
542            if index >= len(segments):
543                raise ValueError(
544                    f"Plan {plan_index} delegate {delegate_index} "
545                    + f"segment index {index} >= num segments {len(segments)}"
546                )
547
548            data_index: int = len(program.backend_delegate_data)
549            program.backend_delegate_data.append(
550                BackendDelegateInlineData(data=segments[index])
551            )
552            delegate.processed = BackendDelegateDataReference(
553                location=DataLocation.INLINE, index=data_index
554            )
555
556    # Replace constants from constant_segment into constant_buffer.
557    if program.constant_segment and len(program.constant_segment.offsets) > 0:
558        buffers: List[Buffer] = []
559        constant_segment = segments[program.constant_segment.segment_index]
560        for i in range(len(program.constant_segment.offsets)):
561            start_offset = program.constant_segment.offsets[i]
562            # Note: this is the original end offset plus any padding between
563            # it and the next start offset.
564            end_offset = (
565                program.constant_segment.offsets[i + 1]
566                if i < len(program.constant_segment.offsets) - 1
567                else len(constant_segment)
568            )
569            buffers.append(Buffer(storage=constant_segment[start_offset:end_offset]))
570        program.constant_buffer = buffers
571        program.constant_segment.segment_index = 0
572        program.constant_segment.offsets = []
573
574    # Clear out the segments list since the original Program didn't have one.
575    program.segments = []
576    return program
577
578
579def deserialize_pte_binary(program_data: bytes) -> Program:
580    """Returns a Program deserialized from the given runtime binary data."""
581    program_size = len(program_data)
582    segment_base_offset = 0
583
584    # Look for an extended header to see if segments follow the flatbuffer
585    # data.
586    eh: Optional[_ExtendedHeader] = _get_extended_header(program_data)
587    if eh and eh.is_valid():
588        program_size = eh.program_size
589        segment_base_offset = eh.segment_base_offset
590
591    # Parse the flatbuffer data.
592    program: Program = _json_to_program(
593        _program_flatbuffer_to_json(program_data[:program_size])
594    )
595
596    if segment_base_offset != 0:
597        # Move segment data back into the Program.
598        program = _restore_segments(
599            program=program, segment_data=program_data[segment_base_offset:]
600        )
601
602    return program
603