1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9import math 10import sys 11from enum import Enum 12from typing import Dict, IO, List, Mapping, Optional, Tuple, TypeAlias, Union 13 14import executorch.devtools.etdump.schema_flatcc as flatcc 15 16import pandas as pd 17 18import torch 19 20from executorch.devtools.debug_format.base_schema import OperatorNode 21 22from executorch.devtools.debug_format.et_schema import FXOperatorGraph, OperatorGraph 23from executorch.devtools.etdump.schema_flatcc import ( 24 DebugEvent, 25 ETDumpFlatCC, 26 ProfileEvent, 27 ScalarType, 28 Tensor, 29 Value, 30 ValueType, 31) 32 33from executorch.devtools.etdump.serialize import deserialize_from_etdump_flatcc 34from executorch.devtools.etrecord import ETRecord 35 36from tabulate import tabulate 37 38FORWARD = "forward" 39EDGE_DIALECT_GRAPH_KEY = "edge_dialect_graph_module" 40 41RESERVED_FRAMEWORK_EVENT_NAMES = [ 42 "Method::init", 43 "Program::load_method", 44 "Method::execute", 45] 46EXCLUDED_COLUMNS_WHEN_PRINTING = [ 47 "raw", 48 "delegate_debug_identifier", 49 "stack_traces", 50 "module_hierarchy", 51 "debug_data", 52] 53EXCLUDED_EVENTS_WHEN_PRINTING = {"OPERATOR_CALL"} 54 55 56class TimeScale(Enum): 57 NS = "ns" 58 US = "us" 59 MS = "ms" 60 S = "s" 61 CYCLES = "cycles" 62 63 64TIME_SCALE_DICT = { 65 TimeScale.NS: 1000000000, 66 TimeScale.US: 1000000, 67 TimeScale.MS: 1000, 68 TimeScale.S: 1, 69 TimeScale.CYCLES: 1, 70} 71 72 73def calculate_time_scale_factor( 74 source_time_scale: TimeScale, target_time_scale: TimeScale 75) -> float: 76 """ 77 Calculate the factor (source divided by target) between two time scales 78 """ 79 return TIME_SCALE_DICT[source_time_scale] / TIME_SCALE_DICT[target_time_scale] 80 81 82# Model Debug Output 83InferenceOutput: TypeAlias = Union[ 84 torch.Tensor, List[torch.Tensor], int, float, str, bool, None 85] 86ProgramOutput: TypeAlias = List[InferenceOutput] 87 88 89# Compare whether two InferenceOutputs are equal 90def is_inference_output_equal( 91 output1: InferenceOutput, output2: InferenceOutput 92) -> bool: 93 if isinstance(output1, torch.Tensor) and isinstance(output2, torch.Tensor): 94 return torch.equal(output1, output2) 95 elif isinstance(output1, List) and isinstance(output2, List): 96 return all(torch.equal(t1, t2) for t1, t2 in zip(output1, output2)) 97 elif output1 == output2: 98 return True 99 else: 100 return False 101 102 103# Given a ETDump Tensor object and offset, extract into a torch.Tensor 104def _parse_tensor_value( 105 tensor: Optional[Tensor], output_buffer: Optional[bytes] 106) -> torch.Tensor: 107 def get_scalar_type_size(scalar_type: ScalarType) -> Tuple[torch.dtype, int]: 108 """ 109 Return the size of the scalar type in bytes 110 """ 111 get_scalar_type_size_map = { 112 ScalarType.BYTE: (torch.uint8, 1), 113 ScalarType.CHAR: (torch.int8, 1), 114 ScalarType.BOOL: (torch.bool, 1), 115 ScalarType.BITS16: (torch.uint16, 2), 116 ScalarType.SHORT: (torch.int16, 2), 117 ScalarType.HALF: (torch.float16, 2), 118 ScalarType.INT: (torch.int, 4), 119 ScalarType.FLOAT: (torch.float, 4), 120 ScalarType.DOUBLE: (torch.double, 8), 121 ScalarType.LONG: (torch.long, 8), 122 } 123 if scalar_type in get_scalar_type_size_map: 124 return get_scalar_type_size_map[scalar_type] 125 else: 126 raise RuntimeError( 127 f"Unsupported scalar type in get_scalar_type_size : {scalar_type}" 128 ) 129 130 if tensor is None or tensor.offset is None: 131 raise ValueError("Tensor cannot be None") 132 133 torch_dtype, dtype_size = get_scalar_type_size(tensor.scalar_type) 134 135 if output_buffer is None: 136 # Empty buffer provided. Cannot deserialize tensors. 137 return torch.zeros(tensor.sizes, dtype=torch_dtype) 138 139 tensor_bytes_size = math.prod(tensor.sizes) * dtype_size 140 if tensor_bytes_size == 0: 141 # Empty tensor. Return empty tensor. 142 return torch.zeros(tensor.sizes, dtype=torch_dtype) 143 144 if tensor.offset is None: 145 raise ValueError("Tensor offset cannot be None") 146 147 return torch.frombuffer( 148 output_buffer[tensor.offset : tensor.offset + tensor_bytes_size], 149 dtype=torch_dtype, 150 ).view(tensor.sizes) 151 152 153def inflate_runtime_output( 154 value: Value, output_buffer: Optional[bytes] 155) -> InferenceOutput: 156 """ 157 Parse the given ETDump Value object into an InferenceOutput object 158 """ 159 160 if value.val == ValueType.INT.value: 161 if value.int_value is None: 162 raise ValueError("Expected Int value, `None` provided") 163 return value.int_value.int_val 164 if value.val == ValueType.BOOL.value: 165 if value.bool_value is None: 166 raise ValueError("Expected Bool value, `None` provided") 167 return value.bool_value.bool_val 168 if value.val == ValueType.FLOAT.value: 169 if value.float_value is None: 170 raise ValueError("Expected Float value, `None` provided") 171 return value.float_value.float_val 172 if value.val == ValueType.DOUBLE.value: 173 if value.double_value is None: 174 raise ValueError("Expected Double value, `None` provided") 175 return value.double_value.double_val 176 if value.val == ValueType.TENSOR.value: 177 return _parse_tensor_value(value.tensor, output_buffer) 178 if value.val == ValueType.TENSOR_LIST.value: 179 if value.tensor_list is None: 180 raise ValueError("Expected TensorList value, `None` provided") 181 return [ 182 _parse_tensor_value(t, output_buffer) for t in value.tensor_list.tensors 183 ] 184 185 186def find_populated_event(event: flatcc.Event) -> Union[ProfileEvent, DebugEvent]: 187 """ 188 Given a ETDump Event object, find the populated event 189 190 Raise an error if no populated event can be found 191 """ 192 if event.profile_event is not None: 193 return event.profile_event 194 195 if event.debug_event is not None: 196 return event.debug_event 197 198 raise ValueError("Unable to find populated event") 199 200 201# TODO: Optimize by verifying prior to inflating the tensors 202def verify_debug_data_equivalence( 203 existing_data: ProgramOutput, new_data: ProgramOutput 204) -> None: 205 """ 206 Verify that the lists of inference_outputs are equivalent 207 208 Raises an corresponding errors if they are not 209 """ 210 assert len(existing_data) == len( 211 new_data 212 ), "Unequal debug data length encountered. Expected to be equal." 213 214 for output_a, output_b in zip(existing_data, new_data): 215 assert isinstance( 216 output_a, type(output_b) 217 ), "Debug Data Types are different. Expected to be equal." 218 219 if isinstance(output_a, torch.Tensor): 220 assert bool( 221 # pyre-fixme[6]: For 1st argument expected `Tensor` but got `bool`. 222 torch.all(output_a == output_b) 223 ), "Tensors Debug Data is different. Expected to be equal." 224 else: 225 assert ( 226 output_a == output_b 227 ), "Scalar Debug Data is different. Expected to be equal" 228 229 230def is_debug_output(value: Value) -> bool: 231 """ 232 Returns True if the given flatcc.Value is a debug output 233 """ 234 return value.output is not None and value.output.bool_val 235 236 237def gen_graphs_from_etrecord( 238 etrecord: ETRecord, enable_module_hierarchy: bool = False 239) -> Mapping[str, OperatorGraph]: 240 op_graph_map = {} 241 if etrecord.graph_map is not None: 242 op_graph_map = { 243 name: FXOperatorGraph.gen_operator_graph( 244 exported_program.graph_module, 245 enable_module_hierarchy=enable_module_hierarchy, 246 ) 247 for name, exported_program in etrecord.graph_map.items() 248 } 249 if etrecord.edge_dialect_program is not None: 250 op_graph_map[EDGE_DIALECT_GRAPH_KEY] = FXOperatorGraph.gen_operator_graph( 251 etrecord.edge_dialect_program.graph_module, 252 enable_module_hierarchy=enable_module_hierarchy, 253 ) 254 255 return op_graph_map 256 257 258def create_debug_handle_to_op_node_mapping( 259 op_graph: OperatorGraph, 260) -> Dict[int, OperatorNode]: 261 """ 262 Recursive function to traverse all the operator graph nodes of input op_graph and build a mapping 263 from each debug handle to the operator node that contains the debug handle in its metadata. 264 """ 265 debug_handle_to_op_node_map: Dict[int, OperatorNode] = {} 266 267 # Recursively searches through the metadata of nodes 268 def _extract_debug_handles(graph: OperatorGraph): 269 for element in graph.elements: 270 if isinstance(element, OperatorGraph): 271 _extract_debug_handles(element) 272 if isinstance(element, OperatorNode) and element.metadata is not None: 273 metadata = element.metadata 274 debug_handle = metadata.get("debug_handle") 275 if debug_handle is not None: 276 existing_entry = debug_handle_to_op_node_map.get(debug_handle) 277 if existing_entry is not None: 278 raise ValueError( 279 f"Duplicated debug handle {str(debug_handle)} shared between {element.name} and {existing_entry.name}. " 280 "No two op nodes of the same graph should have the same debug handle." 281 ) 282 debug_handle_to_op_node_map[debug_handle] = element 283 284 # Start traversing 285 _extract_debug_handles(op_graph) 286 return debug_handle_to_op_node_map 287 288 289def gen_etdump_object( 290 etdump_path: Optional[str] = None, etdump_data: Optional[bytes] = None 291) -> ETDumpFlatCC: 292 # Gen event blocks from etdump 293 if etdump_data is None and etdump_path is not None: 294 with open(etdump_path, "rb") as buff: 295 etdump_data = buff.read() 296 297 if etdump_data is None: 298 raise ValueError( 299 "Unable to get ETDump data. One and only one of etdump_path and etdump_data must be specified." 300 ) 301 302 return deserialize_from_etdump_flatcc(etdump_data) 303 304 305def display_or_print_df(df: pd.DataFrame, file: IO[str] = sys.stdout): 306 try: 307 from IPython import get_ipython 308 from IPython.display import display 309 310 def style_text_size(val, size=12): 311 return f"font-size: {size}px" 312 313 if get_ipython() is not None: 314 styled_df = df.style.applymap(style_text_size) 315 display(styled_df) 316 else: 317 raise Exception( 318 "Environment unable to support IPython. Fall back to print()." 319 ) 320 except: 321 print( 322 tabulate(df, headers="keys", tablefmt="fancy_grid"), 323 file=file, 324 ) 325 326 327def plot_metric(result: List[float], metric_name: str): 328 import matplotlib.pyplot as plt 329 import numpy as np 330 331 # Clear the current figure, otherwise this plot will be on top of previous plots 332 plt.clf() 333 plt.figure(figsize=(8, 6)) 334 335 x_axis = np.arange(len(result)) 336 bars = plt.bar(x_axis, result, width=0.5) 337 plt.grid(True, which="major", axis="y") 338 num_ticks = len(x_axis) if len(x_axis) > 5 else 5 339 interval = 1 if num_ticks < 20 else 5 340 plt.xticks(list(range(num_ticks))[::interval]) 341 plt.xlabel("Output value index") 342 plt.ylabel(metric_name) 343 plt.title(f"Metric {metric_name}") 344 345 # Add value annotations to each bar 346 for bar, value in zip(bars, result): 347 plt.text( 348 bar.get_x() + bar.get_width() / 2, 349 bar.get_height(), 350 str(value), 351 ha="center", 352 va="bottom", 353 ) 354 355 max_value = max(result) * 1.25 356 min_value = min(result) * 1.25 357 358 # Cosine similarity has range [-1, 1], so we set y-axis limits accordingly. 359 if metric_name == "cosine_similarity": 360 max_value = 1.0 361 if min_value >= 0: 362 min_value = 0 363 else: 364 min_value = -1.0 365 366 plt.ylim(min(0, min_value), max(0, max_value)) 367 368 plt.savefig(f"{metric_name}_output_plot.png") # Save the plot to a file 369 plt.show() 370 371 372def calculate_mse(ref_values: ProgramOutput, values: ProgramOutput): 373 def mean_squared_error(a: torch.Tensor, b: torch.Tensor): 374 return round((torch.pow((a - b).to(torch.float32), 2)).mean().item(), 2) 375 376 results = [] 377 for ref_value, value in zip(ref_values, values): 378 # TODO T171811011: extend the implementation of each metrics function to support value types other than tensor type 379 if isinstance(ref_value, torch.Tensor) and isinstance(value, torch.Tensor): 380 results.append(mean_squared_error(ref_value, value)) 381 else: 382 results.append(None) 383 384 return results 385 386 387def calculate_snr(ref_values: ProgramOutput, values: ProgramOutput): 388 def signal_to_noise(signal: torch.Tensor, noise: torch.Tensor): 389 signal = signal.type(torch.float32) 390 noise = noise.type(torch.float32) 391 signal_power = torch.mean(torch.pow(signal, 2)) 392 noise_power = torch.mean(torch.pow(noise, 2)) 393 snr = 10 * torch.log10(signal_power / noise_power) 394 return round(snr.item(), 2) 395 396 results = [] 397 for ref_value, value in zip(ref_values, values): 398 # TODO T171811011: extend the implementation of each metrics function to support value types other than tensor type 399 if isinstance(ref_value, torch.Tensor) and isinstance(value, torch.Tensor): 400 diff = ref_value - value 401 snr = signal_to_noise(ref_value, diff) 402 results.append(snr) 403 else: 404 results.append(None) 405 406 return results 407 408 409def calculate_cosine_similarity(ref_values: ProgramOutput, values: ProgramOutput): 410 def cosine_similarity(tensor1: torch.Tensor, tensor2: torch.Tensor): 411 # Ensure that the tensors have the same shape 412 if tensor1.shape != tensor2.shape: 413 raise ValueError("Input tensors must have the same shape") 414 415 # Calculate the dot product 416 dot_product = torch.sum(tensor1 * tensor2) 417 418 # Calculate the magnitudes 419 magnitude1 = torch.sqrt(torch.sum(torch.pow(tensor1, 2))) 420 magnitude2 = torch.sqrt(torch.sum(torch.pow(tensor2, 2))) 421 422 # Calculate the cosine similarity 423 similarity = dot_product / (magnitude1 * magnitude2) 424 425 return round(similarity.item(), 2) # Convert the result to a Python float 426 427 results = [] 428 for ref_value, value in zip(ref_values, values): 429 # TODO T171811011: extend the implementation of each metrics function to support value types other than tensor type 430 if isinstance(ref_value, torch.Tensor) and isinstance(value, torch.Tensor): 431 results.append(cosine_similarity(ref_value, value)) 432 else: 433 results.append(None) 434 435 return results 436 437 438def compare_results( 439 reference_output: ProgramOutput, 440 run_output: ProgramOutput, 441 metrics: Optional[List[str]] = None, 442 plot: bool = False, 443) -> Dict[str, List[float]]: 444 """ 445 Compares the results of two runs and returns a dictionary of metric names -> lists of metric values. This list matches 446 the reference output & run output lists, so essentially we compare each pair of values in those two lists. 447 448 Args: 449 reference_output: Reference program output. 450 run_output: Program output to compare with reference output. 451 metrics: List of requested metric names. Defaults to all available metrics. 452 plot: Whether to plot the results. 453 454 Returns: 455 Dictionary of metric names to lists of float values. 456 """ 457 458 results = {} 459 metrics_functions = { 460 "snr": calculate_snr, 461 "mse": calculate_mse, 462 "cosine_similarity": calculate_cosine_similarity, 463 } 464 for supported_metric in metrics_functions: 465 if metrics is None or supported_metric in metrics: 466 result = metrics_functions[supported_metric](reference_output, run_output) 467 results[supported_metric] = result 468 469 if plot: 470 plot_metric(result, supported_metric) 471 else: 472 print(supported_metric) 473 print("-" * 20) 474 for index, value in enumerate(result): 475 print(f"{index:<5}{value:>8.5f}") 476 print("\n") 477 478 return results 479