xref: /aosp_15_r20/external/executorch/devtools/inspector/tests/event_blocks_test.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8import unittest
9from typing import List, Optional, Tuple, Union
10
11import executorch.devtools.etdump.schema_flatcc as flatcc
12from executorch.devtools.etdump.schema_flatcc import ETDumpFlatCC, ProfileEvent
13from executorch.devtools.inspector import Event, EventBlock, PerfData
14from executorch.devtools.inspector._inspector import (
15    DelegateMetadata,
16    EventSignature,
17    InstructionEvent,
18    InstructionEventSignature,
19    ProfileEventSignature,
20)
21
22
23class TestEventBlock(unittest.TestCase):
24
25    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Test Helpers ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
26    @staticmethod
27    def _gen_sample_profile_event(
28        name: str,
29        instruction_id: int,
30        time: Tuple[int, int],
31        delegate_debug_id: Optional[Union[int, str]] = None,
32        delegate_debug_metadata: Optional[str] = None,
33    ) -> flatcc.ProfileEvent:
34        """
35        Helper for generating test ProfileEvents
36
37        Notably:
38        - the timestamp is specified as a tuple of two separate integers
39        - delegate_debug_id takes either the str or int representation
40        - chain_idx is auto-populated to 0
41        """
42        delegate_debug_id_int = (
43            delegate_debug_id if isinstance(delegate_debug_id, int) else -1
44        )
45        delegate_debug_id_str = (
46            delegate_debug_id if isinstance(delegate_debug_id, str) else ""
47        )
48        return flatcc.ProfileEvent(
49            name,
50            0,
51            instruction_id,
52            delegate_debug_id_int,
53            delegate_debug_id_str,
54            # pyre-fixme[6]: For 6th argument expected `Optional[bytes]` but got
55            #  `Optional[str]`.
56            delegate_debug_metadata,
57            start_time=time[0],
58            end_time=time[1],
59        )
60
61    @staticmethod
62    def _gen_sample_debug_event(
63        instruction_id: int,
64        delegate_debug_id: Optional[Union[int, str]] = None,
65        name: str = "test_debug_event",
66    ) -> flatcc.DebugEvent:
67        """
68        Helper for generating test DebugEvents
69
70        Notably:
71        - delegate_debug_id takes either the str or int representation
72        """
73        delegate_debug_id_int = (
74            delegate_debug_id if isinstance(delegate_debug_id, int) else -1
75        )
76        delegate_debug_id_str = (
77            delegate_debug_id if isinstance(delegate_debug_id, str) else ""
78        )
79
80        return flatcc.DebugEvent(
81            name=name,
82            chain_index=0,
83            instruction_id=instruction_id,
84            delegate_debug_id_int=delegate_debug_id_int,
85            delegate_debug_id_str=delegate_debug_id_str,
86            debug_entry=flatcc.Value(
87                val=flatcc.ValueType.TENSOR.value,
88                tensor=flatcc.Tensor(
89                    scalar_type=flatcc.ScalarType.INT,
90                    sizes=[1],
91                    strides=[1],
92                    offset=12345,
93                ),
94                tensor_list=flatcc.TensorList(
95                    [
96                        flatcc.Tensor(
97                            scalar_type=flatcc.ScalarType.INT,
98                            sizes=[1],
99                            strides=[1],
100                            offset=12345,
101                        )
102                    ]
103                ),
104                int_value=flatcc.Int(1),
105                float_value=flatcc.Float(1.0),
106                double_value=flatcc.Double(1.0),
107                bool_value=flatcc.Bool(False),
108                output=None,
109            ),
110        )
111
112    @staticmethod
113    def _get_sample_etdump_flatcc() -> flatcc.ETDumpFlatCC:
114        """
115        Helper for getting a sample ETDumpFlatCC object with 3 RunData:
116        - run_data_1 has signature_a with just profile_1
117        - run_data_2 has the same signature with run_data_1, but different times
118        - run_data_3 has signature_b with both (profile_1, profile_2)
119        """
120        profile_event_1 = TestEventBlock._gen_sample_profile_event(
121            name="profile_1", instruction_id=1, time=(0, 1), delegate_debug_id=100
122        )
123        run_data_1 = flatcc.RunData(
124            name="signature_a",
125            bundled_input_index=-1,
126            allocators=[],
127            events=[
128                flatcc.Event(
129                    allocation_event=None,
130                    debug_event=None,
131                    profile_event=profile_event_1,
132                )
133            ],
134        )
135        profile_event_2 = TestEventBlock._gen_sample_profile_event(
136            name="profile_1", instruction_id=1, time=(2, 4), delegate_debug_id=100
137        )
138        run_data_2 = flatcc.RunData(
139            name="signature_a",
140            bundled_input_index=-1,
141            allocators=[],
142            events=[
143                flatcc.Event(
144                    allocation_event=None,
145                    debug_event=None,
146                    profile_event=profile_event_2,
147                )
148            ],
149        )
150
151        profile_event_3 = TestEventBlock._gen_sample_profile_event(
152            name="profile_1", instruction_id=1, time=(5, 6), delegate_debug_id=100
153        )
154        profile_event_4 = TestEventBlock._gen_sample_profile_event(
155            name="profile_2", instruction_id=2, time=(7, 8), delegate_debug_id=100
156        )
157        run_data_3 = flatcc.RunData(
158            name="signature_b",
159            bundled_input_index=-1,
160            allocators=[],
161            events=[
162                flatcc.Event(
163                    allocation_event=None,
164                    debug_event=None,
165                    profile_event=profile_event_3,
166                ),
167                flatcc.Event(
168                    allocation_event=None,
169                    debug_event=None,
170                    profile_event=profile_event_4,
171                ),
172            ],
173        )
174
175        return ETDumpFlatCC(version=0, run_data=[run_data_1, run_data_2, run_data_3])
176
177    @staticmethod
178    def _get_sample_etdump_flatcc_inconsistent_debug_data() -> flatcc.ETDumpFlatCC:
179        debug_event_1 = TestEventBlock._gen_sample_debug_event(
180            instruction_id=1, delegate_debug_id=100
181        )
182        run_data_1 = flatcc.RunData(
183            name="signature_a",
184            bundled_input_index=-1,
185            allocators=[],
186            events=[
187                flatcc.Event(
188                    allocation_event=None,
189                    debug_event=debug_event_1,
190                    profile_event=None,
191                ),
192            ],
193        )
194
195        debug_event_2 = TestEventBlock._gen_sample_debug_event(
196            instruction_id=1, delegate_debug_id=100
197        )
198        # Modify this debug event so it's different from debug_event_1
199        debug_event_2.debug_entry.tensor.sizes = [2]  # pyre-ignore
200        run_data_2 = flatcc.RunData(
201            name="signature_a",
202            bundled_input_index=-1,
203            allocators=[],
204            events=[
205                flatcc.Event(
206                    allocation_event=None,
207                    debug_event=debug_event_2,
208                    profile_event=None,
209                ),
210            ],
211        )
212        return ETDumpFlatCC(version=0, run_data=[run_data_1, run_data_2])
213
214    @staticmethod
215    def _get_sample_etdump_flatcc_profiling_and_debugging() -> flatcc.ETDumpFlatCC:
216        """
217        Helper for getting a sample ETDumpFlatCC object with 3 RunData:
218        - run_data_1 has signature_a with (debug_event_1, profile_event_1)
219        - run_data_2 has the same signature with run_data_1 and same debug event, but different profiling times
220        - run_data_3 has signature_b with (debug_event_3, profile_event_3) and (not debug event, profile_event_4)
221        """
222        profile_event_1 = TestEventBlock._gen_sample_profile_event(
223            name="profile_1", instruction_id=1, time=(0, 1), delegate_debug_id=100
224        )
225        debug_event_1 = TestEventBlock._gen_sample_debug_event(
226            instruction_id=1, delegate_debug_id=100
227        )
228        run_data_1 = flatcc.RunData(
229            name="signature_a",
230            bundled_input_index=-1,
231            allocators=[],
232            events=[
233                flatcc.Event(
234                    allocation_event=None,
235                    debug_event=None,
236                    profile_event=profile_event_1,
237                ),
238                flatcc.Event(
239                    allocation_event=None,
240                    debug_event=debug_event_1,
241                    profile_event=None,
242                ),
243            ],
244        )
245
246        profile_event_2 = TestEventBlock._gen_sample_profile_event(
247            name="profile_1", instruction_id=1, time=(2, 4), delegate_debug_id=100
248        )
249        debug_event_2 = TestEventBlock._gen_sample_debug_event(
250            instruction_id=1, delegate_debug_id=100
251        )
252        run_data_2 = flatcc.RunData(
253            name="signature_a",
254            bundled_input_index=-1,
255            allocators=[],
256            events=[
257                flatcc.Event(
258                    allocation_event=None,
259                    debug_event=None,
260                    profile_event=profile_event_2,
261                ),
262                flatcc.Event(
263                    allocation_event=None,
264                    debug_event=debug_event_2,
265                    profile_event=None,
266                ),
267            ],
268        )
269
270        profile_event_3 = TestEventBlock._gen_sample_profile_event(
271            name="profile_3", instruction_id=1, time=(5, 6), delegate_debug_id=100
272        )
273        debug_event_3 = TestEventBlock._gen_sample_debug_event(
274            instruction_id=1, delegate_debug_id=100
275        )
276        profile_event_4 = TestEventBlock._gen_sample_profile_event(
277            name="profile_4", instruction_id=2, time=(7, 8), delegate_debug_id=100
278        )
279        run_data_3 = flatcc.RunData(
280            name="signature_b",
281            bundled_input_index=-1,
282            allocators=[],
283            events=[
284                flatcc.Event(
285                    allocation_event=None,
286                    debug_event=debug_event_3,
287                    profile_event=None,
288                ),
289                flatcc.Event(
290                    allocation_event=None,
291                    debug_event=None,
292                    profile_event=profile_event_3,
293                ),
294                flatcc.Event(
295                    allocation_event=None,
296                    debug_event=None,
297                    profile_event=profile_event_4,
298                ),
299            ],
300        )
301
302        return ETDumpFlatCC(version=0, run_data=[run_data_1, run_data_2, run_data_3])
303
304    @staticmethod
305    def _get_sample_etdump_flatcc_debug_events_only(
306        event_name: str,
307        delegate_debug_id: str,
308    ) -> flatcc.ETDumpFlatCC:
309        """
310        Helper for getting a sample ETDumpFlatCC object with RunData signature_a
311        and (debug_event_delegated, debug_event_non_delegated, no profile event)
312        """
313
314        debug_event_delegated = TestEventBlock._gen_sample_debug_event(
315            instruction_id=1, delegate_debug_id=delegate_debug_id, name=event_name
316        )
317        debug_event_non_delegated = TestEventBlock._gen_sample_debug_event(
318            instruction_id=1, name=event_name
319        )
320        run_data_1 = flatcc.RunData(
321            name="signature_a",
322            bundled_input_index=-1,
323            allocators=[],
324            events=[
325                flatcc.Event(
326                    allocation_event=None,
327                    debug_event=debug_event_delegated,
328                    profile_event=None,
329                ),
330                flatcc.Event(
331                    allocation_event=None,
332                    debug_event=debug_event_non_delegated,
333                    profile_event=None,
334                ),
335            ],
336        )
337
338        return ETDumpFlatCC(version=0, run_data=[run_data_1])
339
340    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Tests ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
341
342    def test_gen_from_etdump(self) -> None:
343        """
344        Test "e2e" generation of EventBlocks given an ETDump
345            - Generated via EventBlock.gen_from_etdump
346
347        Specifically it tests for external correctness:
348        - Correct number of EventBlocks
349        - Correct number of Events and Raw Data values (iterations)
350        """
351
352        etdump: ETDumpFlatCC = TestEventBlock._get_sample_etdump_flatcc()
353        blocks: List[EventBlock] = EventBlock._gen_from_etdump(etdump)
354
355        self.assertEqual(len(blocks), 2, f"Expected 2 runs, got {len(blocks)}")
356
357        # One EventBlock should have 1 event with 2 iterations
358        # The other EventBlock should have 2 events with 1 iterations
359        run_counts = set()
360        for block in blocks:
361            if (perf_data := block.events[0].perf_data) is not None:
362                run_counts.add((len(block.events), len(perf_data.raw)))
363        self.assertSetEqual(run_counts, {(1, 2), (2, 1)})
364
365    def test_gen_from_etdump_profiling_and_debugging(self) -> None:
366        """
367        Test "e2e" generation of EventBlocks given an ETDump with both profiling and debugging events
368            - Generated via EventBlock.gen_from_etdump
369
370        Specifically it tests for external correctness:
371        - Correct number of EventBlocks
372        - Correct number of raw perf_data and debug_data for each Event
373        """
374        etdump: ETDumpFlatCC = (
375            TestEventBlock._get_sample_etdump_flatcc_profiling_and_debugging()
376        )
377        blocks: List[EventBlock] = EventBlock._gen_from_etdump(etdump)
378
379        self.assertEqual(len(blocks), 2, f"Expected 2 runs, got {len(blocks)}")
380
381        # One EventBlock should have 1 event with 2 iterations
382        # and 1 debug data (because we only populate debug data in the first iteration)
383        self.assertEqual(len(blocks[0].events), 1)
384        if (perf_data := blocks[0].events[0].perf_data) is not None:
385            self.assertEqual(len(perf_data.raw), 2)
386        self.assertEqual(len(blocks[0].events[0].debug_data), 1)
387
388        # The other EventBlock should have 2 events with 1 iterations, and only the fist event has debug data
389        self.assertEqual(len(blocks[1].events), 2)
390        perf_data = blocks[1].events[0].perf_data
391        self.assertIsNotNone(perf_data)
392        self.assertEqual(len(perf_data.raw), 1)
393
394        perf_data = blocks[1].events[1].perf_data
395        self.assertIsNotNone(perf_data)
396        self.assertEqual(len(perf_data.raw), 1)
397        self.assertEqual(len(blocks[1].events[0].debug_data), 1)
398        self.assertEqual(len(blocks[1].events[1].debug_data), 0)
399
400    def test_gen_from_etdump_inconsistent_debug_data(self) -> None:
401        """
402        Make sure AssertionError is thrown when intermediate outputs are different across
403        different iterations of a model run
404        """
405        etdump: ETDumpFlatCC = (
406            TestEventBlock._get_sample_etdump_flatcc_inconsistent_debug_data()
407        )
408        with self.assertRaises(AssertionError):
409            EventBlock._gen_from_etdump(etdump)
410
411    def test_gen_from_etdump_debug_events_only(self) -> None:
412        """
413        Test generation of EventBlocks given an ETDump with only debugging events
414
415        Specifically it tests:
416        - Correct number of EventBlocks and Events
417        - Correct name of each Event
418        """
419        event_name = "test_debug_event_only"
420        delegate_debug_id = "debug_id"
421        etdump: ETDumpFlatCC = (
422            TestEventBlock._get_sample_etdump_flatcc_debug_events_only(
423                event_name=event_name,
424                delegate_debug_id=delegate_debug_id,
425            )
426        )
427        event_blocks = EventBlock._gen_from_etdump(etdump)
428        self.assertEqual(len(event_blocks), 1)
429        self.assertEqual(len(event_blocks[0].events), 2)
430        # Delegated event uses delegate_debug_id as event name
431        self.assertEqual(event_blocks[0].events[0].name, delegate_debug_id)
432        # Non delegated event uses event_name as event name
433        self.assertEqual(event_blocks[0].events[1].name, event_name)
434
435    def test_inspector_event_generation(self) -> None:
436        """
437        Test Inspector.Event derivation from various ProfileEvent cases
438        - Non Delegated
439        - Delegate with Int Debug ID
440        - Delegate with String Debug ID
441        """
442
443        def _test_profile_event_generation(
444            name: str,
445            instruction_id: int,
446            delegate_debug_id_int: Optional[int] = None,
447            delegate_debug_id_str: Optional[str] = None,
448            scale_factor: int = 1000,
449        ) -> None:
450            """
451            Helper function for testing that the provided ProfileEvent fields are
452            properly translated to Inspector.ProfileEventSignature and Inspector.Event
453            """
454            delegate_debug_id = delegate_debug_id_int or delegate_debug_id_str
455            profile_event: flatcc.ProfileEvent = (
456                TestEventBlock._gen_sample_profile_event(
457                    name,
458                    instruction_id,
459                    (0, 1),
460                    delegate_debug_id,
461                )
462            )
463
464            # Test Signature Generation
465            profile_signature = ProfileEventSignature._gen_from_event(profile_event)
466            expected_signature = ProfileEventSignature(
467                name,
468                instruction_id,
469                delegate_debug_id_int,
470                delegate_debug_id_str,
471            )
472            self.assertEqual(profile_signature, expected_signature)
473
474            event_signature = EventSignature(
475                instruction_id=instruction_id,
476                profile_event_signature=profile_signature,
477            )
478
479            # Test Event Generation
480            durations = [10, 20, 30]
481            delegate_debug_metadatas = ["metadata_0", "metadata_1", "metadata_2"]
482            profile_events: List[flatcc.ProfileEvent] = [
483                TestEventBlock._gen_sample_profile_event(
484                    name,
485                    instruction_id,
486                    (0, time),
487                    delegate_debug_id,
488                    (
489                        delegate_debug_metadatas[index]
490                        if delegate_debug_id is not None
491                        else None
492                    ),
493                )
494                for index, time in enumerate(durations)
495            ]
496            instruction_events = [
497                InstructionEvent(
498                    signature=InstructionEventSignature(
499                        instruction_id=instruction_id, chain_index=0
500                    ),
501                    profile_events=[profile_event],
502                )
503                for profile_event in profile_events
504            ]
505            event = Event._gen_from_inference_events(
506                event_signature, instruction_events, scale_factor=scale_factor
507            )
508
509            is_delegated = delegate_debug_id is not None
510            expected_event = Event(
511                name=str(delegate_debug_id) if is_delegated else name,
512                perf_data=PerfData(
513                    [float(duration) / scale_factor for duration in durations]
514                ),
515                delegate_debug_identifier=delegate_debug_id,
516                is_delegated_op=is_delegated,
517                _delegate_debug_metadatas=(
518                    delegate_debug_metadatas if is_delegated else []
519                ),
520                _instruction_id=event_signature.instruction_id,
521            )
522            self.assertEqual(event, expected_event)
523
524            # Test delegate_debug_metadata_parsing
525            if is_delegated:
526                expected_event = Event(
527                    name=str(delegate_debug_id) if is_delegated else name,
528                    perf_data=PerfData(
529                        [float(duration) / scale_factor for duration in durations]
530                    ),
531                    delegate_debug_identifier=delegate_debug_id,
532                    is_delegated_op=is_delegated,
533                    _delegate_debug_metadatas=delegate_debug_metadatas,
534                    _instruction_id=event_signature.instruction_id,
535                    _delegate_metadata_parser=lambda metadatas: {
536                        "joined": "-".join(metadatas)
537                    },
538                )
539                self.assertEqual(
540                    expected_event.delegate_debug_metadatas,
541                    {"joined": "-".join(delegate_debug_metadatas)},
542                )
543
544        # Non Delegated
545        _test_profile_event_generation("non-delegate", 1)
546
547        # Delegate with Int Debug ID
548        _test_profile_event_generation("delegate", 1, 100)
549
550        # Delegate with String Debug ID
551        _test_profile_event_generation("delegate", 1, None, "identifier")
552
553        # Manipulating the scale factor
554        _test_profile_event_generation(
555            "delegate", 1, None, "identifier", scale_factor=10000
556        )
557
558    def test_gen_resolve_debug_handles(self) -> None:
559        """
560        Test that gen_resolve_debug_handles() correctly populates the EventBlock
561        """
562
563        def _gen_event_helper(events: List[ProfileEvent]) -> Event:
564            """
565            Helper function to generate an Event given a set of ProfileEvents
566            """
567            profile_event = events[0]
568            profile_signature = ProfileEventSignature._gen_from_event(profile_event)
569            event_signature = EventSignature(
570                profile_event.instruction_id, profile_signature
571            )
572            instruction_events = [
573                InstructionEvent(
574                    signature=InstructionEventSignature(
575                        instruction_id=profile_event.instruction_id,
576                        chain_index=profile_event.chain_index,
577                    ),
578                    profile_events=[event],
579                )
580                for event in events
581            ]
582            return Event._gen_from_inference_events(event_signature, instruction_events)
583
584        # Create Test Data
585
586        # Non-Delegated
587        non_delegated_profile_events_1 = [
588            TestEventBlock._gen_sample_profile_event("non_del_1", 0, (0, 1)),
589            TestEventBlock._gen_sample_profile_event("non_del_1", 0, (0, 1)),
590        ]
591        non_delegated_profile_events_2 = [
592            TestEventBlock._gen_sample_profile_event("non_del_2", 1, (0, 1)),
593        ]
594        non_delegated_event_1 = _gen_event_helper(non_delegated_profile_events_1)
595        non_delegated_event_2 = _gen_event_helper(non_delegated_profile_events_2)
596
597        # Delegated
598        delegated_profile_events_1 = [
599            TestEventBlock._gen_sample_profile_event("del_1", 0, (0, 1), 10),
600            TestEventBlock._gen_sample_profile_event("del_1", 0, (0, 1), 10),
601            TestEventBlock._gen_sample_profile_event("del_1", 0, (0, 1), 10),
602        ]
603        delegated_profile_events_2 = [
604            TestEventBlock._gen_sample_profile_event("del_2", 2, (0, 1), 20),
605        ]
606        delegated_event_1 = _gen_event_helper(delegated_profile_events_1)
607        delegated_event_2 = _gen_event_helper(delegated_profile_events_2)
608
609        # Create Test EventBlock
610        event_block = EventBlock(
611            name="test_name_1",
612            events=[
613                non_delegated_event_1,
614                non_delegated_event_2,
615                delegated_event_1,
616                delegated_event_2,
617            ],
618        )
619
620        # Create Test Maps
621        handle_map = {"0": [100], "1": [110], "2": [120]}
622        delegate_map = {
623            "0": DelegateMetadata(
624                {
625                    "name": "delegate",
626                    "delegate_map": {10: (100, 1000)},
627                }
628            ),
629            "2": DelegateMetadata(
630                {
631                    "name": "delegate_2",
632                    "delegate_map": {20: (200,)},
633                }
634            ),
635        }
636        event_block._gen_resolve_debug_handles(handle_map, delegate_map)
637
638        # Verify Results
639        for event in event_block.events:
640            # To satisfy type checker
641            assert event._instruction_id is not None
642            if (
643                delegate_debug_identifier := event.delegate_debug_identifier
644            ) is not None:
645                # Delegated
646                metadata = delegate_map[str(event._instruction_id)]
647                self.assertEqual(event.delegate_backend_name, metadata["name"])
648                self.assertEqual(
649                    event.debug_handles,
650                    metadata["delegate_map"][delegate_debug_identifier],  # pyre-ignore
651                )
652            else:
653                # Non Delegated
654                self.assertEqual(
655                    event.debug_handles, handle_map[str(event._instruction_id)]
656                )
657