xref: /aosp_15_r20/external/executorch/exir/_serialize/test/test_program.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1#!/usr/bin/env fbpython
2# Copyright (c) Meta Platforms, Inc. and affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8# pyre-unsafe
9
10import copy
11import difflib
12import json
13import unittest
14
15from typing import List, Sequence
16
17from executorch.exir._serialize._flatbuffer import _program_flatbuffer_to_json
18from executorch.exir._serialize._program import (
19    _ExtendedHeader,
20    _get_extended_header,
21    _json_to_program,
22    _program_to_json,
23    deserialize_pte_binary,
24    serialize_pte_binary,
25)
26
27from executorch.exir.schema import (
28    BackendDelegate,
29    BackendDelegateDataReference,
30    BackendDelegateInlineData,
31    Buffer,
32    ContainerMetadata,
33    DataLocation,
34    DataSegment,
35    ExecutionPlan,
36    Program,
37    SubsegmentOffsets,
38)
39from executorch.exir.tests.common import get_test_program
40
41SEGMENT_ALIGNMENT: int = 128
42
43CONSTANT_TENSOR_ALIGNMENT: int = 16
44
45
46def add_constant_data(program: Program, blobs: Sequence[bytes]) -> None:
47    """Adds the provided constant data blobs to the program."""
48    for blob in blobs:
49        program.constant_buffer.append(Buffer(storage=blob))
50
51
52def add_delegate_data(
53    program: Program, plan: ExecutionPlan, blobs: Sequence[bytes]
54) -> None:
55    """Adds the provided delegate data blobs to the execution plan."""
56    di = len(plan.delegates)
57    for blob in blobs:
58        data_index: int = len(program.backend_delegate_data)
59        program.backend_delegate_data.append(
60            BackendDelegateInlineData(
61                data=blob,
62            )
63        )
64        delegate = BackendDelegate(
65            id=f"delegate{di}",
66            processed=BackendDelegateDataReference(
67                location=DataLocation.INLINE,
68                index=data_index,
69            ),
70            compile_specs=[],
71        )
72        plan.delegates.append(delegate)
73        di += 1
74
75
76def canonicalize_delegate_indices(program: Program) -> Program:
77    """Returns a copy of the program with the backend delegate data list in
78    a predictable order.
79    """
80    program = copy.deepcopy(program)
81
82    # Original index and its data.
83    delegate_entries: list[tuple[int, bytes]] = [
84        (i, entry.data) for i, entry in enumerate(program.backend_delegate_data)
85    ]
86
87    # Sort by the contents of the data, which is the second entry in the tuple.
88    # NOTE: This is unstable if multiple entries have the same data contents.
89    delegate_entries.sort(key=lambda x: x[1])
90
91    # Build up the sorted Program.backend_delegate_data list, and a mapping from
92    # the old index to the new index.
93    old_to_new_index: dict[int, int] = {}
94    program.backend_delegate_data = []
95    for i, data in delegate_entries:
96        old_to_new_index[i] = len(program.backend_delegate_data)
97        print(f">>> Mapping [{i}]: {old_to_new_index[i]} '{data}'")
98        program.backend_delegate_data.append(BackendDelegateInlineData(data=data))
99
100    # Patch up the index pointers from the BackendDelegate entries.
101    for plan in program.execution_plan:
102        for delegate in plan.delegates:
103            delegate.processed.index = old_to_new_index[delegate.processed.index]
104
105    return program
106
107
108class TestProgram(unittest.TestCase):
109    def assert_file_magic_present(self, program_data: bytes) -> None:
110        self.assertEqual(program_data[4:6], b"ET")
111        # Ignore the other bytes, which can change over time and are not
112        # important for this test.
113
114    def assert_programs_equal(self, program1: Program, program2: Program) -> None:
115        def prepare_json_string(j: str) -> List[str]:
116            """Formats the JSON and splits it into lines."""
117            return json.dumps(json.loads(j), indent=2, sort_keys=True).splitlines(
118                keepends=True
119            )
120
121        # This JSON comparison is fragile: some parts of the program do not care
122        # about order (like the operators list), so those are technically free
123        # to be reordered. If they become a problem, we can canonicalize them
124        # like we do for the backend delegate data list.
125        json1 = _program_to_json(canonicalize_delegate_indices(program1))
126        json2 = _program_to_json(canonicalize_delegate_indices(program2))
127
128        # Use unified_diff so it only prints the differences instead of the
129        # entire string.
130        diff: str = "".join(
131            difflib.unified_diff(
132                prepare_json_string(json1),
133                prepare_json_string(json2),
134            )
135        )
136        if diff:
137            self.fail(msg="Programs are not equal\n" + diff)
138
139    def get_and_validate_extended_header(self, pte_data: bytes) -> _ExtendedHeader:
140        """When an extended header is expected, check that it exists and is valid.
141        Does not check correctness of the contents."""
142        eh = _get_extended_header(pte_data)
143        self.assertIsNotNone(eh)
144        self.assertTrue(eh.is_valid())
145        self.assertLess(eh.program_size, len(pte_data))
146        return eh
147
148    def constant_segment_with_tensor_alignment(
149        self, constant_tensor_alignment: int
150    ) -> None:
151        """Utility to test constant segment with varying alignment.
152        Args:
153            constant_tensor_alignment: Alignment of constant tensor data.
154                Must be a multiple of 2.
155                Must be > 8 for the purposes of the test, which checks +- 3 bytes on the edges of each tensor.
156        """
157        # Create a program with some constant tensor data.
158        program = get_test_program()
159        blobs = (
160            b"",  # Empty tensor.
161            self.gen_blob_data(constant_tensor_alignment // 2, b"\x10\x11\x01"),
162            self.gen_blob_data(constant_tensor_alignment - 1, b"\x20\x22\x02"),
163            self.gen_blob_data(constant_tensor_alignment, b"\x30\x33\x03"),
164            self.gen_blob_data(constant_tensor_alignment + 1, b"\x40\x44\x04"),
165        )
166        add_constant_data(program, blobs)
167
168        # Extract blobs into constant segment during serialization.
169        pte_data = bytes(
170            serialize_pte_binary(
171                program,
172                segment_alignment=SEGMENT_ALIGNMENT,
173                constant_tensor_alignment=constant_tensor_alignment,
174            )
175        )
176
177        # The input Program should not be modified.
178        self.assertEqual(program.segments, [])
179
180        # Extended header should be present in the serialized data.
181        eh = self.get_and_validate_extended_header(pte_data)
182
183        # Segment offset should be non-zero since there are segments. It
184        # should point past the end of the program data, but not beyond
185        # the end of the file.
186        self.assertGreaterEqual(eh.segment_base_offset, eh.program_size)
187        self.assertLess(eh.segment_base_offset, len(pte_data))
188
189        # Peek inside the actual flatbuffer data to see the segments.
190        program_with_segments = _json_to_program(_program_flatbuffer_to_json(pte_data))
191
192        # The constant tensor data should appear as the only segment.
193        self.assertEqual(len(program_with_segments.segments), 1)
194
195        # The constant buffer should appear now as a constant segment.
196        segment_table: List[DataSegment] = program_with_segments.segments
197        self.assertEqual(len(segment_table), 1)
198        # Tensor sizes
199        # - tensor[0]: 0
200        # - tensors[1,2,3]: constant_tensor_alignment
201        # - tensor[4]: constant_tensor_alignment + 1 (no padding on the last tensor)
202        self.assertEqual(
203            segment_table[0].size,
204            constant_tensor_alignment * 3 + (constant_tensor_alignment + 1),
205        )
206
207        # Check constant_segment index and offsets.
208        subsegment_offsets: SubsegmentOffsets = program_with_segments.constant_segment
209        self.assertEqual(subsegment_offsets.segment_index, 0)
210        self.assertEqual(
211            subsegment_offsets.offsets,
212            [
213                0,  # Start at offset 0.
214                0,  # tensor[0] is empty.
215                constant_tensor_alignment,  # tensor[1] has size constant_tensor_alignment // 2. Round up.
216                constant_tensor_alignment
217                * 2,  # tensor[2] has size constant_tensor_alignment - 1. Round up.
218                constant_tensor_alignment
219                * 3,  # tensor[3] has size constant_tensor_alignment. No padding needed.
220            ],
221        )
222
223        # Check constant_buffer is empty, because the data was moved into the segment.
224        self.assertEqual(len(program_with_segments.constant_buffer), 0)
225
226        # Check segment data.
227        offsets = subsegment_offsets.offsets
228        segment_data: bytes = pte_data[eh.segment_base_offset :]
229
230        # tensor[1]: padding.
231        self.assertEqual(
232            segment_data[offsets[1] : offsets[1] + 3],
233            # Tensor data.
234            b"\x10\x11\x11",
235        )
236        self.assertEqual(
237            segment_data[
238                offsets[1]
239                + constant_tensor_alignment // 2 : offsets[1]
240                + constant_tensor_alignment // 2
241                + 3
242            ],
243            # Padding.
244            b"\x00\x00\x00",
245        )
246
247        # tensor[3]: no padding.
248        self.assertEqual(
249            segment_data[offsets[4] - 3 : offsets[4] + 3],
250            # End of tensor 3.
251            b"\x33\x33\x03"
252            # Start of tensor 4.
253            + b"\x40\x44\x44",
254        )
255
256        # tensor[4]: no padding for last tensor.
257        self.assertEqual(
258            segment_data[
259                offsets[4]
260                + constant_tensor_alignment
261                - 3 : offsets[4]
262                + constant_tensor_alignment
263                + 1
264            ],
265            b"\x44\x44\x44\x04",
266        )
267
268        # The final segment should not point past the end of the file.
269        self.assertLessEqual(
270            segment_table[-1].offset + segment_table[-1].size,
271            len(pte_data),
272            f"{segment_table}",
273        )
274
275        # Convert back.
276        program2 = deserialize_pte_binary(pte_data)
277        # Programs are the same besides constant_buffer, as deserialization
278        # does not preserve constant segment; padding may be added
279        # during serialization.
280        self.assertEqual(program2.execution_plan, program.execution_plan)
281        # Number of constant tensors should be the same.
282        self.assertEqual(len(program2.constant_buffer), len(program.constant_buffer))
283
284    def test_canonicalize_delegate_indices(self) -> None:
285        def make_execution_plan(
286            name: str, delegates: List[BackendDelegate]
287        ) -> ExecutionPlan:
288            return ExecutionPlan(
289                name=name,
290                container_meta_type=ContainerMetadata(
291                    encoded_inp_str="encoded_inp_str",
292                    encoded_out_str="encoded_out_str",
293                ),
294                values=[],
295                inputs=[],
296                outputs=[],
297                chains=[],
298                operators=[],
299                delegates=delegates,
300                non_const_buffer_sizes=[],
301            )
302
303        # A program with three delegates across two execution plans. To start
304        # with, the data indices in the delegates are in a non-canonical order.
305        program = Program(
306            version=0,
307            execution_plan=[
308                make_execution_plan(
309                    name="forward0",
310                    delegates=[
311                        BackendDelegate(
312                            id="delegate0",
313                            processed=BackendDelegateDataReference(
314                                location=DataLocation.INLINE, index=2
315                            ),
316                            compile_specs=[],
317                        ),
318                        BackendDelegate(
319                            id="delegate1",
320                            processed=BackendDelegateDataReference(
321                                location=DataLocation.INLINE, index=1
322                            ),
323                            compile_specs=[],
324                        ),
325                    ],
326                ),
327                make_execution_plan(
328                    name="forward1",
329                    delegates=[
330                        BackendDelegate(
331                            id="delegate2",
332                            processed=BackendDelegateDataReference(
333                                location=DataLocation.INLINE, index=0
334                            ),
335                            compile_specs=[],
336                        ),
337                    ],
338                ),
339            ],
340            constant_buffer=[],
341            backend_delegate_data=[
342                # Data is in non-canonical (unsorted) order.
343                BackendDelegateInlineData(data=b"CC delegate [1,0] data"),
344                BackendDelegateInlineData(data=b"BB delegate [0,1] data"),
345                BackendDelegateInlineData(data=b"AA delegate [0,0] data"),
346            ],
347            segments=[],
348            constant_segment=SubsegmentOffsets(segment_index=0, offsets=[]),
349        )
350
351        # Demonstrate which data each delegate points to.
352        self.assertEqual(
353            program.backend_delegate_data[
354                program.execution_plan[0].delegates[0].processed.index
355            ].data,
356            b"AA delegate [0,0] data",
357        )
358        self.assertEqual(
359            program.backend_delegate_data[
360                program.execution_plan[0].delegates[1].processed.index
361            ].data,
362            b"BB delegate [0,1] data",
363        )
364        self.assertEqual(
365            program.backend_delegate_data[
366                program.execution_plan[1].delegates[0].processed.index
367            ].data,
368            b"CC delegate [1,0] data",
369        )
370
371        # Canonicalize the program.
372        canonical_program: Program = canonicalize_delegate_indices(program)
373
374        # The delegate data list should be sorted by contents.
375        self.assertListEqual(
376            canonical_program.backend_delegate_data,
377            [
378                # Should have been sorted.
379                BackendDelegateInlineData(data=b"AA delegate [0,0] data"),
380                BackendDelegateInlineData(data=b"BB delegate [0,1] data"),
381                BackendDelegateInlineData(data=b"CC delegate [1,0] data"),
382            ],
383        )
384
385        # Demonstrate that the delegate entries still point to the correct data.
386        self.assertEqual(
387            canonical_program.backend_delegate_data[
388                canonical_program.execution_plan[0].delegates[0].processed.index
389            ].data,
390            b"AA delegate [0,0] data",
391        )
392        self.assertEqual(
393            canonical_program.backend_delegate_data[
394                canonical_program.execution_plan[0].delegates[1].processed.index
395            ].data,
396            b"BB delegate [0,1] data",
397        )
398        self.assertEqual(
399            canonical_program.backend_delegate_data[
400                canonical_program.execution_plan[1].delegates[0].processed.index
401            ].data,
402            b"CC delegate [1,0] data",
403        )
404
405    def test_round_trip_no_header_no_segments(self) -> None:
406        """Tests that a Program remains the same after serializing and
407        deserializing.
408        """
409        program = get_test_program()
410        pte_data = bytes(serialize_pte_binary(program))
411        self.assertGreater(len(pte_data), 16)
412
413        # File magic should be present at the expected offset.
414        self.assert_file_magic_present(pte_data)
415
416        # Extended header should not be present.
417        eh = _get_extended_header(pte_data)
418        self.assertIsNone(eh)
419
420        # Convert back.
421        program2 = deserialize_pte_binary(pte_data)
422
423        # Programs should be the same.
424        self.assert_programs_equal(program, program2)
425
426    def test_round_trip_large_buffer_sizes(self) -> None:
427        """Tests that when the non_const_buffer_sizes contains integers
428        overflowing a signed/unsigned 32 bit integer, we can still serialize the
429        model and get the same program by deserialization.
430        """
431        program = get_test_program()
432        program.execution_plan[0].non_const_buffer_sizes = [0, 2**48]
433        flatbuffer_from_py = bytes(serialize_pte_binary(program))
434        self.assert_programs_equal(program, deserialize_pte_binary(flatbuffer_from_py))
435
436    def test_round_trip_no_segments_and_no_header(self) -> None:
437        """Tests that a Program serialized with extract_delegate_segments=True
438        when there are no segments does not contain an extended header,
439        constant segment, or delegate segments. Confirm that a Program remains
440        the same after serializing and deserializing.
441        """
442        program = get_test_program()
443        pte_data = bytes(serialize_pte_binary(program, extract_delegate_segments=True))
444        self.assertGreater(len(pte_data), 16)
445
446        # File magic should be present at the expected offset.
447        self.assert_file_magic_present(pte_data)
448
449        # Extended header should not be present when no segments are created.
450        eh = _get_extended_header(pte_data)
451        self.assertIsNone(eh)
452
453        # Peek inside the flatbuffer data to confirm that there are no segments.
454        program_with_segments = _json_to_program(_program_flatbuffer_to_json(pte_data))
455        self.assertEqual(program_with_segments.segments, [])
456
457        # Convert back.
458        program2 = deserialize_pte_binary(pte_data)
459
460        # Programs should be the same.
461        self.assert_programs_equal(program, program2)
462
463    @staticmethod
464    def gen_blob_data(size: int, pattern: bytes) -> bytes:
465        """Generates a buffer with special first and last bytes,
466        repeating the middle byte of the pattern."""
467        assert len(pattern) == 3
468        assert size >= 3
469        # Stretch out the middle byte to fill the space.
470        ret = pattern[0:1] + pattern[1:2] * (size - 2) + pattern[2:3]
471        assert len(ret) == size
472        return ret
473
474    def test_round_trip_with_segments(self) -> None:
475        # Create a program with some delegate data blobs.
476        program = get_test_program()
477        blobs = (
478            self.gen_blob_data(SEGMENT_ALIGNMENT // 5, b"\x10\x11\x01"),
479            # Focus on blobs whose sizes fall close to the alignment.
480            self.gen_blob_data(SEGMENT_ALIGNMENT - 1, b"\x20\x22\x02"),
481            self.gen_blob_data(SEGMENT_ALIGNMENT, b"\x30\x33\x03"),
482            self.gen_blob_data(SEGMENT_ALIGNMENT + 1, b"\x40\x44\x04"),
483            b"",  # Empty segment.
484            self.gen_blob_data(SEGMENT_ALIGNMENT // 10, b"\x50\x55\x05"),
485        )
486        add_delegate_data(program, program.execution_plan[0], blobs)
487
488        # Extract the blobs into segments during serialization.
489        pte_data = bytes(
490            serialize_pte_binary(
491                program,
492                extract_delegate_segments=True,
493                segment_alignment=SEGMENT_ALIGNMENT,
494            )
495        )
496
497        # The input Program should not have been modified.
498        self.assertEqual(program.segments, [])
499        self.assertEqual(
500            program.execution_plan[0].delegates[0].processed.location,
501            DataLocation.INLINE,
502        )
503
504        # Extended header should be present in the serialized data.
505        eh = self.get_and_validate_extended_header(pte_data)
506        # Segment offset should be non-zero since there are segments. It
507        # should point past the end of the program data, but not beyond
508        # the end of the file.
509        self.assertGreaterEqual(eh.segment_base_offset, eh.program_size)
510        self.assertLess(eh.segment_base_offset, len(pte_data))
511
512        # Peek inside the actual flatbuffer data to see the segments. Note that
513        # this also implicity tests the case where we try parsing the entire
514        # file with segment data following it, demonstrating that the extra data
515        # doesn't upset the flatbuffer parsing path.
516        program_with_segments = _json_to_program(_program_flatbuffer_to_json(pte_data))
517
518        # The delegate blobs we added to the program should appear as segments.
519        # The one empty blob should have been ignored, hence the `- 1`.
520        self.assertEqual(len(program_with_segments.segments), len(blobs) - 1)
521        segment_table: List[DataSegment] = program_with_segments.segments
522
523        # Check segment range invariants.
524        for i in range(len(segment_table)):
525            # All offsets should be a multiple of SEGMENT_ALIGNMENT.
526            self.assertTrue(
527                segment_table[i].offset % SEGMENT_ALIGNMENT == 0,
528                f"Segment {i} offset is not aligned: {segment_table[i]}",
529            )
530            # There should be no empty segments.
531            self.assertGreater(
532                segment_table[i].size, 0, f"Segment {i}: {segment_table}"
533            )
534            if i > 0:
535                # Segments should not overlap, and should be sorted from
536                # smallest offset to largest.
537                self.assertLessEqual(
538                    segment_table[i - 1].offset + segment_table[i - 1].size,
539                    segment_table[i].offset,
540                    f"Segment {i} overlaps or is out of order: {segment_table}",
541                )
542        # The first segment should begin at zero; i.e., at the segment base
543        # offset.
544        self.assertEqual(segment_table[0].offset, 0, f"{segment_table}")
545        # The final segment should not point past the end of the file.
546        self.assertLessEqual(
547            segment_table[-1].offset + segment_table[-1].size,
548            len(pte_data),
549            f"{segment_table}",
550        )
551
552        # Check the segment base offset boundary.
553        segment_base_offset = eh.segment_base_offset
554        self.assertEqual(
555            pte_data[segment_base_offset - 2 : segment_base_offset + 3],
556            # The padding before the first segment.
557            b"\x00\x00"
558            # The first few bytes of the first segment.
559            + b"\x10\x11\x11",
560        )
561
562        # Now that we've shown that the base offset is correct, slice off the
563        # front so that all segment offsets are relative to zero.
564        segment_data: bytes = pte_data[segment_base_offset:]
565
566        # End of the first segment. It's much smaller than the alignment,
567        # so we know that it's followed by zeros.
568        self.assertEqual(
569            segment_data[segment_table[0].size - 3 : segment_table[0].size + 2],
570            # The end of the segment.
571            b"\x11\x11\x01"
572            # The padding that follows it.
573            + b"\x00\x00",
574        )
575
576        # Look at the end of segment[2], which is exactly the same size as
577        # the alignment. There should be no padding, running right into the
578        # next segment.
579        self.assertEqual(
580            segment_data[segment_table[3].offset - 3 : segment_table[3].offset + 3],
581            # The end of segment[2].
582            b"\x33\x33\x03"
583            # The beginning of segment[3]
584            b"\x40\x44\x44",
585        )
586
587        # Convert back; the programs should be the same after a round trip,
588        # meaning that the segments were moved back to inline. This also
589        # demonstrates that the contents of all segments survived, and weren't
590        # truncated or corrupted.
591        program2 = deserialize_pte_binary(pte_data)
592        self.assert_programs_equal(program, program2)
593
594    def test_no_constants(self) -> None:
595        program = get_test_program()
596        # Insert placeholder for non-const tensors.
597        add_constant_data(program, [b""])
598
599        pte_data = bytes(
600            serialize_pte_binary(
601                program,
602                extract_delegate_segments=True,
603                segment_alignment=SEGMENT_ALIGNMENT,
604                constant_tensor_alignment=CONSTANT_TENSOR_ALIGNMENT,
605            )
606        )
607        # The input Program should not be modified.
608        self.assertEqual(program.segments, [])
609
610        # Peek inside the actual flatbuffer data to see the segments.
611        flatbuffer_program = _json_to_program(_program_flatbuffer_to_json(pte_data))
612
613        # Constant buffer should be empty.
614        self.assertEqual(len(flatbuffer_program.constant_buffer), 0)
615
616        # Constant segment should contain the placeholder.
617        self.assertEqual(flatbuffer_program.constant_segment.segment_index, 0)
618        self.assertEqual(len(flatbuffer_program.constant_segment.offsets), 1)
619        self.assertEqual(flatbuffer_program.constant_segment.offsets[0], 0)
620
621    def test_unused_inline_delegate_blobs_with_segments(self) -> None:
622        # Create a program with some delegate data blobs.
623        program = get_test_program()
624        blobs = (
625            self.gen_blob_data(16, b"\x10\x11\x01"),
626            self.gen_blob_data(32, b"\x20\x22\x02"),
627        )
628        add_delegate_data(program, program.execution_plan[0], blobs)
629
630        # Extract the blobs into segments should succeeed.
631        pte_data = bytes(
632            serialize_pte_binary(
633                program,
634                extract_delegate_segments=True,
635                segment_alignment=SEGMENT_ALIGNMENT,
636            )
637        )
638        self.assertGreater(len(pte_data), 16)
639
640        # Add another inline blob that is not pointed to by a delegate.
641        program.backend_delegate_data.append(
642            BackendDelegateInlineData(data=self.gen_blob_data(16, b"\x30\x33\x03"))
643        )
644
645        # Should cause serialization to fail.
646        with self.assertRaises(ValueError):
647            serialize_pte_binary(
648                program,
649                extract_delegate_segments=True,
650                segment_alignment=SEGMENT_ALIGNMENT,
651            )
652
653    def test_constant_segment_tensor_alignment_16(self) -> None:
654        self.constant_segment_with_tensor_alignment(16)
655
656    def test_constant_segment_tensor_alignment_128(self) -> None:
657        self.constant_segment_with_tensor_alignment(128)
658
659    def test_constant_segment_tensor_alignment_non_power_of_2_fails(self) -> None:
660        # Create a program with some constant tensor data.
661        program = get_test_program()
662        program.constant_buffer.append(Buffer(storage=b"12345"))
663
664        constant_tensor_alignment: int = 14
665        # Extract blobs into constant segment during serialization.
666        # Expect failure as tensor alignment 14 is not a power of 2.
667        with self.assertRaises(ValueError):
668            serialize_pte_binary(
669                program,
670                segment_alignment=SEGMENT_ALIGNMENT,
671                constant_tensor_alignment=constant_tensor_alignment,
672            )
673
674    def test_constant_segment_and_delegate_segment(self) -> None:
675        # Create a program with some constant tensor data and delegate data blobs.
676        program = get_test_program()
677        constant_blobs = (
678            self.gen_blob_data(CONSTANT_TENSOR_ALIGNMENT // 2, b"\x10\x11\x01"),
679            self.gen_blob_data(CONSTANT_TENSOR_ALIGNMENT + 1, b"\x20\x22\x02"),
680        )
681        delegate_blobs = (
682            self.gen_blob_data(SEGMENT_ALIGNMENT // 2, b"\x30\x33\x03"),
683            self.gen_blob_data(SEGMENT_ALIGNMENT + 1, b"\x40\x44\x04"),
684        )
685
686        add_constant_data(program, constant_blobs)
687        add_delegate_data(program, program.execution_plan[0], delegate_blobs)
688
689        # Extract the blobs into segments during serialization.
690        pte_data = bytes(
691            serialize_pte_binary(
692                program,
693                extract_delegate_segments=True,
694                segment_alignment=SEGMENT_ALIGNMENT,
695                constant_tensor_alignment=CONSTANT_TENSOR_ALIGNMENT,
696            )
697        )
698
699        # The input Program should not be modified.
700        self.assertEqual(program.segments, [])
701        self.assertEqual(
702            program.execution_plan[0].delegates[0].processed.location,
703            DataLocation.INLINE,
704        )
705
706        # Extended header should be present in the serialized data.
707        eh = self.get_and_validate_extended_header(pte_data)
708
709        # Segment offset should be non-zero since there are segments. It
710        # should point past the end of the program data, but not beyond
711        # the end of the file.
712        self.assertGreaterEqual(eh.segment_base_offset, eh.program_size)
713        self.assertLess(eh.segment_base_offset, len(pte_data))
714
715        # Peek inside the actual flatbuffer data to see the segments.
716        program_with_segments = _json_to_program(_program_flatbuffer_to_json(pte_data))
717
718        # Segment table should contain a constant segment and the delegate blobs.
719        segment_table: List[DataSegment] = program_with_segments.segments
720        self.assertEqual(len(segment_table), len(delegate_blobs) + 1)
721        self.assertEqual(segment_table[0].offset, 0)
722        # segment_table[0] is the constant segment, which
723        # contains a couple of tensors with sizes:
724        # - tensor[0] = CONSTANT_TENSOR_ALIGNMENT
725        # - tensor[1] = CONSTANT_TENSOR_ALIGNMENT + 1 (no padding on last tensor)
726        self.assertEqual(segment_table[0].size, CONSTANT_TENSOR_ALIGNMENT * 2 + 1)
727        self.assertEqual(segment_table[1].offset, SEGMENT_ALIGNMENT)
728        self.assertEqual(segment_table[1].size, SEGMENT_ALIGNMENT // 2)
729        self.assertEqual(segment_table[2].offset, SEGMENT_ALIGNMENT * 2)
730        self.assertEqual(segment_table[2].size, SEGMENT_ALIGNMENT + 1)
731
732        # Check constant_segment index and offsets.
733        subsegment_offsets: SubsegmentOffsets = program_with_segments.constant_segment
734        self.assertEqual(subsegment_offsets.segment_index, 0)
735        self.assertEqual(
736            subsegment_offsets.offsets,
737            [
738                0,  # Start at offset 0.
739                16,  # tensor[0] has size CONSTANT_TENSOR_ALIGNMENT. No padding required.
740            ],
741        )
742
743        # Check constant_buffer is empty, because the data was moved into the segment.
744        self.assertEqual(len(program_with_segments.constant_buffer), 0)
745
746        # The first segment should begin at zero; i.e., at the segment base
747        # offset.
748        self.assertEqual(segment_table[0].offset, 0, f"{segment_table}")
749        # The final segment should not point past the end of the file.
750        self.assertLessEqual(
751            segment_table[-1].offset + segment_table[-1].size,
752            len(pte_data),
753            f"{segment_table}",
754        )
755
756        # Check the segment base offset boundary.
757        segment_base_offset = eh.segment_base_offset
758        self.assertEqual(
759            pte_data[segment_base_offset - 2 : segment_base_offset + 3],
760            # Padding before the first segment.
761            b"\x00\x00"
762            # First few bytes of the first segment.
763            + b"\x10\x11\x11",
764        )
765
766        # Now that we've shown that the base offset is correct, slice off the
767        # front so that all segment offsets are relative to zero.
768        segment_data: bytes = pte_data[segment_base_offset:]
769
770        # Check segment[0] for constants.
771        offsets = subsegment_offsets.offsets
772        # Check tensor[0]: padding at the end.
773        self.assertEqual(
774            segment_data[0 : offsets[1]],
775            # Tensor data.
776            b"\x10\x11\x11\x11\x11\x11\x11\x01"
777            # Padding.
778            + b"\x00\x00\x00\x00\x00\x00\x00\x00",
779        )
780
781        # Check tensor[1]: padding at CONSTANT_TENSOR_ALIGNMENT.
782        self.assertEqual(
783            segment_data[
784                offsets[1]
785                + CONSTANT_TENSOR_ALIGNMENT
786                - 3 : offsets[1]
787                + CONSTANT_TENSOR_ALIGNMENT
788                + 3
789            ],
790            # Tensor data.
791            b"\x22\x22\x22"
792            # Padding.
793            + b"\x02\x00\x00",
794        )
795
796        # Check segment[0] and segment[1] border.
797        self.assertEqual(
798            segment_data[segment_table[1].offset - 3 : segment_table[1].offset + 3],
799            # Padding for segment[0].
800            b"\x00\x00\x00"
801            # Start of segment[1].
802            + b"\x30\x33\x33",
803        )
804
805        # Check segment[1] and segment[2] border.
806        self.assertEqual(
807            segment_data[segment_table[2].offset - 3 : segment_table[2].offset + 3],
808            # Padding for segment[1].
809            b"\x00\x00\x00"
810            # Start of segment[2].
811            + b"\x40\x44\x44",
812        )
813
814        # Convert back.
815        program2 = deserialize_pte_binary(pte_data)
816        # Programs are the same besides constant_buffer, as deserialization
817        # does not preserve constant segment; padding may be added
818        # during serialization.
819        self.assertEqual(program2.execution_plan, program.execution_plan)
820        # Number of constant tensors should be the same.
821        self.assertEqual(len(program2.constant_buffer), len(program.constant_buffer))
822
823
824# Common data for extended header tests. The two example values should produce
825# the example data.
826EXAMPLE_PROGRAM_SIZE: int = 0x1122112233443344
827EXAMPLE_SEGMENT_BASE_OFFSET: int = 0x5566556677887788
828# This data is intentionally fragile. If the header layout or magic changes,
829# this test must change too. The layout of the header is a contract, not an
830# implementation detail.
831EXAMPLE_HEADER_DATA: bytes = (
832    # Magic bytes
833    b"eh00"
834    # uint32_t header size (little endian)
835    + b"\x18\x00\x00\x00"
836    # uint64_t program size
837    + b"\x44\x33\x44\x33\x22\x11\x22\x11"
838    # uint64_t segment base offset
839    + b"\x88\x77\x88\x77\x66\x55\x66\x55"
840)
841
842
843class TestExtendedHeader(unittest.TestCase):
844    def test_to_bytes(self) -> None:
845        eh = _ExtendedHeader(
846            program_size=EXAMPLE_PROGRAM_SIZE,
847            segment_base_offset=EXAMPLE_SEGMENT_BASE_OFFSET,
848        )
849        self.assertTrue(eh.is_valid())
850        self.assertEqual(eh.to_bytes(), EXAMPLE_HEADER_DATA)
851
852    def test_to_bytes_with_non_defaults(self) -> None:
853        eh = _ExtendedHeader(
854            program_size=EXAMPLE_PROGRAM_SIZE,
855            segment_base_offset=EXAMPLE_SEGMENT_BASE_OFFSET,
856            # Override the default magic and length, to demonstrate that this
857            # does not affect the serialized header.
858            magic=b"ABCD",
859            length=0xAABBCCDD,
860        )
861        # No longer counts as valid.
862        self.assertFalse(eh.is_valid())
863
864        # But still produces a valid output header, since to_bytes() ignores
865        # magic and length.
866        self.assertEqual(eh.to_bytes(), EXAMPLE_HEADER_DATA)
867
868    def test_from_bytes_valid(self) -> None:
869        # Parse the serialized extended header.
870        eh = _ExtendedHeader.from_bytes(EXAMPLE_HEADER_DATA)
871
872        # This is a valid header: good magic and length.
873        self.assertTrue(eh.is_valid())
874
875        self.assertEqual(eh.magic, _ExtendedHeader.EXPECTED_MAGIC)
876        self.assertEqual(eh.length, _ExtendedHeader.EXPECTED_LENGTH)
877        self.assertEqual(eh.program_size, EXAMPLE_PROGRAM_SIZE)
878        self.assertEqual(eh.segment_base_offset, EXAMPLE_SEGMENT_BASE_OFFSET)
879
880    def test_from_bytes_with_more_data_than_necessary(self) -> None:
881        # Pass in more data than necessary to parse the header.
882        header_data_with_suffix = EXAMPLE_HEADER_DATA + b"\x55" * 16
883        eh = _ExtendedHeader.from_bytes(header_data_with_suffix)
884
885        # This is a valid header: good magic and length.
886        self.assertTrue(eh.is_valid())
887
888        self.assertEqual(eh.magic, _ExtendedHeader.EXPECTED_MAGIC)
889        self.assertEqual(eh.length, _ExtendedHeader.EXPECTED_LENGTH)
890        self.assertEqual(eh.program_size, EXAMPLE_PROGRAM_SIZE)
891        self.assertEqual(eh.segment_base_offset, EXAMPLE_SEGMENT_BASE_OFFSET)
892
893    def test_from_bytes_larger_than_needed_header_size_field(self) -> None:
894        # Simulate a backwards-compatibility situation. Parse a header
895        # with a larger-than expected size. This would typically mean that
896        # there are additional fields that we don't know about, but we will
897        # ignore them.
898        input_data: bytes = (
899            # Magic bytes
900            b"eh00"
901            # uint32_t header size (little endian)
902            + b"\x1c\x00\x00\x00"  # Longer than expected
903            # uint64_t program size
904            + b"\x44\x33\x44\x33\x22\x11\x22\x11"
905            # uint64_t segment base offset
906            + b"\x88\x77\x88\x77\x66\x55\x66\x55"
907            # uint32_t new field (ignored)
908            + b"\xff\xee\xff\xee"
909        )
910
911        # Parse the serialized extended header.
912        eh = _ExtendedHeader.from_bytes(input_data)
913
914        # Header is valid despite having a larger than expected size.
915        self.assertTrue(eh.is_valid())
916
917        self.assertEqual(eh.magic, _ExtendedHeader.EXPECTED_MAGIC)
918        self.assertEqual(eh.length, 28)
919        self.assertEqual(eh.program_size, EXAMPLE_PROGRAM_SIZE)
920        self.assertEqual(eh.segment_base_offset, EXAMPLE_SEGMENT_BASE_OFFSET)
921
922    def test_from_bytes_not_enough_data_fails(self) -> None:
923        # Parsing a truncated prefix should fail.
924        with self.assertRaises(ValueError):
925            _ExtendedHeader.from_bytes(EXAMPLE_HEADER_DATA[:16])
926
927    def test_from_bytes_invalid_magic(self) -> None:
928        # An invalid serialized header
929        input_data: bytes = (
930            # Magic bytes
931            b"ABCD"  # Invalid
932            # uint32_t header size (little endian)
933            + b"\x18\x00\x00\x00"
934            # uint64_t program size
935            + b"\x44\x33\x44\x33\x22\x11\x22\x11"
936            # uint64_t segment base offset
937            + b"\x88\x77\x88\x77\x66\x55\x66\x55"
938        )
939
940        # Parse the serialized extended header.
941        eh = _ExtendedHeader.from_bytes(input_data)
942
943        # Bad magic makes this invalid
944        self.assertFalse(eh.is_valid())
945
946        # But it still parsed out the fields, so that callers can
947        # see what went wrong.
948        self.assertEqual(eh.magic, b"ABCD")
949        self.assertEqual(eh.length, _ExtendedHeader.EXPECTED_LENGTH)
950        self.assertEqual(eh.program_size, EXAMPLE_PROGRAM_SIZE)
951        self.assertEqual(eh.segment_base_offset, EXAMPLE_SEGMENT_BASE_OFFSET)
952
953    def test_from_bytes_invalid_length(self) -> None:
954        # An invalid serialized header
955        input_data: bytes = (
956            # Magic bytes
957            b"eh00"
958            # uint32_t header size (little endian)
959            + b"\x10\x00\x00\x00"  # Too short
960            # uint64_t program size
961            + b"\x44\x33\x44\x33\x22\x11\x22\x11"
962            # uint64_t segment base offset
963            + b"\x88\x77\x88\x77\x66\x55\x66\x55"
964        )
965
966        # Parse the serialized extended header.
967        eh = _ExtendedHeader.from_bytes(input_data)
968
969        # Bad header size makes this invalid
970        self.assertFalse(eh.is_valid())
971
972        # But it still parsed out the fields, so that callers can
973        # see what went wrong.
974        self.assertEqual(eh.magic, _ExtendedHeader.EXPECTED_MAGIC)
975        self.assertEqual(eh.length, 16)
976        self.assertEqual(eh.program_size, EXAMPLE_PROGRAM_SIZE)
977        self.assertEqual(eh.segment_base_offset, EXAMPLE_SEGMENT_BASE_OFFSET)
978