xref: /aosp_15_r20/external/executorch/devtools/inspector/_inspector_utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9import math
10import sys
11from enum import Enum
12from typing import Dict, IO, List, Mapping, Optional, Tuple, TypeAlias, Union
13
14import executorch.devtools.etdump.schema_flatcc as flatcc
15
16import pandas as pd
17
18import torch
19
20from executorch.devtools.debug_format.base_schema import OperatorNode
21
22from executorch.devtools.debug_format.et_schema import FXOperatorGraph, OperatorGraph
23from executorch.devtools.etdump.schema_flatcc import (
24    DebugEvent,
25    ETDumpFlatCC,
26    ProfileEvent,
27    ScalarType,
28    Tensor,
29    Value,
30    ValueType,
31)
32
33from executorch.devtools.etdump.serialize import deserialize_from_etdump_flatcc
34from executorch.devtools.etrecord import ETRecord
35
36from tabulate import tabulate
37
38FORWARD = "forward"
39EDGE_DIALECT_GRAPH_KEY = "edge_dialect_graph_module"
40
41RESERVED_FRAMEWORK_EVENT_NAMES = [
42    "Method::init",
43    "Program::load_method",
44    "Method::execute",
45]
46EXCLUDED_COLUMNS_WHEN_PRINTING = [
47    "raw",
48    "delegate_debug_identifier",
49    "stack_traces",
50    "module_hierarchy",
51    "debug_data",
52]
53EXCLUDED_EVENTS_WHEN_PRINTING = {"OPERATOR_CALL"}
54
55
56class TimeScale(Enum):
57    NS = "ns"
58    US = "us"
59    MS = "ms"
60    S = "s"
61    CYCLES = "cycles"
62
63
64TIME_SCALE_DICT = {
65    TimeScale.NS: 1000000000,
66    TimeScale.US: 1000000,
67    TimeScale.MS: 1000,
68    TimeScale.S: 1,
69    TimeScale.CYCLES: 1,
70}
71
72
73def calculate_time_scale_factor(
74    source_time_scale: TimeScale, target_time_scale: TimeScale
75) -> float:
76    """
77    Calculate the factor (source divided by target) between two time scales
78    """
79    return TIME_SCALE_DICT[source_time_scale] / TIME_SCALE_DICT[target_time_scale]
80
81
82# Model Debug Output
83InferenceOutput: TypeAlias = Union[
84    torch.Tensor, List[torch.Tensor], int, float, str, bool, None
85]
86ProgramOutput: TypeAlias = List[InferenceOutput]
87
88
89# Compare whether two InferenceOutputs are equal
90def is_inference_output_equal(
91    output1: InferenceOutput, output2: InferenceOutput
92) -> bool:
93    if isinstance(output1, torch.Tensor) and isinstance(output2, torch.Tensor):
94        return torch.equal(output1, output2)
95    elif isinstance(output1, List) and isinstance(output2, List):
96        return all(torch.equal(t1, t2) for t1, t2 in zip(output1, output2))
97    elif output1 == output2:
98        return True
99    else:
100        return False
101
102
103# Given a ETDump Tensor object and offset, extract into a torch.Tensor
104def _parse_tensor_value(
105    tensor: Optional[Tensor], output_buffer: Optional[bytes]
106) -> torch.Tensor:
107    def get_scalar_type_size(scalar_type: ScalarType) -> Tuple[torch.dtype, int]:
108        """
109        Return the size of the scalar type in bytes
110        """
111        get_scalar_type_size_map = {
112            ScalarType.BYTE: (torch.uint8, 1),
113            ScalarType.CHAR: (torch.int8, 1),
114            ScalarType.BOOL: (torch.bool, 1),
115            ScalarType.BITS16: (torch.uint16, 2),
116            ScalarType.SHORT: (torch.int16, 2),
117            ScalarType.HALF: (torch.float16, 2),
118            ScalarType.INT: (torch.int, 4),
119            ScalarType.FLOAT: (torch.float, 4),
120            ScalarType.DOUBLE: (torch.double, 8),
121            ScalarType.LONG: (torch.long, 8),
122        }
123        if scalar_type in get_scalar_type_size_map:
124            return get_scalar_type_size_map[scalar_type]
125        else:
126            raise RuntimeError(
127                f"Unsupported scalar type in get_scalar_type_size : {scalar_type}"
128            )
129
130    if tensor is None or tensor.offset is None:
131        raise ValueError("Tensor cannot be None")
132
133    torch_dtype, dtype_size = get_scalar_type_size(tensor.scalar_type)
134
135    if output_buffer is None:
136        # Empty buffer provided. Cannot deserialize tensors.
137        return torch.zeros(tensor.sizes, dtype=torch_dtype)
138
139    tensor_bytes_size = math.prod(tensor.sizes) * dtype_size
140    if tensor_bytes_size == 0:
141        # Empty tensor. Return empty tensor.
142        return torch.zeros(tensor.sizes, dtype=torch_dtype)
143
144    if tensor.offset is None:
145        raise ValueError("Tensor offset cannot be None")
146
147    return torch.frombuffer(
148        output_buffer[tensor.offset : tensor.offset + tensor_bytes_size],
149        dtype=torch_dtype,
150    ).view(tensor.sizes)
151
152
153def inflate_runtime_output(
154    value: Value, output_buffer: Optional[bytes]
155) -> InferenceOutput:
156    """
157    Parse the given ETDump Value object into an InferenceOutput object
158    """
159
160    if value.val == ValueType.INT.value:
161        if value.int_value is None:
162            raise ValueError("Expected Int value, `None` provided")
163        return value.int_value.int_val
164    if value.val == ValueType.BOOL.value:
165        if value.bool_value is None:
166            raise ValueError("Expected Bool value, `None` provided")
167        return value.bool_value.bool_val
168    if value.val == ValueType.FLOAT.value:
169        if value.float_value is None:
170            raise ValueError("Expected Float value, `None` provided")
171        return value.float_value.float_val
172    if value.val == ValueType.DOUBLE.value:
173        if value.double_value is None:
174            raise ValueError("Expected Double value, `None` provided")
175        return value.double_value.double_val
176    if value.val == ValueType.TENSOR.value:
177        return _parse_tensor_value(value.tensor, output_buffer)
178    if value.val == ValueType.TENSOR_LIST.value:
179        if value.tensor_list is None:
180            raise ValueError("Expected TensorList value, `None` provided")
181        return [
182            _parse_tensor_value(t, output_buffer) for t in value.tensor_list.tensors
183        ]
184
185
186def find_populated_event(event: flatcc.Event) -> Union[ProfileEvent, DebugEvent]:
187    """
188    Given a ETDump Event object, find the populated event
189
190    Raise an error if no populated event can be found
191    """
192    if event.profile_event is not None:
193        return event.profile_event
194
195    if event.debug_event is not None:
196        return event.debug_event
197
198    raise ValueError("Unable to find populated event")
199
200
201# TODO: Optimize by verifying prior to inflating the tensors
202def verify_debug_data_equivalence(
203    existing_data: ProgramOutput, new_data: ProgramOutput
204) -> None:
205    """
206    Verify that the lists of inference_outputs are equivalent
207
208    Raises an corresponding errors if they are not
209    """
210    assert len(existing_data) == len(
211        new_data
212    ), "Unequal debug data length encountered. Expected to be equal."
213
214    for output_a, output_b in zip(existing_data, new_data):
215        assert isinstance(
216            output_a, type(output_b)
217        ), "Debug Data Types are different. Expected to be equal."
218
219        if isinstance(output_a, torch.Tensor):
220            assert bool(
221                # pyre-fixme[6]: For 1st argument expected `Tensor` but got `bool`.
222                torch.all(output_a == output_b)
223            ), "Tensors Debug Data is different. Expected to be equal."
224        else:
225            assert (
226                output_a == output_b
227            ), "Scalar Debug Data is different. Expected to be equal"
228
229
230def is_debug_output(value: Value) -> bool:
231    """
232    Returns True if the given flatcc.Value is a debug output
233    """
234    return value.output is not None and value.output.bool_val
235
236
237def gen_graphs_from_etrecord(
238    etrecord: ETRecord, enable_module_hierarchy: bool = False
239) -> Mapping[str, OperatorGraph]:
240    op_graph_map = {}
241    if etrecord.graph_map is not None:
242        op_graph_map = {
243            name: FXOperatorGraph.gen_operator_graph(
244                exported_program.graph_module,
245                enable_module_hierarchy=enable_module_hierarchy,
246            )
247            for name, exported_program in etrecord.graph_map.items()
248        }
249    if etrecord.edge_dialect_program is not None:
250        op_graph_map[EDGE_DIALECT_GRAPH_KEY] = FXOperatorGraph.gen_operator_graph(
251            etrecord.edge_dialect_program.graph_module,
252            enable_module_hierarchy=enable_module_hierarchy,
253        )
254
255    return op_graph_map
256
257
258def create_debug_handle_to_op_node_mapping(
259    op_graph: OperatorGraph,
260) -> Dict[int, OperatorNode]:
261    """
262    Recursive function to traverse all the operator graph nodes of input op_graph and build a mapping
263    from each debug handle to the operator node that contains the debug handle in its metadata.
264    """
265    debug_handle_to_op_node_map: Dict[int, OperatorNode] = {}
266
267    # Recursively searches through the metadata of nodes
268    def _extract_debug_handles(graph: OperatorGraph):
269        for element in graph.elements:
270            if isinstance(element, OperatorGraph):
271                _extract_debug_handles(element)
272            if isinstance(element, OperatorNode) and element.metadata is not None:
273                metadata = element.metadata
274                debug_handle = metadata.get("debug_handle")
275                if debug_handle is not None:
276                    existing_entry = debug_handle_to_op_node_map.get(debug_handle)
277                    if existing_entry is not None:
278                        raise ValueError(
279                            f"Duplicated debug handle {str(debug_handle)} shared between {element.name} and {existing_entry.name}. "
280                            "No two op nodes of the same graph should have the same debug handle."
281                        )
282                    debug_handle_to_op_node_map[debug_handle] = element
283
284    # Start traversing
285    _extract_debug_handles(op_graph)
286    return debug_handle_to_op_node_map
287
288
289def gen_etdump_object(
290    etdump_path: Optional[str] = None, etdump_data: Optional[bytes] = None
291) -> ETDumpFlatCC:
292    # Gen event blocks from etdump
293    if etdump_data is None and etdump_path is not None:
294        with open(etdump_path, "rb") as buff:
295            etdump_data = buff.read()
296
297    if etdump_data is None:
298        raise ValueError(
299            "Unable to get ETDump data. One and only one of etdump_path and etdump_data must be specified."
300        )
301
302    return deserialize_from_etdump_flatcc(etdump_data)
303
304
305def display_or_print_df(df: pd.DataFrame, file: IO[str] = sys.stdout):
306    try:
307        from IPython import get_ipython
308        from IPython.display import display
309
310        def style_text_size(val, size=12):
311            return f"font-size: {size}px"
312
313        if get_ipython() is not None:
314            styled_df = df.style.applymap(style_text_size)
315            display(styled_df)
316        else:
317            raise Exception(
318                "Environment unable to support IPython. Fall back to print()."
319            )
320    except:
321        print(
322            tabulate(df, headers="keys", tablefmt="fancy_grid"),
323            file=file,
324        )
325
326
327def plot_metric(result: List[float], metric_name: str):
328    import matplotlib.pyplot as plt
329    import numpy as np
330
331    # Clear the current figure, otherwise this plot will be on top of previous plots
332    plt.clf()
333    plt.figure(figsize=(8, 6))
334
335    x_axis = np.arange(len(result))
336    bars = plt.bar(x_axis, result, width=0.5)
337    plt.grid(True, which="major", axis="y")
338    num_ticks = len(x_axis) if len(x_axis) > 5 else 5
339    interval = 1 if num_ticks < 20 else 5
340    plt.xticks(list(range(num_ticks))[::interval])
341    plt.xlabel("Output value index")
342    plt.ylabel(metric_name)
343    plt.title(f"Metric {metric_name}")
344
345    # Add value annotations to each bar
346    for bar, value in zip(bars, result):
347        plt.text(
348            bar.get_x() + bar.get_width() / 2,
349            bar.get_height(),
350            str(value),
351            ha="center",
352            va="bottom",
353        )
354
355    max_value = max(result) * 1.25
356    min_value = min(result) * 1.25
357
358    # Cosine similarity has range [-1, 1], so we set y-axis limits accordingly.
359    if metric_name == "cosine_similarity":
360        max_value = 1.0
361        if min_value >= 0:
362            min_value = 0
363        else:
364            min_value = -1.0
365
366    plt.ylim(min(0, min_value), max(0, max_value))
367
368    plt.savefig(f"{metric_name}_output_plot.png")  # Save the plot to a file
369    plt.show()
370
371
372def calculate_mse(ref_values: ProgramOutput, values: ProgramOutput):
373    def mean_squared_error(a: torch.Tensor, b: torch.Tensor):
374        return round((torch.pow((a - b).to(torch.float32), 2)).mean().item(), 2)
375
376    results = []
377    for ref_value, value in zip(ref_values, values):
378        # TODO T171811011: extend the implementation of each metrics function to support value types other than tensor type
379        if isinstance(ref_value, torch.Tensor) and isinstance(value, torch.Tensor):
380            results.append(mean_squared_error(ref_value, value))
381        else:
382            results.append(None)
383
384    return results
385
386
387def calculate_snr(ref_values: ProgramOutput, values: ProgramOutput):
388    def signal_to_noise(signal: torch.Tensor, noise: torch.Tensor):
389        signal = signal.type(torch.float32)
390        noise = noise.type(torch.float32)
391        signal_power = torch.mean(torch.pow(signal, 2))
392        noise_power = torch.mean(torch.pow(noise, 2))
393        snr = 10 * torch.log10(signal_power / noise_power)
394        return round(snr.item(), 2)
395
396    results = []
397    for ref_value, value in zip(ref_values, values):
398        # TODO T171811011: extend the implementation of each metrics function to support value types other than tensor type
399        if isinstance(ref_value, torch.Tensor) and isinstance(value, torch.Tensor):
400            diff = ref_value - value
401            snr = signal_to_noise(ref_value, diff)
402            results.append(snr)
403        else:
404            results.append(None)
405
406    return results
407
408
409def calculate_cosine_similarity(ref_values: ProgramOutput, values: ProgramOutput):
410    def cosine_similarity(tensor1: torch.Tensor, tensor2: torch.Tensor):
411        # Ensure that the tensors have the same shape
412        if tensor1.shape != tensor2.shape:
413            raise ValueError("Input tensors must have the same shape")
414
415        # Calculate the dot product
416        dot_product = torch.sum(tensor1 * tensor2)
417
418        # Calculate the magnitudes
419        magnitude1 = torch.sqrt(torch.sum(torch.pow(tensor1, 2)))
420        magnitude2 = torch.sqrt(torch.sum(torch.pow(tensor2, 2)))
421
422        # Calculate the cosine similarity
423        similarity = dot_product / (magnitude1 * magnitude2)
424
425        return round(similarity.item(), 2)  # Convert the result to a Python float
426
427    results = []
428    for ref_value, value in zip(ref_values, values):
429        # TODO T171811011: extend the implementation of each metrics function to support value types other than tensor type
430        if isinstance(ref_value, torch.Tensor) and isinstance(value, torch.Tensor):
431            results.append(cosine_similarity(ref_value, value))
432        else:
433            results.append(None)
434
435    return results
436
437
438def compare_results(
439    reference_output: ProgramOutput,
440    run_output: ProgramOutput,
441    metrics: Optional[List[str]] = None,
442    plot: bool = False,
443) -> Dict[str, List[float]]:
444    """
445    Compares the results of two runs and returns a dictionary of metric names -> lists of metric values. This list matches
446    the reference output & run output lists, so essentially we compare each pair of values in those two lists.
447
448    Args:
449        reference_output: Reference program output.
450        run_output: Program output to compare with reference output.
451        metrics: List of requested metric names. Defaults to all available metrics.
452        plot: Whether to plot the results.
453
454    Returns:
455        Dictionary of metric names to lists of float values.
456    """
457
458    results = {}
459    metrics_functions = {
460        "snr": calculate_snr,
461        "mse": calculate_mse,
462        "cosine_similarity": calculate_cosine_similarity,
463    }
464    for supported_metric in metrics_functions:
465        if metrics is None or supported_metric in metrics:
466            result = metrics_functions[supported_metric](reference_output, run_output)
467            results[supported_metric] = result
468
469            if plot:
470                plot_metric(result, supported_metric)
471            else:
472                print(supported_metric)
473                print("-" * 20)
474                for index, value in enumerate(result):
475                    print(f"{index:<5}{value:>8.5f}")
476                print("\n")
477
478    return results
479