1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9""" 10Intermediate Representation of ExecuTorch Concepts in Developer Tools 11""" 12 13from __future__ import annotations 14 15import operator 16import warnings 17 18from collections import defaultdict 19 20from enum import Enum 21from types import NoneType 22from typing import Any, Dict, List, Optional, Set, Tuple 23 24import torch 25from executorch import exir 26from executorch.devtools.debug_format.base_schema import ( 27 Node, 28 OperatorGraph, 29 OperatorNode, 30 ValueNode, 31) 32from torch._subclasses import FakeTensor 33 34 35# Keywords used in debug_format Metadata 36class RESERVED_METADATA_ARG(Enum): 37 DEBUG_HANDLE = "debug_handle" 38 MODULE_STACK = "nn_module_stack" 39 SOURCE_FN_STACK = "source_fn_stack" 40 MODULE_TYPE = "module_type" 41 PROFILE_START_TIME = "profile_start_time" 42 PROFILE_END_TIME = "profile_end_time" 43 LOAD_START_TIME = "load_start_time" 44 LOAD_END_TIME = "load_end_time" 45 MEMORY_USAGE = "memory_usage" 46 DEBUG_ENTRY = "debug_entry" 47 STACK_TRACE = "stack_trace" 48 DEBUG_DATA = "debug_data" 49 50 METRICS_KEYWORD = "metrics" 51 PROFILE_SUMMARY_COLDSTART = "Coldstart" 52 PROFILE_SUMMARY_AVERAGE = "Average" 53 PROFILE_SUMMARY_P90 = "P90" 54 PROFILE_SUMMARY_P10 = "P10" 55 PROFILE_SUMMARY_MIN = "Min" 56 PROFILE_SUMMARY_MAX = "Max" 57 58 AGGREGATED_OP_TABLE = "Aggregated op stats" 59 RUN_SUMMARY_INDIVIDUAL_RUNS_TABLE = "Run summary individual stats" 60 RUN_SUMMARY_TABLE = "Aggregated run summary stats" 61 OP_INSTANCE_SUMMARY_TABLE = "Individual op stats" 62 63 TABLES_KEYWORD = "tables" 64 KV_KEYWORD = "kv" 65 MODEL_LOAD_TIME_KEY = "Model load time (ms)" 66 67 68# Representation of an FX GraphModule as an OperatorGraph 69class FXOperatorGraph(OperatorGraph): 70 @staticmethod 71 def _get_node_name(node: torch.fx.Node) -> str: 72 if node.target == operator.getitem: 73 # pyre-ignore[9]: Incompatible variable type 74 node = node.args[0] 75 assert isinstance( 76 node, torch.fx.Node 77 ), f"First argument of getitem must be a torch fx node. Got {node.args[0]}" 78 79 # Adding the "_" to the node name prevents TensorBoard from collapsing 80 # nodes with similar names that only differ by an integer at the end of 81 # their name. 82 return node.name + "_" 83 84 @staticmethod 85 def _get_op_name(node: torch.fx.Node) -> str: 86 # pyre-ignore[16]: has no attribute `__name__`. 87 return node.target.__name__ 88 89 # Given a node and its metadata (if containing module stack), update the provided module mappings 90 @staticmethod 91 def _update_module_mapping( 92 node: Node, 93 module_mapping: Dict[Tuple[str, str], List[Node]], 94 metadata: Dict[str, Any], 95 ): 96 if ( 97 source_fn_stack := metadata.get("source_fn_stack") 98 ) is not None and "nn_module_stack" in metadata: 99 # (module name, module type) 100 source_fn = source_fn_stack[-1] 101 module_type = ( 102 source_fn[1] if isinstance(source_fn[1], str) else source_fn[1].__name__ 103 ) 104 module_mapping[(source_fn[0], module_type)].append(node) 105 106 @staticmethod 107 def _parse_args( # noqa: C901 108 node: torch.fx.Node, 109 nodes: Dict[str, Node], 110 const_count: int, 111 module_mapping: Dict[Tuple[str, str], List[Node]], 112 enable_module_hierarchy: bool, 113 ) -> Tuple[List[Node], int]: 114 inputs = [] 115 op = node.op 116 name = node.name 117 args = node.args 118 kwargs = node.kwargs 119 named_args = None 120 if node.op == "call_function" and hasattr(node.target, "_schema"): 121 # pyre-ignore 122 named_args = node.target._schema.arguments 123 124 for index, arg in enumerate(args): 125 if isinstance(arg, torch.fx.node.Node): 126 if arg.target == exir.memory.alloc: 127 continue 128 arg_name = FXOperatorGraph._get_node_name(arg) 129 elif isinstance(arg, (int, float, torch.dtype, str)): 130 # e.g. The "0" from node.args of squeeze_copy (mm_default, 0) 131 if named_args and len(named_args) > index: 132 arg_name = named_args[index].name + "_" + str(const_count) 133 else: 134 arg_name = "CONST_" + str(const_count) 135 const_count += 1 136 const_node = ValueNode(arg_name, val=str(arg)) 137 nodes[arg_name] = const_node 138 if enable_module_hierarchy: 139 FXOperatorGraph._update_module_mapping( 140 const_node, module_mapping, node.meta 141 ) 142 elif isinstance(arg, list): 143 arg_name: List[str] = [] 144 for list_arg in arg: 145 if isinstance(list_arg, (int, float)): 146 # Consider the whole list of ints/floats as a single constant and 147 # stringify that. 148 if named_args and len(named_args) > index: 149 arg_name = [named_args[index].name + "_" + str(const_count)] 150 else: 151 arg_name = ["CONST_" + str(const_count)] 152 const_count += 1 153 const_node = ValueNode(arg_name[0], val=arg) 154 nodes[arg_name[0]] = const_node 155 if enable_module_hierarchy: 156 FXOperatorGraph._update_module_mapping( 157 const_node, module_mapping, node.meta 158 ) 159 break 160 elif isinstance(list_arg, torch.fx.node.Node): 161 arg_name += [FXOperatorGraph._get_node_name(list_arg)] 162 elif list_arg is None: 163 arg_name += ["CONST_NONE" + str(const_count)] 164 const_count += 1 165 const_node = ValueNode(arg_name[-1], val=str(arg)) 166 nodes[arg_name[-1]] = const_node 167 if enable_module_hierarchy: 168 FXOperatorGraph._update_module_mapping( 169 const_node, module_mapping, node.meta 170 ) 171 else: 172 raise Exception( 173 f"Unsupported argument encountered in list {arg}, {type(arg[0])}" 174 ) 175 elif isinstance(arg, NoneType): 176 continue 177 else: 178 raise Exception( 179 f"Unsupported argument encountered {op}, {name}, {arg}, {type(arg)}" 180 ) 181 182 if isinstance(arg_name, list): 183 for val in arg_name: 184 inputs.append(nodes[val]) 185 else: 186 inputs.append(nodes[arg_name]) 187 for _, node in kwargs.items(): 188 # We can ignore the out kwarg as that's mostly used to pass in the output tensor 189 # which has been memory planned. The same value is also returned by the operator 190 # which is then consumed by other nodes in the graph. 191 if ( 192 isinstance(node, torch.fx.node.Node) 193 and node.target == exir.memory.alloc 194 ): 195 continue 196 else: 197 warnings.warn( 198 f"Unsupported kwarg encountered: {name}, {kwargs}", stacklevel=1 199 ) 200 201 return inputs, const_count 202 203 # Given an FX GraphModule, parse it into an OperatorGraph 204 @staticmethod 205 def gen_operator_graph( 206 model: torch.fx.GraphModule, 207 skip_stack_trace: Optional[bool] = False, 208 enable_module_hierarchy: bool = False, 209 ) -> FXOperatorGraph: 210 graph: torch.fx.Graph = model.graph 211 212 nodes = {} 213 input_nodes = {} 214 output_nodes = {} 215 out_variant_output_nodes = set() 216 module_mapping = defaultdict(list) 217 218 const_count = 0 219 for fx_node in graph.nodes: 220 if ( 221 fx_node.target == exir.memory.alloc 222 or fx_node.target == operator.getitem 223 ): 224 continue 225 op = fx_node.op 226 name = FXOperatorGraph._get_node_name(fx_node) 227 target = fx_node.target 228 args = fx_node.args 229 kwargs = fx_node.kwargs 230 metadata = FXOperatorGraph._extract_metadata(fx_node.meta, skip_stack_trace) 231 output_shapes = FXOperatorGraph._extract_output_shapes( 232 fx_node.meta.get("val") 233 ) 234 dtype = FXOperatorGraph._extract_output_dtype(fx_node.meta.get("val")) or "" 235 236 assert ( 237 op != "call_module" 238 ), f"Module call not yet supported in edge model graph [.toEdge()]: {name}, {str(target)}" 239 assert ( 240 op != "call_method" 241 ), f"Call Method not yet supported in edge model graph [.toEdge()]: {name}, {str(target)}" 242 243 # Input 244 if op == "placeholder": 245 node = ValueNode( 246 name, 247 output_shapes=output_shapes, 248 metadata=metadata, 249 dtype=str(dtype), 250 ) # val is default arg 251 input_nodes[name] = node 252 # Constants 253 elif op == "get_attr": 254 node = ValueNode( 255 name, 256 output_shapes=output_shapes, 257 metadata=metadata, 258 dtype=str(dtype), 259 ) 260 # Output 261 elif op == "output": 262 assert len(args) == 1 263 # Args of op=='output' is a wrapped list of return nodes ([ret_1, ret_2, ...], ) 264 in_nodes = [ 265 ( 266 nodes[FXOperatorGraph._get_node_name(ret)] 267 if ret is not None 268 else [] 269 ) 270 for ret in args[0] 271 ] 272 node = ValueNode( 273 name, 274 inputs=in_nodes, 275 output_shapes=output_shapes, 276 metadata=metadata, 277 dtype=str(dtype), 278 ) 279 output_nodes[name] = node 280 # Op Calls 281 elif op == "call_function": 282 inputs, const_count = FXOperatorGraph._parse_args( 283 fx_node, nodes, const_count, module_mapping, enable_module_hierarchy 284 ) 285 named_args = [] 286 if fx_node.op == "call_function" and hasattr(fx_node.target, "_schema"): 287 named_args = [arg.name for arg in fx_node.target._schema.arguments] 288 node = OperatorNode( 289 name, 290 inputs=inputs, 291 output_shapes=output_shapes, 292 metadata=metadata, 293 op=FXOperatorGraph._get_op_name(fx_node), 294 named_args=named_args, 295 ) 296 if enable_module_hierarchy: 297 FXOperatorGraph._update_module_mapping( 298 node, module_mapping, fx_node.meta 299 ) 300 301 for kwarg_name, kwarg in kwargs.items(): 302 if ( 303 isinstance(kwarg, torch.fx.node.Node) 304 and kwarg.target == exir.memory.alloc 305 and kwarg_name == "out" 306 ): 307 nodes[FXOperatorGraph._get_node_name(kwarg)] = node 308 out_variant_output_nodes.add( 309 FXOperatorGraph._get_node_name(kwarg) 310 ) 311 else: 312 raise Exception(f"Unsupported op type encountered {op}, {name}") 313 314 nodes[name] = node 315 return FXOperatorGraph._compose_op_graph( 316 "base", 317 nodes, 318 input_nodes, 319 output_nodes, 320 out_variant_output_nodes, 321 module_mapping, 322 ) 323 324 @staticmethod 325 def _compose_op_graph( 326 name: str, 327 nodes: Dict[str, Node], 328 input_nodes: Dict[ 329 str, Node | OperatorGraph 330 ], # Never OperatorGraph, annotated for Pyre 331 output_nodes: Dict[ 332 str, Node | OperatorGraph 333 ], # Never OperatorGraph, annotated for Pyre 334 out_variant_output_nodes: Set[str], 335 module_mapping: Dict[ 336 Tuple[str, str], List[Any] 337 ], # Any used here for Pyre, list of Nodes 338 ): 339 # Generate Module Graphs 340 module_graphs: List[OperatorGraph] = [] 341 for module_key, module_nodes in module_mapping.items(): 342 module_element = OperatorGraph( 343 graph_name=module_key[0], 344 elements=module_nodes, 345 metadata={"module_type": module_key[1]}, 346 ) 347 module_graphs.append(module_element) 348 349 # Remove module modes from main graph 350 for node in module_nodes: 351 nodes.pop(node.name) 352 353 main_nodes = [ 354 node 355 for name, node in nodes.items() 356 if name not in input_nodes 357 and name not in output_nodes 358 and name not in out_variant_output_nodes 359 ] 360 main_graph = FXOperatorGraph( 361 graph_name="forward", elements=main_nodes + module_graphs 362 ) 363 input_graph = FXOperatorGraph( 364 graph_name="inputs", elements=list(input_nodes.values()) 365 ) 366 output_graph = FXOperatorGraph( 367 graph_name="outputs", elements=list(output_nodes.values()) 368 ) 369 370 return FXOperatorGraph( 371 graph_name=name, 372 elements=[input_graph, main_graph, output_graph], 373 ) 374 375 # Given a dict, extract only the utilized metadata 376 @staticmethod 377 def _extract_metadata( 378 metadata: Dict[str, Any], skip_stack_trace: Optional[bool] = False 379 ) -> Dict[str, Any]: 380 ret = {} 381 if RESERVED_METADATA_ARG.DEBUG_HANDLE.value in metadata: 382 ret[RESERVED_METADATA_ARG.DEBUG_HANDLE.value] = metadata[ 383 RESERVED_METADATA_ARG.DEBUG_HANDLE.value 384 ] 385 if not skip_stack_trace and RESERVED_METADATA_ARG.STACK_TRACE.value in metadata: 386 ret[RESERVED_METADATA_ARG.STACK_TRACE.value] = metadata[ 387 RESERVED_METADATA_ARG.STACK_TRACE.value 388 ] 389 if RESERVED_METADATA_ARG.MODULE_STACK.value in metadata: 390 ret[RESERVED_METADATA_ARG.MODULE_STACK.value] = metadata[ 391 RESERVED_METADATA_ARG.MODULE_STACK.value 392 ] 393 return ret 394 395 @staticmethod 396 def _extract_output_shapes(val: Any) -> Optional[List[List[int]]]: 397 if isinstance(val, (FakeTensor, torch.Tensor)): 398 # If val is a single tensor 399 return [list(val.shape)] 400 elif isinstance(val, tuple) and all( 401 isinstance(tensor, (FakeTensor, torch.Tensor)) for tensor in val 402 ): 403 # If val is a tuple of tensors 404 shapes = [list(fake_tensor.shape) for fake_tensor in val] 405 return shapes 406 else: 407 return None 408 409 @staticmethod 410 def _extract_output_dtype(val: Any) -> Optional[List[torch.dtype]]: 411 if isinstance(val, (FakeTensor, torch.Tensor)): 412 # If val is a single tensor 413 return [val.dtype] 414 elif isinstance(val, tuple) and all( 415 isinstance(tensor, (FakeTensor, torch.Tensor)) for tensor in val 416 ): 417 # If val is a tuple of tensors 418 dtypes = [fake_tensor.dtype for fake_tensor in val] 419 return dtypes 420 else: 421 return None 422