xref: /aosp_15_r20/external/executorch/devtools/inspector/_inspector.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9import dataclasses
10import logging
11import sys
12import warnings
13from collections import defaultdict, OrderedDict
14from dataclasses import dataclass
15from functools import cached_property
16from typing import (
17    Any,
18    Callable,
19    Dict,
20    IO,
21    List,
22    Mapping,
23    Optional,
24    Sequence,
25    Tuple,
26    TypeAlias,
27    TypedDict,
28    Union,
29)
30
31import executorch.devtools.etdump.schema_flatcc as flatcc
32
33import numpy as np
34import pandas as pd
35
36from executorch.devtools.debug_format.et_schema import OperatorGraph, OperatorNode
37from executorch.devtools.etdump.schema_flatcc import (
38    DebugEvent,
39    ETDumpFlatCC,
40    ProfileEvent,
41)
42from executorch.devtools.etrecord import ETRecord, parse_etrecord
43from executorch.devtools.inspector._inspector_utils import (
44    calculate_time_scale_factor,
45    create_debug_handle_to_op_node_mapping,
46    display_or_print_df,
47    EDGE_DIALECT_GRAPH_KEY,
48    EXCLUDED_COLUMNS_WHEN_PRINTING,
49    EXCLUDED_EVENTS_WHEN_PRINTING,
50    find_populated_event,
51    FORWARD,
52    gen_etdump_object,
53    gen_graphs_from_etrecord,
54    inflate_runtime_output,
55    is_debug_output,
56    is_inference_output_equal,
57    ProgramOutput,
58    RESERVED_FRAMEWORK_EVENT_NAMES,
59    TimeScale,
60    verify_debug_data_equivalence,
61)
62from executorch.exir import ExportedProgram
63
64
65log: logging.Logger = logging.getLogger(__name__)
66
67
68# Signature of an InstructionEvent
69@dataclass(frozen=True, order=True)
70class InstructionEventSignature:
71    instruction_id: int
72    chain_index: int
73    delegate_id: Optional[int] = None
74    delegate_id_str: Optional[str] = None
75
76
77# Aggregated Runtime Events for a single instruction
78@dataclass
79class InstructionEvent:
80    signature: InstructionEventSignature
81    profile_events: Optional[List[ProfileEvent]] = None
82    debug_events: Optional[List[DebugEvent]] = None
83
84    @staticmethod
85    def gen_from_events(run_events: List[flatcc.Event]) -> List["InstructionEvent"]:
86        """
87        Given a list of events from a run in ETDump, collate the ProfileEvent
88        and DebugEvents by instruction id and return a list of InstructionEvents
89        constructed from collated events (ignoring run_output events)
90        """
91        instruction_events: Dict[InstructionEventSignature, InstructionEvent] = (
92            OrderedDict()
93        )
94        for event in run_events:
95            # Find the event that was logged
96            populated_event: Union[DebugEvent, ProfileEvent] = find_populated_event(
97                event
98            )
99
100            # Get existing InstructionEvent or insert a new one
101            signature = InstructionEventSignature(
102                instruction_id=populated_event.instruction_id,
103                chain_index=populated_event.chain_index,
104                delegate_id=populated_event.delegate_debug_id_int,
105                delegate_id_str=populated_event.delegate_debug_id_str,
106            )
107
108            instruction_event = instruction_events.setdefault(
109                signature, InstructionEvent(signature=signature)
110            )
111
112            # Update InstructionEvent based on event type
113            if isinstance(populated_event, ProfileEvent):
114                if instruction_event.profile_events is None:
115                    instruction_event.profile_events = []
116                instruction_event.profile_events.append(populated_event)
117            elif isinstance(populated_event, DebugEvent):
118                # Ignore run_output events
119                if not is_debug_output(populated_event.debug_entry):
120                    if instruction_event.debug_events is None:
121                        instruction_event.debug_events = []
122                    instruction_event.debug_events.append(populated_event)
123
124        return list(instruction_events.values())
125
126
127# Signature of a ProfileEvent
128@dataclass(frozen=True, order=True)
129class ProfileEventSignature:
130    name: str
131    instruction_id: Optional[int]
132    delegate_id: Optional[int] = None
133    delegate_id_str: Optional[str] = None
134
135    @staticmethod
136    def _gen_from_event(event: ProfileEvent) -> "ProfileEventSignature":
137        """
138        Given a ProfileEvent, extract the fields into a signature
139
140        ProfileEvents from ETDump default to "" and -1 when the field is not populated
141        The Signature will convert these back to the intended None value
142        """
143        return ProfileEventSignature(
144            event.name or "",
145            event.instruction_id if event.instruction_id != -1 else None,
146            event.delegate_debug_id_int if event.delegate_debug_id_int != -1 else None,
147            event.delegate_debug_id_str if event.delegate_debug_id_str != "" else None,
148        )
149
150
151# Signature of a DebugEvent
152@dataclass(frozen=True, order=True)
153class DebugEventSignature:
154    name: str = ""
155    instruction_id: Optional[int] = -1
156    delegate_id: Optional[int] = None
157    delegate_id_str: Optional[str] = None
158
159    @staticmethod
160    def _gen_from_event(event: DebugEvent) -> "DebugEventSignature":
161        """
162        Given a DebugEvent, extract the fields into a signature
163
164        DebugEvents from ETDump default to "" and -1 when the field is not populated
165        The Signature will convert these back to the intended None value
166        """
167        return DebugEventSignature(
168            event.name or "",
169            event.instruction_id if event.instruction_id != -1 else None,
170            event.delegate_debug_id_int if event.delegate_debug_id_int != -1 else None,
171            event.delegate_debug_id_str if event.delegate_debug_id_str != "" else None,
172        )
173
174
175# Signature of an Event inside of a Run
176@dataclass(frozen=True, order=True)
177class EventSignature:
178    """
179    Note that (profile_event_signature, debug_event_signature) are sufficient
180    signature identifiers.
181
182    instruction_id is extracted from the signatures (equivalent in both) and
183    surfaced for convenience
184    """
185
186    instruction_id: int
187    profile_event_signature: Optional[ProfileEventSignature] = None
188    debug_event_signature: Optional[DebugEventSignature] = None
189
190    @staticmethod
191    def gen_from_instruction_event(
192        instruction_event: InstructionEvent,
193    ) -> List[Tuple["EventSignature", InstructionEvent]]:
194        """
195        Construct EventSignatures from the given InstructionEvent
196        and return tuples of (1) EventSignature and (2) related subset
197        InstructionEvent
198        """
199
200        # Generate the DebugEventSignature
201        debug_events = instruction_event.debug_events
202        debug_signature = (
203            DebugEventSignature._gen_from_event(debug_events[0])
204            if debug_events is not None and len(debug_events) > 0
205            else None
206        )
207
208        # If no ProfileEvents, return a singleton EventSignature
209        if (profile_events := instruction_event.profile_events) is None:
210            return [
211                (
212                    EventSignature(
213                        instruction_id=instruction_event.signature.instruction_id,
214                        debug_event_signature=debug_signature,
215                    ),
216                    instruction_event,
217                )
218            ]
219
220        # Generate the ProfileEventSignature
221        return [
222            (
223                EventSignature(
224                    instruction_id=instruction_event.signature.instruction_id,
225                    profile_event_signature=ProfileEventSignature._gen_from_event(
226                        profile_event
227                    ),
228                    debug_event_signature=debug_signature,
229                ),
230                dataclasses.replace(instruction_event, profile_events=[profile_event]),
231            )
232            for profile_event in profile_events
233        ]
234
235
236# Signature of a Run
237@dataclass(frozen=True, order=True)
238class RunSignature:
239    """
240    Args:
241        name: Name of the run
242        events: List of EventSignatures that correspond to the run
243        bundled_input_index: Index of the bundled input used to generate the debug output
244    """
245
246    name: str
247    events: Optional[Tuple[EventSignature]] = None
248    bundled_input_index: Optional[int] = None
249
250
251# Typing for mapping Event.delegate_debug_identifiers to debug_handle(s)
252DelegateIdentifierDebugHandleMap: TypeAlias = Union[
253    Mapping[int, Tuple[int, ...]], Mapping[str, Tuple[int, ...]]
254]
255
256# Typing for Dict containig delegate metadata
257DelegateMetadata = TypedDict(
258    "DelegateMetadata",
259    {"name": str, "delegate_map": DelegateIdentifierDebugHandleMap},
260)
261
262
263@dataclass
264class PerfData:
265    def __init__(self, raw: List[float]):
266        self.raw: List[float] = raw
267
268    @property
269    def p10(self) -> float:
270        return np.percentile(self.raw, 10)
271
272    @property
273    def p50(self) -> float:
274        return np.percentile(self.raw, 50)
275
276    @property
277    def p90(self) -> float:
278        return np.percentile(self.raw, 90)
279
280    @property
281    def avg(self) -> float:
282        return np.mean(self.raw)
283
284    @property
285    def min(self) -> float:
286        return min(self.raw)
287
288    @property
289    def max(self) -> float:
290        return max(self.raw)
291
292
293@dataclass
294class Event:
295    """
296    An Event corresponds to an operator instance with perf data retrieved from the runtime and other metadata from `ETRecord`.
297
298    Args:
299        name: Name of the profiling `Event`, empty if no profiling event.
300        perf_data: Performance data associated with the event retrived from the runtime (available attributes: p10, p50, p90, avg, min and max).
301        op_type: List of op types corresponding to the event.
302        delegate_debug_identifier: Supplemental identifier used in combination with instruction id.
303        debug_handles: Debug handles in the model graph to which this event is correlated.
304        stack_trace: A dictionary mapping the name of each associated op to its stack trace.
305        module_hierarchy: A dictionary mapping the name of each associated op to its module hierarchy.
306        is_delegated_op: Whether or not the event was delegated.
307        delegate_backend_name: Name of the backend this event was delegated to.
308
309        _delegate_debug_metadatas: A list of raw delegate debug metadata in string, one for each profile event.
310            Available parsed (if parser provided) as Event.delegate_debug_metadatas
311            Available as Event.raw_delegate_debug_metadatas
312
313        debug_data: A list containing intermediate data collected.
314
315        _instruction_id: Instruction Identifier for Symbolication
316        _delegate_metadata_parser: Optional Parser for _delegate_debug_metadatas
317    """
318
319    name: str
320    perf_data: Optional[PerfData] = None
321    op_types: List[str] = dataclasses.field(default_factory=list)
322    delegate_debug_identifier: Optional[Union[int, str]] = None
323    debug_handles: Optional[Union[int, Sequence[int]]] = None
324    stack_traces: Dict[str, str] = dataclasses.field(default_factory=dict)
325    module_hierarchy: Dict[str, Dict] = dataclasses.field(default_factory=dict)
326    is_delegated_op: Optional[bool] = None
327    delegate_backend_name: Optional[str] = None
328    _delegate_debug_metadatas: List[str] = dataclasses.field(default_factory=list)
329
330    debug_data: ProgramOutput = dataclasses.field(default_factory=list)
331    _instruction_id: Optional[int] = None
332
333    _delegate_metadata_parser: Optional[Callable[[List[str]], Dict[str, Any]]] = None
334    _delegate_time_scale_converter: Optional[
335        Callable[[Union[int, str], Union[int, float]], Union[int, float]]
336    ] = None
337
338    @cached_property
339    def delegate_debug_metadatas(self) -> Union[List[str], Dict[str, Any]]:
340        """
341        Returns the parsed _delegate_debug_metadatas if a parser is available
342        Otherwise returns the raw _delegate_debug_metadatas
343        """
344        if not self.is_delegated_op or self._delegate_metadata_parser is None:
345            return self._delegate_debug_metadatas
346        return self._delegate_metadata_parser(self._delegate_debug_metadatas)
347
348    @property
349    def raw_delegate_debug_metadatas(self) -> List[str]:
350        """
351        Return the raw unparsed _delegate_debug_metadatas
352        """
353        return self._delegate_debug_metadatas
354
355    def to_dataframe(self, _units="") -> pd.DataFrame:
356        """
357        Convert the Event into a pandas DataFrame
358
359        Args:
360            None
361
362        Returns:
363            A pandas DataFrame with the Event data
364        """
365        event_dict = self.asdict(_units=_units)
366        return pd.DataFrame(event_dict)
367
368    # Override the default implementation of dataclass.asdict to handle null perf data
369    def asdict(self, _units="") -> dict:
370        """
371        Convert the Event into a dict
372
373        Args:
374            None
375
376        Returns:
377            A dict with the Event data
378        """
379
380        def truncated_list(long_list: List[str]) -> str:
381            return f"['{long_list[0]}', '{long_list[1]}' ... '{long_list[-1]}'] ({len(long_list)} total)"
382
383        return {
384            "event_name": self.name,
385            "raw": [self.perf_data.raw if self.perf_data else None],
386            "p10" + _units: self.perf_data.p10 if self.perf_data else None,
387            "p50" + _units: self.perf_data.p50 if self.perf_data else None,
388            "p90" + _units: self.perf_data.p90 if self.perf_data else None,
389            "avg" + _units: self.perf_data.avg if self.perf_data else None,
390            "min" + _units: self.perf_data.min if self.perf_data else None,
391            "max" + _units: self.perf_data.max if self.perf_data else None,
392            "op_types": [
393                (
394                    self.op_types
395                    if len(self.op_types) < 5
396                    else truncated_list(self.op_types)
397                )
398            ],
399            "delegate_debug_identifier": self.delegate_debug_identifier,
400            "stack_traces": [self.stack_traces],
401            "module_hierarchy": [self.module_hierarchy],
402            "is_delegated_op": self.is_delegated_op,
403            "delegate_backend_name": self.delegate_backend_name,
404            "debug_data": [self.debug_data],
405        }
406
407    @staticmethod
408    def _gen_from_inference_events(
409        signature: EventSignature,
410        events: List[InstructionEvent],
411        scale_factor: float = 1.0,
412        output_buffer: Optional[bytes] = None,
413        delegate_metadata_parser: Optional[
414            Callable[[List[str]], Dict[str, Any]]
415        ] = None,
416        delegate_time_scale_converter: Optional[
417            Callable[[Union[int, str], Union[int, float]], Union[int, float]]
418        ] = None,
419    ) -> "Event":
420        """
421        Given an EventSignature and a list of Events with that signature,
422        return an Event object matching the EventSignature, with perf_data
423        populated from the list of ProfileEvents and debug_data populated from
424        the list of DebugEvents.
425
426        An optional inverse scale factor can be provided to adjust the event timestamps
427        An optional buffer can be provided to inflate etdump references
428        An optional delegate_metadata_parser can be provided to parse the delegate metadata
429        """
430
431        profile_event_signature = signature.profile_event_signature
432        debug_event_signature = signature.debug_event_signature
433
434        # Event is gradually populated in this function
435        ret_event: Event = Event(
436            name="",
437            _instruction_id=signature.instruction_id,
438            _delegate_metadata_parser=delegate_metadata_parser,
439            _delegate_time_scale_converter=delegate_time_scale_converter,
440        )
441
442        # Populate fields from profile events
443        Event._populate_profiling_related_fields(
444            ret_event, profile_event_signature, events, scale_factor
445        )
446
447        # Populate fields from debug events
448        Event._populate_debugging_related_fields(
449            ret_event, debug_event_signature, events, output_buffer
450        )
451
452        return ret_event
453
454    @staticmethod
455    def _calculate_elapsed_time(start_time, end_time):
456        # We're assuming if there's a wraparound in the time values, then
457        # the time representation of that platform only contains 32 bits.
458        # This should be fine for now, but ideally we should source the max
459        # time value from the platform using etdump.
460        max_uint32 = 2**32 - 1
461        if start_time > end_time:
462            if (start_time > max_uint32) or (end_time > max_uint32):
463                raise ValueError(
464                    f"Expected start_time ({start_time}) and end_time ({end_time}) to be less than {max_uint32} for cases where there is wrap-around of time values."
465                )
466            # Handle wraparound
467            elapsed_time = (max_uint32 - start_time) + end_time
468        else:
469            # Normal case
470            elapsed_time = end_time - start_time
471        return elapsed_time
472
473    @staticmethod
474    def _populate_event_signature_fields(
475        ret_event: "Event",
476        event_signature: Optional[Union[ProfileEventSignature, DebugEventSignature]],
477    ) -> None:
478        """
479        Given a partially constructed Event, populate the fields related to
480        the profile event signature or debug event signature
481
482        Fields Updated:
483            name
484            delegate_debug_identifier
485            is_delegated_op
486        """
487        # TODO: T201347372 Push the None check to ealier in the stack.
488        if event_signature is not None:
489            if event_signature.delegate_id is not None:  # 0 is a valid value
490                delegate_debug_identifier = event_signature.delegate_id
491            else:
492                delegate_debug_identifier = event_signature.delegate_id_str or None
493
494            # Use the delegate identifier as the event name if delegated
495            is_delegated_op = delegate_debug_identifier is not None
496            name = (
497                event_signature.name
498                if not is_delegated_op
499                else str(delegate_debug_identifier)
500            )
501
502            # Update fields
503            # This is for older version of etdump that doesn't have the name field for debug events, we don't update the name field
504            if name:
505                ret_event.name = name
506            ret_event.delegate_debug_identifier = delegate_debug_identifier
507            ret_event.is_delegated_op = is_delegated_op
508
509    @staticmethod
510    def _populate_profiling_related_fields(
511        ret_event: "Event",
512        profile_event_signature: Optional[ProfileEventSignature],
513        events: List[InstructionEvent],
514        scale_factor: float,
515    ) -> None:
516        """
517        Given a partially constructed Event, populate the fields related to
518        the profile events
519
520        Fields Updated:
521            name
522            delegate_debug_identifier
523            is_delegated_op
524            perf_data
525            delegate_debug_metadatas
526        """
527
528        # Fill out fields from profile event signature
529        Event._populate_event_signature_fields(ret_event, profile_event_signature)
530
531        # Fill out fields from profile event
532        data = []
533        delegate_debug_metadatas = []
534        for event in events:
535            if (profile_events := event.profile_events) is not None:
536                if len(profile_events) != 1:
537                    raise ValueError(
538                        f"Expected exactly one profile event per InstructionEvent when generating Inspector Event, but got {len(profile_events)}"
539                    )
540
541                profile_event = profile_events[0]
542
543                # Scale factor should only be applied to non-delegated ops
544                if (
545                    ret_event.is_delegated_op
546                    and (convert_time_scale := ret_event._delegate_time_scale_converter)
547                    is not None
548                ):
549                    scaled_time = Event._calculate_elapsed_time(
550                        convert_time_scale(ret_event.name, profile_event.start_time),
551                        convert_time_scale(ret_event.name, profile_event.end_time),
552                    )
553                # If it's not a delegated op then we can just use the raw time values
554                # and then scale them according to the scale factor that was passed in.
555                elif not ret_event.is_delegated_op:
556                    scaled_time = (
557                        float(
558                            Event._calculate_elapsed_time(
559                                profile_event.start_time, profile_event.end_time
560                            )
561                        )
562                        / scale_factor
563                    )
564                # If there was no scale factor passed in just take a difference of the
565                # end and start times.
566                else:
567                    scaled_time = float(
568                        Event._calculate_elapsed_time(
569                            profile_event.start_time, profile_event.end_time
570                        )
571                    )
572
573                data.append(scaled_time)
574                delegate_debug_metadatas.append(
575                    profile_event.delegate_debug_metadata
576                    if profile_event.delegate_debug_metadata
577                    else ""
578                )
579
580        # Update fields
581        if len(data) > 0:
582            ret_event.perf_data = PerfData(data)
583        if any(delegate_debug_metadatas):
584            ret_event._delegate_debug_metadatas = delegate_debug_metadatas
585
586    @staticmethod
587    def _populate_debugging_related_fields(
588        ret_event: "Event",
589        debug_event_signature: Optional[DebugEventSignature],
590        events: List[InstructionEvent],
591        output_buffer: Optional[bytes] = None,
592    ) -> None:
593        """
594        Given a partially constructed Event, populate the fields related to
595        the debug events
596
597        Fields Updated:
598            name
599            delegate_debug_identifier
600            is_delegated_op
601            debug_data
602        """
603
604        # Fill out fields from debug event signature
605        Event._populate_event_signature_fields(ret_event, debug_event_signature)
606
607        debug_data: List[flatcc.Value] = []
608        for event in events:
609            if (debug_events := event.debug_events) is None:
610                continue
611
612            # Populate on the first iteration only, then verify equivalence for others
613            if len(debug_data) == 0:
614                debug_data = [debug_event.debug_entry for debug_event in debug_events]
615            else:
616                for debug_event, value in zip(debug_events, debug_data):
617                    v1 = inflate_runtime_output(debug_event.debug_entry, output_buffer)
618                    v2 = inflate_runtime_output(value, output_buffer)
619                    assert is_inference_output_equal(
620                        v1, v2
621                    ), """Corresponding debug events in multiple iterations of the model
622                    must have the same debug entry values. This is not the case for the
623                    intermediate data present in this ETDump and indicates potential issues
624                    with the model/runtime."""
625
626        ret_event.debug_data = [
627            inflate_runtime_output(debug_value, output_buffer)
628            for debug_value in debug_data
629        ]
630
631    def _associate_with_op_graph_nodes(
632        self,
633        debug_handle_to_op_node_map: Dict[int, OperatorNode],
634    ) -> None:
635        """
636        Helper function to populate the stack_traces, module_hierarchy and op_types attributes
637        based on the debug handles of this event
638        """
639
640        # Framework events aren't logically associated with any nodes
641        if self.name in RESERVED_FRAMEWORK_EVENT_NAMES:
642            return
643
644        if (debug_handles := self.debug_handles) is None:
645            return
646
647        if isinstance(debug_handles, int):
648            debug_handles = [debug_handles]
649
650        for handle in debug_handles:
651            node = debug_handle_to_op_node_map.get(handle)
652            # Attach node metadata including stack traces, module hierarchy and op_types to this event
653            if node is not None and (metadata := node.metadata) is not None:
654                self.stack_traces[node.name] = metadata.get("stack_trace")
655                self.module_hierarchy[node.name] = metadata.get("nn_module_stack")
656                if node.op:
657                    # TODO: consider having this as a dict from node.name -> node.op
658                    self.op_types += [node.op]
659
660
661@dataclass
662class EventBlock:
663    r"""
664    An `EventBlock` contains a collection of events associated with a particular profiling/debugging block retrieved from the runtime.
665    Each `EventBlock` represents a pattern of execution. For example, model initiation and loading lives in a single `EventBlock`.
666    If there's a control flow, each branch will be represented by a separate `EventBlock`.
667
668    Args:
669        name: Name of the profiling/debugging block.
670        events: List of `Event`\ s associated with the profiling/debugging block.
671
672        bundled_input_idx: Index of the Bundled Input that this EventBlock corresponds to.
673        run_output: Run output extracted from the encapsulated Events
674    """
675
676    name: str
677    events: List[Event] = dataclasses.field(default_factory=list)
678    source_time_scale: TimeScale = TimeScale.NS
679    target_time_scale: TimeScale = TimeScale.MS
680    bundled_input_index: Optional[int] = None
681    run_output: Optional[ProgramOutput] = None
682    reference_output: Optional[ProgramOutput] = None
683
684    def to_dataframe(
685        self, include_units: bool = False, include_delegate_debug_data: bool = False
686    ) -> pd.DataFrame:
687        """
688        Converts the EventBlock into a DataFrame with each row being an event instance
689
690        Note: Rows that have an event_name = OPERATOR_CALL correspond to the perf of the
691            previous operator + framework tax of making said operator call.
692
693        Args:
694            include_units: Whether headers should include units (default false)
695            include_delegate_debug_data: Whether to show the delegate debug data
696
697        Returns:
698            A pandas DataFrame containing the data of each Event instance in this EventBlock.
699        """
700
701        units = " (" + self.target_time_scale.value + ")" if include_units else ""
702
703        df = pd.concat([e.to_dataframe(units) for e in self.events], ignore_index=True)
704        df.insert(
705            0,
706            "event_block_name",
707            np.asarray([self.name for _ in range(len(self.events))]),
708            allow_duplicates=True,
709        )
710
711        # Add Delegate Debug Metadata columns
712        if include_delegate_debug_data:
713            delegate_data = []
714            for event in self.events:
715                if (metadata := event.delegate_debug_metadatas) is not None and len(
716                    metadata
717                ) > 0:
718                    if isinstance(metadata, list):
719                        delegate_data.append(
720                            pd.Series([metadata], index=["delegate_debug_metadata"])
721                        )
722                    elif isinstance(metadata, dict):
723                        delegate_data.append(pd.Series(metadata))
724                    else:
725                        raise ValueError(
726                            f"Unexpected type for delegate_debug_metadata: {type(metadata)}"
727                        )
728                else:
729                    delegate_data.append(pd.Series())
730
731            if any(not data.empty for data in delegate_data):
732                df = pd.concat([df, pd.DataFrame(delegate_data)], axis=1)
733
734        return df
735
736    @staticmethod
737    def _gen_from_etdump(
738        etdump: ETDumpFlatCC,
739        source_time_scale: TimeScale = TimeScale.NS,
740        target_time_scale: TimeScale = TimeScale.MS,
741        output_buffer: Optional[bytes] = None,
742        delegate_metadata_parser: Optional[
743            Callable[[List[str]], Dict[str, Any]]
744        ] = None,
745        delegate_time_scale_converter: Optional[
746            Callable[[Union[int, str], Union[int, float]], Union[int, float]]
747        ] = None,
748    ) -> List["EventBlock"]:
749        """
750        Given an etdump, generate a list of EventBlocks corresponding to the
751        contents.
752
753        An optional (inverse) scale factor can be provided to adjust the
754        etdump timestamps associated with each EventBlocks
755
756        An optional buffer to inflate etdump references
757
758        An optional delegate metadata parser function to parse delegate profiling metadata
759        """
760
761        # Map each RunSignatures to instances of its constituent events.
762        #   The value of the map is a GroupedRunInstance which contains:
763        #   (1) a map from each EventSignature to InstructionEvents with the signature
764        #   (2) the run output for this RunSignature
765        @dataclass
766        class GroupedRunInstances:
767            events: OrderedDict[EventSignature, List[InstructionEvent]]
768            run_output: ProgramOutput
769
770        run_groups: Mapping[RunSignature, GroupedRunInstances] = defaultdict(
771            lambda: GroupedRunInstances(OrderedDict(), [])
772        )
773
774        # Collect all the run data
775        for run in etdump.run_data:
776            if (run_events := run.events) is None:
777                continue
778
779            # Collate the run_events into InstructionEvents
780            instruction_events: List[InstructionEvent] = (
781                InstructionEvent.gen_from_events(run_events)
782            )
783
784            # Map EventSignatures to the InstructionEvents
785            event_signatures: Dict[EventSignature, InstructionEvent] = OrderedDict()
786            for instruction_event in instruction_events:
787                if (
788                    instruction_event.debug_events is None
789                    and instruction_event.profile_events is None
790                ):
791                    # Currently corresponds to run output
792                    continue
793
794                generated_event_signatures: List[
795                    Tuple[EventSignature, InstructionEvent]
796                ] = EventSignature.gen_from_instruction_event(instruction_event)
797                for (
798                    event_signature,
799                    filtered_instruction_event,
800                ) in generated_event_signatures:
801                    event_signatures[event_signature] = filtered_instruction_event
802
803            # Create a RunSignature from the EventSignatures
804            run_signature = RunSignature(
805                name=run.name,
806                events=tuple(event_signatures.keys()),
807                bundled_input_index=run.bundled_input_index,
808            )
809
810            # Update the Run Groups, indexed on the RunSignature
811            run_signature_events: OrderedDict[
812                EventSignature, List[InstructionEvent]
813            ] = run_groups[run_signature].events
814            for event_signature, event in event_signatures.items():
815                run_signature_events.setdefault(event_signature, []).append(event)
816
817            # Populate (or Verify if already populated) Run Outputs
818            run_outputs: ProgramOutput = EventBlock._collect_run_outputs(
819                run_events, output_buffer
820            )
821            if len(existing_run_outputs := run_groups[run_signature].run_output) == 0:
822                existing_run_outputs.extend(run_outputs)
823            else:
824                verify_debug_data_equivalence(existing_run_outputs, run_outputs)
825
826        # Construct the EventBlocks
827        event_blocks = []
828        scale_factor = calculate_time_scale_factor(source_time_scale, target_time_scale)
829        for run_signature, grouped_run_instance in run_groups.items():
830            run_group: OrderedDict[EventSignature, List[InstructionEvent]] = (
831                grouped_run_instance.events
832            )
833            run_outputs: ProgramOutput = grouped_run_instance.run_output
834
835            # Construct the Events
836            events: List[Event] = [
837                Event._gen_from_inference_events(
838                    signature,
839                    instruction_events,
840                    scale_factor,
841                    output_buffer,
842                    delegate_metadata_parser,
843                    delegate_time_scale_converter,
844                )
845                for signature, instruction_events in run_group.items()
846            ]
847
848            # Add the EventBlock to the return list
849            event_blocks.append(
850                EventBlock(
851                    name=run_signature.name,
852                    events=events,
853                    source_time_scale=source_time_scale,
854                    target_time_scale=target_time_scale,
855                    bundled_input_index=run_signature.bundled_input_index,
856                    run_output=run_outputs,
857                )
858            )
859
860        return event_blocks
861
862    @staticmethod
863    def _collect_run_outputs(
864        events: List[flatcc.Event], output_buffer: Optional[bytes] = None
865    ) -> ProgramOutput:
866        """
867        Given a list of events, search the events for ProgramOutputs (aka lists of InferenceOutputs) marked
868        as run outputs
869        """
870
871        output_events = []
872        for event in events:
873            if event.debug_event is None:
874                continue
875            if event.debug_event.debug_entry is None:
876                raise RuntimeError(
877                    "Debug entry inside debug event should not be empty!"
878                )
879            if is_debug_output(event.debug_event.debug_entry):
880                output_events += [event]
881
882        return [
883            inflate_runtime_output(debug_event.debug_entry, output_buffer)
884            for output_event in output_events
885            if (debug_event := output_event.debug_event) is not None
886        ]
887
888    # TODO: Considering changing ETRecord deserialization logic to cast the ints in string format to actual ints
889    def _gen_resolve_debug_handles(
890        self,
891        handle_map: Dict[str, List[int]],
892        delegate_map: Optional[Dict[str, DelegateMetadata]] = None,
893    ):
894        """
895        Given mappings from instruction id to debug handles, populate the
896        debug_handles field of all underlying events
897
898        If the event is delegated, index with the instruction_id and delegate_debug_identifier
899        to obtain the debug_handle via the delegate map
900        """
901        for event in self.events:
902            # Check if instruction_id is present in the event
903            if event._instruction_id is None:
904                continue
905
906            # Check for the instruction_id in handle map
907            if (instruction_id := str(event._instruction_id)) not in handle_map:
908                continue
909
910            # For non-delegated event, handles are found in handle_map
911            if (delegate_debug_id := event.delegate_debug_identifier) is None:
912                event.debug_handles = handle_map[instruction_id]
913
914                # DELEGATE_CALL is a special non-delegated event and benefits from having the name populated
915                if (
916                    event.name == "DELEGATE_CALL"
917                    and delegate_map is not None
918                    and (delegate_metadata := delegate_map.get(instruction_id))
919                    is not None
920                ):
921                    event.delegate_backend_name = delegate_metadata.get("name", "")
922
923                continue
924
925            # Check that the delegated event has a corresponding mapping
926            if (
927                delegate_map is None
928                or (delegate_metadata := delegate_map.get(instruction_id)) is None
929            ):
930                event.debug_handles = handle_map[instruction_id]
931                log.warning(
932                    f" No delegate mapping found for delegate with instruction id {event._instruction_id}"
933                )
934                continue
935
936            # For delegated events, handles are found via delegateMetadata
937            event.delegate_backend_name = delegate_metadata.get("name", "")
938            delegate_metadata_delegate_map = delegate_metadata.get("delegate_map") or {}
939
940            # delegate_debug_id can be either int based or string based, therefore we need to check both
941            debug_handles = delegate_metadata_delegate_map.get(
942                delegate_debug_id  # pyre-ignore
943            )
944            if debug_handles is not None:
945                event.debug_handles = debug_handles
946            else:
947                event.debug_handles = delegate_metadata_delegate_map.get(
948                    str(delegate_debug_id)  # pyre-ignore
949                )
950                for key, value in delegate_metadata_delegate_map.items():
951                    if key in str(delegate_debug_id):
952                        event.debug_handles = value
953
954
955class Inspector:
956    """
957    APIs for examining model architecture and performance stats.
958
959    Public Attributes:
960        event_blocks: List["EventBlocks"]. Structured data from ETDump (correlated with ETRecord if provided).
961
962    Private Attributes:
963        _etrecord: Optional[ETRecord]. File under etrecord_path deserialized into an object.
964    """
965
966    def __init__(
967        self,
968        etdump_path: Optional[str] = None,
969        etdump_data: Optional[bytes] = None,
970        etrecord: Optional[Union[ETRecord, str]] = None,
971        source_time_scale: TimeScale = TimeScale.NS,
972        target_time_scale: TimeScale = TimeScale.MS,
973        debug_buffer_path: Optional[str] = None,
974        delegate_metadata_parser: Optional[
975            Callable[[List[str]], Dict[str, Any]]
976        ] = None,
977        delegate_time_scale_converter: Optional[
978            Callable[[Union[int, str], Union[int, float]], Union[int, float]]
979        ] = None,
980        enable_module_hierarchy: bool = False,
981    ) -> None:
982        r"""
983        Initialize an `Inspector` instance with the underlying `EventBlock`\ s populated with data from the provided ETDump path or binary,
984        and optional ETRecord path.
985
986        Args:
987            etdump_path: Path to the ETDump file. Either this parameter or etdump_data should be provided.
988            etdump_data: ETDump binary. Either this parameter or etdump_path should be provided.
989            etrecord: Optional ETRecord object or path to the ETRecord file.
990            source_time_scale: The time scale of the performance data retrieved from the runtime. The default time hook implentation in the runtime returns NS.
991            target_time_scale: The target time scale to which the users want their performance data converted to. Defaults to MS.
992            debug_buffer_path: Debug buffer file path that contains the debug data referenced by ETDump for intermediate and program outputs.
993            delegate_metadata_parser: Optional function to parse delegate metadata from an Profiling Event. Expected signature of the function is:
994                    (delegate_metadata_list: List[bytes]) -> Union[List[str], Dict[str, Any]]
995            delegate_time_scale_converter: Optional function to convert the time scale of delegate profiling data. If not given, use the conversion ratio of
996                    target_time_scale/source_time_scale.
997            enable_module_hierarchy: Enable submodules in the operator graph. Defaults to False.
998
999        Returns:
1000            None
1001        """
1002
1003        if (source_time_scale == TimeScale.CYCLES) ^ (
1004            target_time_scale == TimeScale.CYCLES
1005        ):
1006            raise RuntimeError(
1007                "For TimeScale in cycles both the source and target time scale have to be in cycles."
1008            )
1009        self._source_time_scale = source_time_scale
1010        self._target_time_scale = target_time_scale
1011
1012        if delegate_time_scale_converter is None:
1013            scale_factor = calculate_time_scale_factor(
1014                source_time_scale, target_time_scale
1015            )
1016            delegate_time_scale_converter = (
1017                lambda event_name, input_time: input_time / scale_factor
1018            )
1019
1020        if etrecord is None:
1021            self._etrecord = None
1022        elif isinstance(etrecord, ETRecord):
1023            self._etrecord = etrecord
1024        elif isinstance(etrecord, str):
1025            self._etrecord = parse_etrecord(etrecord_path=etrecord)
1026        else:
1027            raise TypeError("Unsupported ETRecord type")
1028
1029        if (etdump_path is None) == (etdump_data is None):
1030            raise ValueError(
1031                "Expecting exactly one of etdump_path or etdump_data to be specified."
1032            )
1033
1034        # Create EventBlocks from ETDump
1035        etdump = gen_etdump_object(etdump_path=etdump_path, etdump_data=etdump_data)
1036        if debug_buffer_path is not None:
1037            with open(debug_buffer_path, "rb") as f:
1038                output_buffer = f.read()
1039        else:
1040            output_buffer = None
1041            warnings.warn(
1042                "Output Buffer not found. Tensor Debug Data will not be available.",
1043                stacklevel=1,
1044            )
1045
1046        self.event_blocks = EventBlock._gen_from_etdump(
1047            etdump=etdump,
1048            source_time_scale=self._source_time_scale,
1049            target_time_scale=self._target_time_scale,
1050            output_buffer=output_buffer,
1051            delegate_metadata_parser=delegate_metadata_parser,
1052            delegate_time_scale_converter=delegate_time_scale_converter,
1053        )
1054
1055        # Connect ETRecord to EventBlocks
1056        self.op_graph_dict: Optional[Mapping[str, OperatorGraph]] = None
1057
1058        # _consume_etrecord() will populate the _reference_outputs dict
1059        # Key str is method name; value is list of ProgramOutputs because of list of test cases
1060        self._reference_outputs: Dict[str, List[ProgramOutput]] = {}
1061        self._enable_module_hierarchy = enable_module_hierarchy
1062        self._consume_etrecord()
1063
1064    def _consume_etrecord(self) -> None:
1065        """
1066        If an ETRecord is provided, connect it to the EventBlocks and populate the Event metadata.
1067
1068        Steps:
1069            1. Debug Handle Symbolification:
1070                For each Event, find the debug_handle counterparts using
1071                ETRecord's debug_handle_map and delegate_map
1072
1073            2. Event Metadata Association:
1074                For each Event, populate its metadata from OperatorGraph Nodes,
1075                generated from ETRecord. The debug_handle is used to
1076                identify the corresponding OperatorGraph Nodes.
1077
1078            3. Reference Outputs Extraction:
1079                If there're reference outputs saved in ETRecord, assign each reference output to the corresponding
1080                EventBlock based on the method name (currently assumes only "forward") and the
1081                bundled_input_index of the EventBlock.
1082        """
1083
1084        if self._etrecord is None:
1085            return
1086
1087        # (1) Debug Handle Symbolification
1088        for event_block in self.event_blocks:
1089            event_block._gen_resolve_debug_handles(
1090                self._etrecord._debug_handle_map[FORWARD],
1091                (
1092                    self._etrecord._delegate_map[FORWARD]
1093                    if self._etrecord._delegate_map is not None
1094                    else None
1095                ),
1096            )
1097
1098        # (2) Event Metadata Association
1099        self.op_graph_dict = gen_graphs_from_etrecord(
1100            etrecord=self._etrecord,
1101            enable_module_hierarchy=self._enable_module_hierarchy,
1102        )
1103        debug_handle_to_op_node_map = create_debug_handle_to_op_node_mapping(
1104            self.op_graph_dict[EDGE_DIALECT_GRAPH_KEY],
1105        )
1106        for event_block in self.event_blocks:
1107            for event in event_block.events:
1108                event._associate_with_op_graph_nodes(
1109                    debug_handle_to_op_node_map=debug_handle_to_op_node_map,
1110                )
1111
1112        # (3) Reference Outputs Extraction
1113        if self._etrecord._reference_outputs is not None:
1114            self._reference_outputs = self._etrecord._reference_outputs
1115            # Associate each reference output to the corresponding event block
1116            for event_block in self.event_blocks:
1117                index = event_block.bundled_input_index
1118                if index is not None:
1119                    event_block.reference_output = self._reference_outputs[FORWARD][
1120                        index
1121                    ]
1122
1123    def to_dataframe(
1124        self,
1125        include_units: bool = True,
1126        include_delegate_debug_data: bool = False,
1127    ) -> pd.DataFrame:
1128        """
1129        Args:
1130            include_units: Whether headers should include units (default true)
1131            include_delegate_debug_data: Whether to include delegate debug metadata (default false)
1132
1133        Returns:
1134            Returns a pandas DataFrame of the Events in each EventBlock in the inspector, with each row representing an Event.
1135        """
1136
1137        df_list = [
1138            event_block.to_dataframe(
1139                include_units=include_units,
1140                include_delegate_debug_data=include_delegate_debug_data,
1141            )
1142            for event_block in self.event_blocks
1143        ]
1144        return pd.concat(df_list, ignore_index=True)
1145
1146    def print_data_tabular(
1147        self,
1148        file: IO[str] = sys.stdout,
1149        include_units: bool = True,
1150        include_delegate_debug_data: bool = False,
1151    ) -> None:
1152        """
1153        Displays the underlying EventBlocks in a structured tabular format, with each row representing an Event.
1154
1155        Args:
1156            file: Which IO stream to print to. Defaults to stdout.
1157                Not used if this is in an IPython environment such as a Jupyter notebook.
1158            include_units: Whether headers should include units (default true)
1159            include_delegate_debug_data: Whether to include delegate debug metadata (default false)
1160
1161        Returns:
1162            None
1163        """
1164        combined_df = self.to_dataframe(include_units, include_delegate_debug_data)
1165
1166        # Filter out some columns and rows for better readability when printing
1167        filtered_column_df = combined_df.drop(columns=EXCLUDED_COLUMNS_WHEN_PRINTING)
1168        for filter_name in EXCLUDED_EVENTS_WHEN_PRINTING:
1169            filtered_column_df = filtered_column_df[
1170                ~filtered_column_df["event_name"].str.contains(filter_name)
1171            ]
1172        filtered_column_df.reset_index(drop=True, inplace=True)
1173
1174        display_or_print_df(filtered_column_df, file)
1175
1176    # TODO: write unit test
1177    def find_total_for_module(self, module_name: str) -> float:
1178        """
1179        Returns the total average compute time of all operators within the specified module.
1180
1181        Args:
1182            module_name: Name of the module to be aggregated against.
1183
1184        Returns:
1185            Sum of the average compute time (in seconds) of all operators within the module with "module_name".
1186        """
1187
1188        total = 0.0
1189        for block in self.event_blocks:
1190            for event in block.events:
1191                module_hierarchy = event.module_hierarchy.values()
1192                for hierarchy in module_hierarchy:
1193                    if not hierarchy:
1194                        continue
1195                    found = any(module_name in key for key in hierarchy.keys())
1196                    if found:
1197                        if event.perf_data is not None:
1198                            total += event.perf_data.avg
1199                        break
1200        return total
1201
1202    def get_op_list(
1203        self, event_block: str, show_delegated_ops: Optional[bool] = True
1204    ) -> Dict[str, List[Event]]:
1205        """
1206        Return a map of op_types to Events of that op_type
1207        """
1208        # TODO: implement
1209        return {}
1210
1211    def write_tensorboard_artifact(self, path: str) -> None:
1212        """
1213        Write to the provided path, the artifacts required for visualization in TensorBoard
1214        """
1215        # TODO: implement
1216        pass
1217
1218    def get_exported_program(
1219        self, graph: Optional[str] = None
1220    ) -> Optional[ExportedProgram]:
1221        """
1222        Access helper for ETRecord, defaults to returning the Edge Dialect program.
1223
1224        Args:
1225            graph: Optional name of the graph to access. If None, returns the Edge Dialect program.
1226
1227        Returns:
1228            The ExportedProgram object of "graph".
1229        """
1230        if self._etrecord is None:
1231            log.warning(
1232                "Exported program is only available when a valid etrecord_path was provided at the time of Inspector construction"
1233            )
1234            return None
1235        return (
1236            self._etrecord.edge_dialect_program
1237            if graph is None
1238            else self._etrecord.graph_map.get(graph)
1239        )
1240