1#!/usr/bin/env fbpython 2# Copyright (c) Meta Platforms, Inc. and affiliates. 3# All rights reserved. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7 8# pyre-unsafe 9 10import copy 11import difflib 12import json 13import unittest 14 15from typing import List, Sequence 16 17from executorch.exir._serialize._flatbuffer import _program_flatbuffer_to_json 18from executorch.exir._serialize._program import ( 19 _ExtendedHeader, 20 _get_extended_header, 21 _json_to_program, 22 _program_to_json, 23 deserialize_pte_binary, 24 serialize_pte_binary, 25) 26 27from executorch.exir.schema import ( 28 BackendDelegate, 29 BackendDelegateDataReference, 30 BackendDelegateInlineData, 31 Buffer, 32 ContainerMetadata, 33 DataLocation, 34 DataSegment, 35 ExecutionPlan, 36 Program, 37 SubsegmentOffsets, 38) 39from executorch.exir.tests.common import get_test_program 40 41SEGMENT_ALIGNMENT: int = 128 42 43CONSTANT_TENSOR_ALIGNMENT: int = 16 44 45 46def add_constant_data(program: Program, blobs: Sequence[bytes]) -> None: 47 """Adds the provided constant data blobs to the program.""" 48 for blob in blobs: 49 program.constant_buffer.append(Buffer(storage=blob)) 50 51 52def add_delegate_data( 53 program: Program, plan: ExecutionPlan, blobs: Sequence[bytes] 54) -> None: 55 """Adds the provided delegate data blobs to the execution plan.""" 56 di = len(plan.delegates) 57 for blob in blobs: 58 data_index: int = len(program.backend_delegate_data) 59 program.backend_delegate_data.append( 60 BackendDelegateInlineData( 61 data=blob, 62 ) 63 ) 64 delegate = BackendDelegate( 65 id=f"delegate{di}", 66 processed=BackendDelegateDataReference( 67 location=DataLocation.INLINE, 68 index=data_index, 69 ), 70 compile_specs=[], 71 ) 72 plan.delegates.append(delegate) 73 di += 1 74 75 76def canonicalize_delegate_indices(program: Program) -> Program: 77 """Returns a copy of the program with the backend delegate data list in 78 a predictable order. 79 """ 80 program = copy.deepcopy(program) 81 82 # Original index and its data. 83 delegate_entries: list[tuple[int, bytes]] = [ 84 (i, entry.data) for i, entry in enumerate(program.backend_delegate_data) 85 ] 86 87 # Sort by the contents of the data, which is the second entry in the tuple. 88 # NOTE: This is unstable if multiple entries have the same data contents. 89 delegate_entries.sort(key=lambda x: x[1]) 90 91 # Build up the sorted Program.backend_delegate_data list, and a mapping from 92 # the old index to the new index. 93 old_to_new_index: dict[int, int] = {} 94 program.backend_delegate_data = [] 95 for i, data in delegate_entries: 96 old_to_new_index[i] = len(program.backend_delegate_data) 97 print(f">>> Mapping [{i}]: {old_to_new_index[i]} '{data}'") 98 program.backend_delegate_data.append(BackendDelegateInlineData(data=data)) 99 100 # Patch up the index pointers from the BackendDelegate entries. 101 for plan in program.execution_plan: 102 for delegate in plan.delegates: 103 delegate.processed.index = old_to_new_index[delegate.processed.index] 104 105 return program 106 107 108class TestProgram(unittest.TestCase): 109 def assert_file_magic_present(self, program_data: bytes) -> None: 110 self.assertEqual(program_data[4:6], b"ET") 111 # Ignore the other bytes, which can change over time and are not 112 # important for this test. 113 114 def assert_programs_equal(self, program1: Program, program2: Program) -> None: 115 def prepare_json_string(j: str) -> List[str]: 116 """Formats the JSON and splits it into lines.""" 117 return json.dumps(json.loads(j), indent=2, sort_keys=True).splitlines( 118 keepends=True 119 ) 120 121 # This JSON comparison is fragile: some parts of the program do not care 122 # about order (like the operators list), so those are technically free 123 # to be reordered. If they become a problem, we can canonicalize them 124 # like we do for the backend delegate data list. 125 json1 = _program_to_json(canonicalize_delegate_indices(program1)) 126 json2 = _program_to_json(canonicalize_delegate_indices(program2)) 127 128 # Use unified_diff so it only prints the differences instead of the 129 # entire string. 130 diff: str = "".join( 131 difflib.unified_diff( 132 prepare_json_string(json1), 133 prepare_json_string(json2), 134 ) 135 ) 136 if diff: 137 self.fail(msg="Programs are not equal\n" + diff) 138 139 def get_and_validate_extended_header(self, pte_data: bytes) -> _ExtendedHeader: 140 """When an extended header is expected, check that it exists and is valid. 141 Does not check correctness of the contents.""" 142 eh = _get_extended_header(pte_data) 143 self.assertIsNotNone(eh) 144 self.assertTrue(eh.is_valid()) 145 self.assertLess(eh.program_size, len(pte_data)) 146 return eh 147 148 def constant_segment_with_tensor_alignment( 149 self, constant_tensor_alignment: int 150 ) -> None: 151 """Utility to test constant segment with varying alignment. 152 Args: 153 constant_tensor_alignment: Alignment of constant tensor data. 154 Must be a multiple of 2. 155 Must be > 8 for the purposes of the test, which checks +- 3 bytes on the edges of each tensor. 156 """ 157 # Create a program with some constant tensor data. 158 program = get_test_program() 159 blobs = ( 160 b"", # Empty tensor. 161 self.gen_blob_data(constant_tensor_alignment // 2, b"\x10\x11\x01"), 162 self.gen_blob_data(constant_tensor_alignment - 1, b"\x20\x22\x02"), 163 self.gen_blob_data(constant_tensor_alignment, b"\x30\x33\x03"), 164 self.gen_blob_data(constant_tensor_alignment + 1, b"\x40\x44\x04"), 165 ) 166 add_constant_data(program, blobs) 167 168 # Extract blobs into constant segment during serialization. 169 pte_data = bytes( 170 serialize_pte_binary( 171 program, 172 segment_alignment=SEGMENT_ALIGNMENT, 173 constant_tensor_alignment=constant_tensor_alignment, 174 ) 175 ) 176 177 # The input Program should not be modified. 178 self.assertEqual(program.segments, []) 179 180 # Extended header should be present in the serialized data. 181 eh = self.get_and_validate_extended_header(pte_data) 182 183 # Segment offset should be non-zero since there are segments. It 184 # should point past the end of the program data, but not beyond 185 # the end of the file. 186 self.assertGreaterEqual(eh.segment_base_offset, eh.program_size) 187 self.assertLess(eh.segment_base_offset, len(pte_data)) 188 189 # Peek inside the actual flatbuffer data to see the segments. 190 program_with_segments = _json_to_program(_program_flatbuffer_to_json(pte_data)) 191 192 # The constant tensor data should appear as the only segment. 193 self.assertEqual(len(program_with_segments.segments), 1) 194 195 # The constant buffer should appear now as a constant segment. 196 segment_table: List[DataSegment] = program_with_segments.segments 197 self.assertEqual(len(segment_table), 1) 198 # Tensor sizes 199 # - tensor[0]: 0 200 # - tensors[1,2,3]: constant_tensor_alignment 201 # - tensor[4]: constant_tensor_alignment + 1 (no padding on the last tensor) 202 self.assertEqual( 203 segment_table[0].size, 204 constant_tensor_alignment * 3 + (constant_tensor_alignment + 1), 205 ) 206 207 # Check constant_segment index and offsets. 208 subsegment_offsets: SubsegmentOffsets = program_with_segments.constant_segment 209 self.assertEqual(subsegment_offsets.segment_index, 0) 210 self.assertEqual( 211 subsegment_offsets.offsets, 212 [ 213 0, # Start at offset 0. 214 0, # tensor[0] is empty. 215 constant_tensor_alignment, # tensor[1] has size constant_tensor_alignment // 2. Round up. 216 constant_tensor_alignment 217 * 2, # tensor[2] has size constant_tensor_alignment - 1. Round up. 218 constant_tensor_alignment 219 * 3, # tensor[3] has size constant_tensor_alignment. No padding needed. 220 ], 221 ) 222 223 # Check constant_buffer is empty, because the data was moved into the segment. 224 self.assertEqual(len(program_with_segments.constant_buffer), 0) 225 226 # Check segment data. 227 offsets = subsegment_offsets.offsets 228 segment_data: bytes = pte_data[eh.segment_base_offset :] 229 230 # tensor[1]: padding. 231 self.assertEqual( 232 segment_data[offsets[1] : offsets[1] + 3], 233 # Tensor data. 234 b"\x10\x11\x11", 235 ) 236 self.assertEqual( 237 segment_data[ 238 offsets[1] 239 + constant_tensor_alignment // 2 : offsets[1] 240 + constant_tensor_alignment // 2 241 + 3 242 ], 243 # Padding. 244 b"\x00\x00\x00", 245 ) 246 247 # tensor[3]: no padding. 248 self.assertEqual( 249 segment_data[offsets[4] - 3 : offsets[4] + 3], 250 # End of tensor 3. 251 b"\x33\x33\x03" 252 # Start of tensor 4. 253 + b"\x40\x44\x44", 254 ) 255 256 # tensor[4]: no padding for last tensor. 257 self.assertEqual( 258 segment_data[ 259 offsets[4] 260 + constant_tensor_alignment 261 - 3 : offsets[4] 262 + constant_tensor_alignment 263 + 1 264 ], 265 b"\x44\x44\x44\x04", 266 ) 267 268 # The final segment should not point past the end of the file. 269 self.assertLessEqual( 270 segment_table[-1].offset + segment_table[-1].size, 271 len(pte_data), 272 f"{segment_table}", 273 ) 274 275 # Convert back. 276 program2 = deserialize_pte_binary(pte_data) 277 # Programs are the same besides constant_buffer, as deserialization 278 # does not preserve constant segment; padding may be added 279 # during serialization. 280 self.assertEqual(program2.execution_plan, program.execution_plan) 281 # Number of constant tensors should be the same. 282 self.assertEqual(len(program2.constant_buffer), len(program.constant_buffer)) 283 284 def test_canonicalize_delegate_indices(self) -> None: 285 def make_execution_plan( 286 name: str, delegates: List[BackendDelegate] 287 ) -> ExecutionPlan: 288 return ExecutionPlan( 289 name=name, 290 container_meta_type=ContainerMetadata( 291 encoded_inp_str="encoded_inp_str", 292 encoded_out_str="encoded_out_str", 293 ), 294 values=[], 295 inputs=[], 296 outputs=[], 297 chains=[], 298 operators=[], 299 delegates=delegates, 300 non_const_buffer_sizes=[], 301 ) 302 303 # A program with three delegates across two execution plans. To start 304 # with, the data indices in the delegates are in a non-canonical order. 305 program = Program( 306 version=0, 307 execution_plan=[ 308 make_execution_plan( 309 name="forward0", 310 delegates=[ 311 BackendDelegate( 312 id="delegate0", 313 processed=BackendDelegateDataReference( 314 location=DataLocation.INLINE, index=2 315 ), 316 compile_specs=[], 317 ), 318 BackendDelegate( 319 id="delegate1", 320 processed=BackendDelegateDataReference( 321 location=DataLocation.INLINE, index=1 322 ), 323 compile_specs=[], 324 ), 325 ], 326 ), 327 make_execution_plan( 328 name="forward1", 329 delegates=[ 330 BackendDelegate( 331 id="delegate2", 332 processed=BackendDelegateDataReference( 333 location=DataLocation.INLINE, index=0 334 ), 335 compile_specs=[], 336 ), 337 ], 338 ), 339 ], 340 constant_buffer=[], 341 backend_delegate_data=[ 342 # Data is in non-canonical (unsorted) order. 343 BackendDelegateInlineData(data=b"CC delegate [1,0] data"), 344 BackendDelegateInlineData(data=b"BB delegate [0,1] data"), 345 BackendDelegateInlineData(data=b"AA delegate [0,0] data"), 346 ], 347 segments=[], 348 constant_segment=SubsegmentOffsets(segment_index=0, offsets=[]), 349 ) 350 351 # Demonstrate which data each delegate points to. 352 self.assertEqual( 353 program.backend_delegate_data[ 354 program.execution_plan[0].delegates[0].processed.index 355 ].data, 356 b"AA delegate [0,0] data", 357 ) 358 self.assertEqual( 359 program.backend_delegate_data[ 360 program.execution_plan[0].delegates[1].processed.index 361 ].data, 362 b"BB delegate [0,1] data", 363 ) 364 self.assertEqual( 365 program.backend_delegate_data[ 366 program.execution_plan[1].delegates[0].processed.index 367 ].data, 368 b"CC delegate [1,0] data", 369 ) 370 371 # Canonicalize the program. 372 canonical_program: Program = canonicalize_delegate_indices(program) 373 374 # The delegate data list should be sorted by contents. 375 self.assertListEqual( 376 canonical_program.backend_delegate_data, 377 [ 378 # Should have been sorted. 379 BackendDelegateInlineData(data=b"AA delegate [0,0] data"), 380 BackendDelegateInlineData(data=b"BB delegate [0,1] data"), 381 BackendDelegateInlineData(data=b"CC delegate [1,0] data"), 382 ], 383 ) 384 385 # Demonstrate that the delegate entries still point to the correct data. 386 self.assertEqual( 387 canonical_program.backend_delegate_data[ 388 canonical_program.execution_plan[0].delegates[0].processed.index 389 ].data, 390 b"AA delegate [0,0] data", 391 ) 392 self.assertEqual( 393 canonical_program.backend_delegate_data[ 394 canonical_program.execution_plan[0].delegates[1].processed.index 395 ].data, 396 b"BB delegate [0,1] data", 397 ) 398 self.assertEqual( 399 canonical_program.backend_delegate_data[ 400 canonical_program.execution_plan[1].delegates[0].processed.index 401 ].data, 402 b"CC delegate [1,0] data", 403 ) 404 405 def test_round_trip_no_header_no_segments(self) -> None: 406 """Tests that a Program remains the same after serializing and 407 deserializing. 408 """ 409 program = get_test_program() 410 pte_data = bytes(serialize_pte_binary(program)) 411 self.assertGreater(len(pte_data), 16) 412 413 # File magic should be present at the expected offset. 414 self.assert_file_magic_present(pte_data) 415 416 # Extended header should not be present. 417 eh = _get_extended_header(pte_data) 418 self.assertIsNone(eh) 419 420 # Convert back. 421 program2 = deserialize_pte_binary(pte_data) 422 423 # Programs should be the same. 424 self.assert_programs_equal(program, program2) 425 426 def test_round_trip_large_buffer_sizes(self) -> None: 427 """Tests that when the non_const_buffer_sizes contains integers 428 overflowing a signed/unsigned 32 bit integer, we can still serialize the 429 model and get the same program by deserialization. 430 """ 431 program = get_test_program() 432 program.execution_plan[0].non_const_buffer_sizes = [0, 2**48] 433 flatbuffer_from_py = bytes(serialize_pte_binary(program)) 434 self.assert_programs_equal(program, deserialize_pte_binary(flatbuffer_from_py)) 435 436 def test_round_trip_no_segments_and_no_header(self) -> None: 437 """Tests that a Program serialized with extract_delegate_segments=True 438 when there are no segments does not contain an extended header, 439 constant segment, or delegate segments. Confirm that a Program remains 440 the same after serializing and deserializing. 441 """ 442 program = get_test_program() 443 pte_data = bytes(serialize_pte_binary(program, extract_delegate_segments=True)) 444 self.assertGreater(len(pte_data), 16) 445 446 # File magic should be present at the expected offset. 447 self.assert_file_magic_present(pte_data) 448 449 # Extended header should not be present when no segments are created. 450 eh = _get_extended_header(pte_data) 451 self.assertIsNone(eh) 452 453 # Peek inside the flatbuffer data to confirm that there are no segments. 454 program_with_segments = _json_to_program(_program_flatbuffer_to_json(pte_data)) 455 self.assertEqual(program_with_segments.segments, []) 456 457 # Convert back. 458 program2 = deserialize_pte_binary(pte_data) 459 460 # Programs should be the same. 461 self.assert_programs_equal(program, program2) 462 463 @staticmethod 464 def gen_blob_data(size: int, pattern: bytes) -> bytes: 465 """Generates a buffer with special first and last bytes, 466 repeating the middle byte of the pattern.""" 467 assert len(pattern) == 3 468 assert size >= 3 469 # Stretch out the middle byte to fill the space. 470 ret = pattern[0:1] + pattern[1:2] * (size - 2) + pattern[2:3] 471 assert len(ret) == size 472 return ret 473 474 def test_round_trip_with_segments(self) -> None: 475 # Create a program with some delegate data blobs. 476 program = get_test_program() 477 blobs = ( 478 self.gen_blob_data(SEGMENT_ALIGNMENT // 5, b"\x10\x11\x01"), 479 # Focus on blobs whose sizes fall close to the alignment. 480 self.gen_blob_data(SEGMENT_ALIGNMENT - 1, b"\x20\x22\x02"), 481 self.gen_blob_data(SEGMENT_ALIGNMENT, b"\x30\x33\x03"), 482 self.gen_blob_data(SEGMENT_ALIGNMENT + 1, b"\x40\x44\x04"), 483 b"", # Empty segment. 484 self.gen_blob_data(SEGMENT_ALIGNMENT // 10, b"\x50\x55\x05"), 485 ) 486 add_delegate_data(program, program.execution_plan[0], blobs) 487 488 # Extract the blobs into segments during serialization. 489 pte_data = bytes( 490 serialize_pte_binary( 491 program, 492 extract_delegate_segments=True, 493 segment_alignment=SEGMENT_ALIGNMENT, 494 ) 495 ) 496 497 # The input Program should not have been modified. 498 self.assertEqual(program.segments, []) 499 self.assertEqual( 500 program.execution_plan[0].delegates[0].processed.location, 501 DataLocation.INLINE, 502 ) 503 504 # Extended header should be present in the serialized data. 505 eh = self.get_and_validate_extended_header(pte_data) 506 # Segment offset should be non-zero since there are segments. It 507 # should point past the end of the program data, but not beyond 508 # the end of the file. 509 self.assertGreaterEqual(eh.segment_base_offset, eh.program_size) 510 self.assertLess(eh.segment_base_offset, len(pte_data)) 511 512 # Peek inside the actual flatbuffer data to see the segments. Note that 513 # this also implicity tests the case where we try parsing the entire 514 # file with segment data following it, demonstrating that the extra data 515 # doesn't upset the flatbuffer parsing path. 516 program_with_segments = _json_to_program(_program_flatbuffer_to_json(pte_data)) 517 518 # The delegate blobs we added to the program should appear as segments. 519 # The one empty blob should have been ignored, hence the `- 1`. 520 self.assertEqual(len(program_with_segments.segments), len(blobs) - 1) 521 segment_table: List[DataSegment] = program_with_segments.segments 522 523 # Check segment range invariants. 524 for i in range(len(segment_table)): 525 # All offsets should be a multiple of SEGMENT_ALIGNMENT. 526 self.assertTrue( 527 segment_table[i].offset % SEGMENT_ALIGNMENT == 0, 528 f"Segment {i} offset is not aligned: {segment_table[i]}", 529 ) 530 # There should be no empty segments. 531 self.assertGreater( 532 segment_table[i].size, 0, f"Segment {i}: {segment_table}" 533 ) 534 if i > 0: 535 # Segments should not overlap, and should be sorted from 536 # smallest offset to largest. 537 self.assertLessEqual( 538 segment_table[i - 1].offset + segment_table[i - 1].size, 539 segment_table[i].offset, 540 f"Segment {i} overlaps or is out of order: {segment_table}", 541 ) 542 # The first segment should begin at zero; i.e., at the segment base 543 # offset. 544 self.assertEqual(segment_table[0].offset, 0, f"{segment_table}") 545 # The final segment should not point past the end of the file. 546 self.assertLessEqual( 547 segment_table[-1].offset + segment_table[-1].size, 548 len(pte_data), 549 f"{segment_table}", 550 ) 551 552 # Check the segment base offset boundary. 553 segment_base_offset = eh.segment_base_offset 554 self.assertEqual( 555 pte_data[segment_base_offset - 2 : segment_base_offset + 3], 556 # The padding before the first segment. 557 b"\x00\x00" 558 # The first few bytes of the first segment. 559 + b"\x10\x11\x11", 560 ) 561 562 # Now that we've shown that the base offset is correct, slice off the 563 # front so that all segment offsets are relative to zero. 564 segment_data: bytes = pte_data[segment_base_offset:] 565 566 # End of the first segment. It's much smaller than the alignment, 567 # so we know that it's followed by zeros. 568 self.assertEqual( 569 segment_data[segment_table[0].size - 3 : segment_table[0].size + 2], 570 # The end of the segment. 571 b"\x11\x11\x01" 572 # The padding that follows it. 573 + b"\x00\x00", 574 ) 575 576 # Look at the end of segment[2], which is exactly the same size as 577 # the alignment. There should be no padding, running right into the 578 # next segment. 579 self.assertEqual( 580 segment_data[segment_table[3].offset - 3 : segment_table[3].offset + 3], 581 # The end of segment[2]. 582 b"\x33\x33\x03" 583 # The beginning of segment[3] 584 b"\x40\x44\x44", 585 ) 586 587 # Convert back; the programs should be the same after a round trip, 588 # meaning that the segments were moved back to inline. This also 589 # demonstrates that the contents of all segments survived, and weren't 590 # truncated or corrupted. 591 program2 = deserialize_pte_binary(pte_data) 592 self.assert_programs_equal(program, program2) 593 594 def test_no_constants(self) -> None: 595 program = get_test_program() 596 # Insert placeholder for non-const tensors. 597 add_constant_data(program, [b""]) 598 599 pte_data = bytes( 600 serialize_pte_binary( 601 program, 602 extract_delegate_segments=True, 603 segment_alignment=SEGMENT_ALIGNMENT, 604 constant_tensor_alignment=CONSTANT_TENSOR_ALIGNMENT, 605 ) 606 ) 607 # The input Program should not be modified. 608 self.assertEqual(program.segments, []) 609 610 # Peek inside the actual flatbuffer data to see the segments. 611 flatbuffer_program = _json_to_program(_program_flatbuffer_to_json(pte_data)) 612 613 # Constant buffer should be empty. 614 self.assertEqual(len(flatbuffer_program.constant_buffer), 0) 615 616 # Constant segment should contain the placeholder. 617 self.assertEqual(flatbuffer_program.constant_segment.segment_index, 0) 618 self.assertEqual(len(flatbuffer_program.constant_segment.offsets), 1) 619 self.assertEqual(flatbuffer_program.constant_segment.offsets[0], 0) 620 621 def test_unused_inline_delegate_blobs_with_segments(self) -> None: 622 # Create a program with some delegate data blobs. 623 program = get_test_program() 624 blobs = ( 625 self.gen_blob_data(16, b"\x10\x11\x01"), 626 self.gen_blob_data(32, b"\x20\x22\x02"), 627 ) 628 add_delegate_data(program, program.execution_plan[0], blobs) 629 630 # Extract the blobs into segments should succeeed. 631 pte_data = bytes( 632 serialize_pte_binary( 633 program, 634 extract_delegate_segments=True, 635 segment_alignment=SEGMENT_ALIGNMENT, 636 ) 637 ) 638 self.assertGreater(len(pte_data), 16) 639 640 # Add another inline blob that is not pointed to by a delegate. 641 program.backend_delegate_data.append( 642 BackendDelegateInlineData(data=self.gen_blob_data(16, b"\x30\x33\x03")) 643 ) 644 645 # Should cause serialization to fail. 646 with self.assertRaises(ValueError): 647 serialize_pte_binary( 648 program, 649 extract_delegate_segments=True, 650 segment_alignment=SEGMENT_ALIGNMENT, 651 ) 652 653 def test_constant_segment_tensor_alignment_16(self) -> None: 654 self.constant_segment_with_tensor_alignment(16) 655 656 def test_constant_segment_tensor_alignment_128(self) -> None: 657 self.constant_segment_with_tensor_alignment(128) 658 659 def test_constant_segment_tensor_alignment_non_power_of_2_fails(self) -> None: 660 # Create a program with some constant tensor data. 661 program = get_test_program() 662 program.constant_buffer.append(Buffer(storage=b"12345")) 663 664 constant_tensor_alignment: int = 14 665 # Extract blobs into constant segment during serialization. 666 # Expect failure as tensor alignment 14 is not a power of 2. 667 with self.assertRaises(ValueError): 668 serialize_pte_binary( 669 program, 670 segment_alignment=SEGMENT_ALIGNMENT, 671 constant_tensor_alignment=constant_tensor_alignment, 672 ) 673 674 def test_constant_segment_and_delegate_segment(self) -> None: 675 # Create a program with some constant tensor data and delegate data blobs. 676 program = get_test_program() 677 constant_blobs = ( 678 self.gen_blob_data(CONSTANT_TENSOR_ALIGNMENT // 2, b"\x10\x11\x01"), 679 self.gen_blob_data(CONSTANT_TENSOR_ALIGNMENT + 1, b"\x20\x22\x02"), 680 ) 681 delegate_blobs = ( 682 self.gen_blob_data(SEGMENT_ALIGNMENT // 2, b"\x30\x33\x03"), 683 self.gen_blob_data(SEGMENT_ALIGNMENT + 1, b"\x40\x44\x04"), 684 ) 685 686 add_constant_data(program, constant_blobs) 687 add_delegate_data(program, program.execution_plan[0], delegate_blobs) 688 689 # Extract the blobs into segments during serialization. 690 pte_data = bytes( 691 serialize_pte_binary( 692 program, 693 extract_delegate_segments=True, 694 segment_alignment=SEGMENT_ALIGNMENT, 695 constant_tensor_alignment=CONSTANT_TENSOR_ALIGNMENT, 696 ) 697 ) 698 699 # The input Program should not be modified. 700 self.assertEqual(program.segments, []) 701 self.assertEqual( 702 program.execution_plan[0].delegates[0].processed.location, 703 DataLocation.INLINE, 704 ) 705 706 # Extended header should be present in the serialized data. 707 eh = self.get_and_validate_extended_header(pte_data) 708 709 # Segment offset should be non-zero since there are segments. It 710 # should point past the end of the program data, but not beyond 711 # the end of the file. 712 self.assertGreaterEqual(eh.segment_base_offset, eh.program_size) 713 self.assertLess(eh.segment_base_offset, len(pte_data)) 714 715 # Peek inside the actual flatbuffer data to see the segments. 716 program_with_segments = _json_to_program(_program_flatbuffer_to_json(pte_data)) 717 718 # Segment table should contain a constant segment and the delegate blobs. 719 segment_table: List[DataSegment] = program_with_segments.segments 720 self.assertEqual(len(segment_table), len(delegate_blobs) + 1) 721 self.assertEqual(segment_table[0].offset, 0) 722 # segment_table[0] is the constant segment, which 723 # contains a couple of tensors with sizes: 724 # - tensor[0] = CONSTANT_TENSOR_ALIGNMENT 725 # - tensor[1] = CONSTANT_TENSOR_ALIGNMENT + 1 (no padding on last tensor) 726 self.assertEqual(segment_table[0].size, CONSTANT_TENSOR_ALIGNMENT * 2 + 1) 727 self.assertEqual(segment_table[1].offset, SEGMENT_ALIGNMENT) 728 self.assertEqual(segment_table[1].size, SEGMENT_ALIGNMENT // 2) 729 self.assertEqual(segment_table[2].offset, SEGMENT_ALIGNMENT * 2) 730 self.assertEqual(segment_table[2].size, SEGMENT_ALIGNMENT + 1) 731 732 # Check constant_segment index and offsets. 733 subsegment_offsets: SubsegmentOffsets = program_with_segments.constant_segment 734 self.assertEqual(subsegment_offsets.segment_index, 0) 735 self.assertEqual( 736 subsegment_offsets.offsets, 737 [ 738 0, # Start at offset 0. 739 16, # tensor[0] has size CONSTANT_TENSOR_ALIGNMENT. No padding required. 740 ], 741 ) 742 743 # Check constant_buffer is empty, because the data was moved into the segment. 744 self.assertEqual(len(program_with_segments.constant_buffer), 0) 745 746 # The first segment should begin at zero; i.e., at the segment base 747 # offset. 748 self.assertEqual(segment_table[0].offset, 0, f"{segment_table}") 749 # The final segment should not point past the end of the file. 750 self.assertLessEqual( 751 segment_table[-1].offset + segment_table[-1].size, 752 len(pte_data), 753 f"{segment_table}", 754 ) 755 756 # Check the segment base offset boundary. 757 segment_base_offset = eh.segment_base_offset 758 self.assertEqual( 759 pte_data[segment_base_offset - 2 : segment_base_offset + 3], 760 # Padding before the first segment. 761 b"\x00\x00" 762 # First few bytes of the first segment. 763 + b"\x10\x11\x11", 764 ) 765 766 # Now that we've shown that the base offset is correct, slice off the 767 # front so that all segment offsets are relative to zero. 768 segment_data: bytes = pte_data[segment_base_offset:] 769 770 # Check segment[0] for constants. 771 offsets = subsegment_offsets.offsets 772 # Check tensor[0]: padding at the end. 773 self.assertEqual( 774 segment_data[0 : offsets[1]], 775 # Tensor data. 776 b"\x10\x11\x11\x11\x11\x11\x11\x01" 777 # Padding. 778 + b"\x00\x00\x00\x00\x00\x00\x00\x00", 779 ) 780 781 # Check tensor[1]: padding at CONSTANT_TENSOR_ALIGNMENT. 782 self.assertEqual( 783 segment_data[ 784 offsets[1] 785 + CONSTANT_TENSOR_ALIGNMENT 786 - 3 : offsets[1] 787 + CONSTANT_TENSOR_ALIGNMENT 788 + 3 789 ], 790 # Tensor data. 791 b"\x22\x22\x22" 792 # Padding. 793 + b"\x02\x00\x00", 794 ) 795 796 # Check segment[0] and segment[1] border. 797 self.assertEqual( 798 segment_data[segment_table[1].offset - 3 : segment_table[1].offset + 3], 799 # Padding for segment[0]. 800 b"\x00\x00\x00" 801 # Start of segment[1]. 802 + b"\x30\x33\x33", 803 ) 804 805 # Check segment[1] and segment[2] border. 806 self.assertEqual( 807 segment_data[segment_table[2].offset - 3 : segment_table[2].offset + 3], 808 # Padding for segment[1]. 809 b"\x00\x00\x00" 810 # Start of segment[2]. 811 + b"\x40\x44\x44", 812 ) 813 814 # Convert back. 815 program2 = deserialize_pte_binary(pte_data) 816 # Programs are the same besides constant_buffer, as deserialization 817 # does not preserve constant segment; padding may be added 818 # during serialization. 819 self.assertEqual(program2.execution_plan, program.execution_plan) 820 # Number of constant tensors should be the same. 821 self.assertEqual(len(program2.constant_buffer), len(program.constant_buffer)) 822 823 824# Common data for extended header tests. The two example values should produce 825# the example data. 826EXAMPLE_PROGRAM_SIZE: int = 0x1122112233443344 827EXAMPLE_SEGMENT_BASE_OFFSET: int = 0x5566556677887788 828# This data is intentionally fragile. If the header layout or magic changes, 829# this test must change too. The layout of the header is a contract, not an 830# implementation detail. 831EXAMPLE_HEADER_DATA: bytes = ( 832 # Magic bytes 833 b"eh00" 834 # uint32_t header size (little endian) 835 + b"\x18\x00\x00\x00" 836 # uint64_t program size 837 + b"\x44\x33\x44\x33\x22\x11\x22\x11" 838 # uint64_t segment base offset 839 + b"\x88\x77\x88\x77\x66\x55\x66\x55" 840) 841 842 843class TestExtendedHeader(unittest.TestCase): 844 def test_to_bytes(self) -> None: 845 eh = _ExtendedHeader( 846 program_size=EXAMPLE_PROGRAM_SIZE, 847 segment_base_offset=EXAMPLE_SEGMENT_BASE_OFFSET, 848 ) 849 self.assertTrue(eh.is_valid()) 850 self.assertEqual(eh.to_bytes(), EXAMPLE_HEADER_DATA) 851 852 def test_to_bytes_with_non_defaults(self) -> None: 853 eh = _ExtendedHeader( 854 program_size=EXAMPLE_PROGRAM_SIZE, 855 segment_base_offset=EXAMPLE_SEGMENT_BASE_OFFSET, 856 # Override the default magic and length, to demonstrate that this 857 # does not affect the serialized header. 858 magic=b"ABCD", 859 length=0xAABBCCDD, 860 ) 861 # No longer counts as valid. 862 self.assertFalse(eh.is_valid()) 863 864 # But still produces a valid output header, since to_bytes() ignores 865 # magic and length. 866 self.assertEqual(eh.to_bytes(), EXAMPLE_HEADER_DATA) 867 868 def test_from_bytes_valid(self) -> None: 869 # Parse the serialized extended header. 870 eh = _ExtendedHeader.from_bytes(EXAMPLE_HEADER_DATA) 871 872 # This is a valid header: good magic and length. 873 self.assertTrue(eh.is_valid()) 874 875 self.assertEqual(eh.magic, _ExtendedHeader.EXPECTED_MAGIC) 876 self.assertEqual(eh.length, _ExtendedHeader.EXPECTED_LENGTH) 877 self.assertEqual(eh.program_size, EXAMPLE_PROGRAM_SIZE) 878 self.assertEqual(eh.segment_base_offset, EXAMPLE_SEGMENT_BASE_OFFSET) 879 880 def test_from_bytes_with_more_data_than_necessary(self) -> None: 881 # Pass in more data than necessary to parse the header. 882 header_data_with_suffix = EXAMPLE_HEADER_DATA + b"\x55" * 16 883 eh = _ExtendedHeader.from_bytes(header_data_with_suffix) 884 885 # This is a valid header: good magic and length. 886 self.assertTrue(eh.is_valid()) 887 888 self.assertEqual(eh.magic, _ExtendedHeader.EXPECTED_MAGIC) 889 self.assertEqual(eh.length, _ExtendedHeader.EXPECTED_LENGTH) 890 self.assertEqual(eh.program_size, EXAMPLE_PROGRAM_SIZE) 891 self.assertEqual(eh.segment_base_offset, EXAMPLE_SEGMENT_BASE_OFFSET) 892 893 def test_from_bytes_larger_than_needed_header_size_field(self) -> None: 894 # Simulate a backwards-compatibility situation. Parse a header 895 # with a larger-than expected size. This would typically mean that 896 # there are additional fields that we don't know about, but we will 897 # ignore them. 898 input_data: bytes = ( 899 # Magic bytes 900 b"eh00" 901 # uint32_t header size (little endian) 902 + b"\x1c\x00\x00\x00" # Longer than expected 903 # uint64_t program size 904 + b"\x44\x33\x44\x33\x22\x11\x22\x11" 905 # uint64_t segment base offset 906 + b"\x88\x77\x88\x77\x66\x55\x66\x55" 907 # uint32_t new field (ignored) 908 + b"\xff\xee\xff\xee" 909 ) 910 911 # Parse the serialized extended header. 912 eh = _ExtendedHeader.from_bytes(input_data) 913 914 # Header is valid despite having a larger than expected size. 915 self.assertTrue(eh.is_valid()) 916 917 self.assertEqual(eh.magic, _ExtendedHeader.EXPECTED_MAGIC) 918 self.assertEqual(eh.length, 28) 919 self.assertEqual(eh.program_size, EXAMPLE_PROGRAM_SIZE) 920 self.assertEqual(eh.segment_base_offset, EXAMPLE_SEGMENT_BASE_OFFSET) 921 922 def test_from_bytes_not_enough_data_fails(self) -> None: 923 # Parsing a truncated prefix should fail. 924 with self.assertRaises(ValueError): 925 _ExtendedHeader.from_bytes(EXAMPLE_HEADER_DATA[:16]) 926 927 def test_from_bytes_invalid_magic(self) -> None: 928 # An invalid serialized header 929 input_data: bytes = ( 930 # Magic bytes 931 b"ABCD" # Invalid 932 # uint32_t header size (little endian) 933 + b"\x18\x00\x00\x00" 934 # uint64_t program size 935 + b"\x44\x33\x44\x33\x22\x11\x22\x11" 936 # uint64_t segment base offset 937 + b"\x88\x77\x88\x77\x66\x55\x66\x55" 938 ) 939 940 # Parse the serialized extended header. 941 eh = _ExtendedHeader.from_bytes(input_data) 942 943 # Bad magic makes this invalid 944 self.assertFalse(eh.is_valid()) 945 946 # But it still parsed out the fields, so that callers can 947 # see what went wrong. 948 self.assertEqual(eh.magic, b"ABCD") 949 self.assertEqual(eh.length, _ExtendedHeader.EXPECTED_LENGTH) 950 self.assertEqual(eh.program_size, EXAMPLE_PROGRAM_SIZE) 951 self.assertEqual(eh.segment_base_offset, EXAMPLE_SEGMENT_BASE_OFFSET) 952 953 def test_from_bytes_invalid_length(self) -> None: 954 # An invalid serialized header 955 input_data: bytes = ( 956 # Magic bytes 957 b"eh00" 958 # uint32_t header size (little endian) 959 + b"\x10\x00\x00\x00" # Too short 960 # uint64_t program size 961 + b"\x44\x33\x44\x33\x22\x11\x22\x11" 962 # uint64_t segment base offset 963 + b"\x88\x77\x88\x77\x66\x55\x66\x55" 964 ) 965 966 # Parse the serialized extended header. 967 eh = _ExtendedHeader.from_bytes(input_data) 968 969 # Bad header size makes this invalid 970 self.assertFalse(eh.is_valid()) 971 972 # But it still parsed out the fields, so that callers can 973 # see what went wrong. 974 self.assertEqual(eh.magic, _ExtendedHeader.EXPECTED_MAGIC) 975 self.assertEqual(eh.length, 16) 976 self.assertEqual(eh.program_size, EXAMPLE_PROGRAM_SIZE) 977 self.assertEqual(eh.segment_base_offset, EXAMPLE_SEGMENT_BASE_OFFSET) 978