1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import _operator 10import copy 11import json 12import logging 13import os 14import re 15from collections import defaultdict 16from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union 17 18import torch 19from executorch.exir import control_flow, memory, memory_planning 20from executorch.exir.common import override_logger 21from executorch.exir.delegate import executorch_call_delegate 22from executorch.exir.dialects.backend._ops import BackendOpOverload 23from executorch.exir.dialects.edge._ops import EdgeOpOverload 24from executorch.exir.dynamic_shape import DynamicMemoryPlanningMode 25from executorch.exir.error import InternalError 26from executorch.exir.operator.convert import ( 27 get_out_args_from_opoverload, 28 is_out_variant, 29 to_out_variant, 30 to_scratch_op, 31) 32 33from executorch.exir.pass_base import ExportPass 34from executorch.exir.pass_manager import PassManager, PassType 35from executorch.exir.passes.const_prop_pass import ConstPropPass 36from executorch.exir.passes.debug_handle_generator_pass import DebugHandleGeneratorPass 37 38from executorch.exir.passes.executorch_prim_ops_registry import _EXECUTORCH_SYM_OPS 39from executorch.exir.passes.insert_write_back_for_buffers_pass import ( 40 insert_write_back_for_buffers_pass, 41) 42from executorch.exir.passes.memory_format_ops_pass import MemoryFormatOpsPass 43from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass 44from executorch.exir.passes.normalize_transpose_pass import NormalizeTransposePass 45from executorch.exir.passes.quant_fusion_pass import QuantFusionPass 46from executorch.exir.passes.remove_noop_pass import RemoveNoopPass, RemoveToCopyPass 47from executorch.exir.passes.replace_aten_with_edge_pass import OpReplacePass 48from executorch.exir.passes.replace_broken_ops_with_function_ops_pass import ( 49 ReplaceBrokenOpsWithFunctionalOpsPass, 50) 51from executorch.exir.passes.replace_edge_with_backend_pass import EdgeToBackendOpsPass 52from executorch.exir.passes.replace_sym_size_op_pass import ReplaceSymSizeOpPass 53from executorch.exir.passes.scalar_to_tensor_pass import ScalarToTensorPass 54from executorch.exir.passes.spec_prop_pass import SpecPropPass 55from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass 56from executorch.exir.passes.sym_to_tensor_pass import SymToTensorPass 57from executorch.exir.passes.weights_to_outputs_pass import weights_to_outputs_pass 58from torch import fx 59from torch._subclasses import FakeTensor 60from torch.fx.passes.infra.pass_base import PassBase, PassResult 61from torch.fx.passes.shape_prop import TensorMetadata 62 63__all__ = [ 64 "ExportPass", 65 "ConstPropPass", 66 "QuantFusionPass", 67 "OpReplacePass", 68 "EdgeToBackendOpsPass", 69 "MemoryFormatOpsPass", 70 "MemoryPlanningPass", 71 "HintBasedSymShapeEvalPass", 72 "insert_write_back_for_buffers_pass", 73 "weights_to_outputs_pass", 74] 75 76Argument = Optional[ 77 Union[ 78 Tuple["Argument", ...], 79 List["Argument"], 80 Dict[str, "Argument"], 81 slice, 82 torch.fx.Node, 83 str, 84 int, 85 float, 86 bool, 87 complex, 88 torch.dtype, 89 torch.Tensor, 90 torch.device, 91 torch.memory_format, 92 torch.layout, 93 ] 94] 95 96 97def update_args( 98 args: Tuple[Argument, ...], key: int, val: torch.fx.Node 99) -> Tuple[Argument, ...]: 100 """ 101 A helper function to update an argument container without changing it. 102 This can be used with both args and kwargs. 103 """ 104 if isinstance(args, dict): 105 new_dict = copy.copy(args) 106 new_dict[key] = val 107 return new_dict 108 109 assert isinstance(args, tuple) 110 new_tuple = list(args) 111 new_tuple[key] = val 112 return tuple(new_tuple) 113 114 115class DebugPass(PassBase): 116 def __init__( 117 self, 118 msg: str = "", 119 enable_debug_pass: bool = True, 120 show_src: bool = False, 121 show_full_path: bool = False, 122 show_all_frames: bool = False, 123 path_filter: Optional[str] = None, 124 show_spec: bool = False, 125 log_filename: Optional[str] = None, 126 ) -> None: 127 """ 128 show_src: whether to show source code that generated each fx Node 129 show_full_path: whether to show the full path of source code or just the filename 130 show_all_frames: control for each node whether show only the last frame or all the frames. 131 path_filter: a regular expression to filter the path of the stackframes 132 log_filename: if provided, the output will also be written to this path. 133 Existing content in this file will be discarded. 134 """ 135 self.msg = msg 136 self.enable_debug_pass = enable_debug_pass 137 self.show_src = show_src 138 self.show_full_path = show_full_path 139 self.show_all_frames = show_all_frames 140 self.show_spec = show_spec 141 self.log_filename = log_filename 142 if path_filter: 143 self.path_filter_re = re.compile(path_filter) # pyre-ignore 144 else: 145 self.path_filter_re = None 146 147 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 148 """ 149 Counts the number of operations and call_funciton operations. 150 """ 151 if not self.enable_debug_pass: 152 return PassResult(graph_module, False) 153 # it doesn't make sense to mute the DebugPass if user already 154 # specify self.enable_debug_pass to be true 155 with override_logger(filename=self.log_filename): 156 self.callWithLoggerEnabled(graph_module) 157 return PassResult(graph_module, True) 158 159 def printFrames(self, node: fx.Node) -> None: 160 """ 161 The DebugPass maybe used for graph generated by both the old exir dispatch 162 tracer or the new pt2 tracer. 163 The former store 'stack_trace' field as a json string; 164 the latter store 'stack_trace' field as a free form string like: 165 ``` 166 File "/data/sandcastle/boxes/fbsource/buck-out/v2/gen/fbcode/20c706e99f51cf3a/executorch/test/end2end/__end2end__/end2end#link-tree/executorch/test/end2end/test_end2end.py", line 150, in forward 167 o = o * a 168 ``` 169 Make this method handle both format. In future, maybe we can drop the 170 support for old exir dispatch tracer. 171 """ 172 if ( 173 self.show_src 174 and "stack_trace" in node.meta 175 and len(node.meta["stack_trace"]) > 0 176 ): 177 try: 178 stack_trace = json.loads(node.meta["stack_trace"]) 179 is_json = True 180 except json.decoder.JSONDecodeError: 181 is_json = False 182 183 if not is_json: 184 logging.debug(node.meta["stack_trace"]) 185 return 186 187 frame_list = [] # tuple of filename, frame name, line number and line 188 for frame in stack_trace: 189 filename = frame["filename"] 190 name = frame["name"] 191 lineno = frame["lineno"] 192 line = frame["line"] 193 if not self.show_full_path: 194 filename = os.path.basename(filename) 195 mark = "#link-tree/" 196 if mark in filename: 197 filename = filename.split(mark)[-1] 198 199 if not self.path_filter_re or self.path_filter_re.search(filename): 200 frame_list.append((filename, name, lineno, line)) 201 202 if not self.show_all_frames: 203 frame_list = frame_list[-1:] 204 for filename, name, lineno, line in frame_list: 205 logging.debug(f" > {filename}:{lineno} in {name}: {line}") 206 207 def callWithLoggerEnabled(self, graph_module: torch.fx.GraphModule) -> None: 208 if self.msg: 209 logging.debug(self.msg) 210 logging.debug("Enter debug_pass") 211 graph_module.recompile() 212 logging.debug(f"Code is:\n{graph_module.code}") 213 op_to_cnt = defaultdict(int) # stats for op type 214 func_to_cnt = defaultdict(int) # stats for targets in call_function type 215 logging.debug("Nodes:") 216 idx = 0 217 for node in graph_module.graph.nodes: 218 # TODO: better to print python code along with TensorSpecs 219 logging.debug(f"{idx:4}: {node.format_node()}") 220 if self.show_spec: 221 specs = memory_planning.get_node_tensor_specs(node) 222 for spec in specs: 223 logging.debug(f" {spec.debug()}") 224 logging.debug(f" val: {node.meta.get('val', None)}") 225 self.printFrames(node) 226 idx += 1 227 op_to_cnt[node.op] += 1 228 229 if node.op == "call_function": 230 target = str(node.target) 231 func_to_cnt[target] += 1 232 233 logging.debug("-- node op type stat --") 234 for op, cnt in op_to_cnt.items(): 235 logging.debug(f" op {op}, cnt {cnt}") 236 237 logging.debug("-- call_function stat --") 238 for fn, cnt in func_to_cnt.items(): 239 logging.debug(f" fn {fn}, cnt {cnt}") 240 241 242# Skip these ops when converting to out variants. They will be handled and 243# removed by the emitter. 244# pyre-ignore 245to_out_var_skiplist: Set[Callable[[Any], Any]] = { 246 _operator.getitem, 247 torch.ops.higher_order.cond, 248 control_flow.while_loop, 249 # memory.alloc will be added after the to_out_variant pass so usually 250 # we won't see it in the input graph to the to_out_variant pass, unless 251 # it's retraced after running to_out_variant with the first trace. 252 memory.alloc, 253 memory.view, 254 executorch_call_delegate, 255 torch.ops.aten.copy_.default, 256} 257to_out_var_skiplist.update(_EXECUTORCH_SYM_OPS) 258 259 260def make_alloc_node( 261 graph_module: torch.fx.GraphModule, 262 val: Union[ 263 Optional[FakeTensor], List[Optional[FakeTensor]], Tuple[Optional[FakeTensor]] 264 ], 265 tensor_meta: Union[ 266 Optional[TensorMetadata], 267 List[Optional[TensorMetadata]], 268 Tuple[Optional[TensorMetadata]], 269 ], 270) -> torch.fx.Node: 271 """ 272 Note: tensor_metadata is only used in the case of a Tensor subclass, since 273 fakifying a tensor subclass is not supported right now 274 """ 275 if val is None: 276 if tensor_meta is not None: 277 assert isinstance(tensor_meta, TensorMetadata) 278 alloc_spec = (tensor_meta.shape, tensor_meta.dtype) 279 else: 280 raise InternalError( 281 "Memory allocator node needs FakeTensor val or TensorMetadata to proceed" 282 ) 283 elif isinstance(val, FakeTensor): 284 alloc_spec = (val.shape, val.dtype) 285 else: 286 assert isinstance(val, list) or isinstance(val, tuple) 287 assert isinstance(tensor_meta, list) or isinstance(tensor_meta, tuple) 288 alloc_spec: List[memory.AllocSpec] = [] 289 for v, t in zip(val, tensor_meta): 290 if v is not None: 291 # pyre-fixme[6]: For 1st argument expected 292 # `Union[List[Tuple[List[int], dtype]], Tuple[List[int], dtype]]` but 293 # got `Tuple[Size, dtype]`. 294 alloc_spec.append((v.shape, v.dtype)) 295 elif t is not None: 296 # pyre-fixme[6]: For 1st argument expected 297 # `Union[List[Tuple[List[int], dtype]], Tuple[List[int], dtype]]` but 298 # got `Tuple[Size, dtype]`. 299 alloc_spec.append((t.shape, t.dtype)) 300 else: 301 raise InternalError( 302 "Memory allocator node needs FakeTensor val or TensorMetadata to proceed" 303 ) 304 305 # pyre-fixme[6] 306 alloc = graph_module.graph.call_function(memory.alloc, (alloc_spec,)) 307 alloc.meta["val"] = val 308 alloc.meta["tensor_meta"] = tensor_meta 309 return alloc 310 311 312class ToOutVarPass(PassBase): 313 def __init__(self, ignore_to_out_var_failure: bool = False) -> None: 314 self.ignore_to_out_var_failure = ignore_to_out_var_failure 315 316 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901 317 """ 318 Converts all of the functions to contain an out variant if it does not exist 319 """ 320 missing_out_vars: Set[str] = set() 321 322 def get_submodule(node: torch.fx.Node) -> torch.fx.GraphModule: 323 assert node.op == "get_attr" 324 return getattr(graph_module, node.target) 325 326 for node in graph_module.graph.nodes: 327 if node.op != "call_function": 328 continue 329 330 target = node.target 331 if target == torch.ops.higher_order.cond: 332 self.call(get_submodule(node.args[1])) 333 self.call(get_submodule(node.args[2])) 334 continue 335 if target == torch.ops.higher_order.map_impl: 336 self.call(get_submodule(node.args[0])) 337 continue 338 elif target == control_flow.while_loop: 339 self.call(get_submodule(node.args[0])) 340 self.call(get_submodule(node.args[1])) 341 continue 342 elif getattr(target, "__module__", None) in ("builtins", "_operator"): 343 continue 344 elif target in to_out_var_skiplist: 345 continue 346 if not isinstance( 347 target, (torch._ops.OpOverload, EdgeOpOverload, BackendOpOverload) 348 ): 349 raise RuntimeError(f"Require an op overload for target: {target}") 350 351 op_name = target._schema.name 352 overload_name = target._schema.overload_name 353 if is_out_variant(op_name, overload_name): 354 # TODO (zhxchen17) Remove this after functionalization is always on. 355 if "out" in node.kwargs and isinstance(node.kwargs["out"], fx.Node): 356 out = node.kwargs["out"] 357 if out.target is not memory.alloc and len(out.users) == 1: 358 with graph_module.graph.inserting_before(node): 359 alloc = make_alloc_node( 360 graph_module, 361 node.meta["val"], 362 node.meta["tensor_meta"], 363 ) 364 out.replace_all_uses_with(alloc) 365 graph_module.graph.erase_node(out) 366 continue 367 368 try: 369 if isinstance(target, (EdgeOpOverload, BackendOpOverload)): 370 out_var_target = target.to_out_variant() 371 out_args_names = get_out_args_from_opoverload(out_var_target) 372 else: 373 out_var_target, out_args_names = to_out_variant(target) 374 except RuntimeError as e: 375 # pyre-fixme[16]: `GraphModule` has no attribute 376 # `encounter_to_out_var_failure`. 377 graph_module.encounter_to_out_var_failure = True 378 logging.info( 379 f"Failed converting '{target}' to its out variant with error: '{e}'" 380 ) 381 missing_out_vars.add(op_name) 382 continue 383 384 assert out_var_target 385 out_var_kwargs = {} 386 387 # Pool functional target's kwargs into out-variant's kwargs 388 for arg in out_var_target._schema.arguments: 389 if arg.name in out_args_names: 390 continue 391 if arg.name in node.kwargs: 392 out_var_kwargs[arg.name] = node.kwargs[arg.name] 393 394 with graph_module.graph.inserting_before(node): 395 if len(out_args_names) == 1: 396 alloc_node = make_alloc_node( 397 graph_module, node.meta["val"], node.meta["tensor_meta"] 398 ) 399 out_var_kwargs[out_args_names[0]] = alloc_node 400 if len(out_var_target._schema.returns) == 0: 401 node.replace_all_uses_with(alloc_node) 402 else: 403 # If the op has multiple out args, we assume the node's 404 # metadata contains a fake tensor with the same size and type 405 fake_tensor_list = node.meta["val"] 406 tensor_metadatas = node.meta["tensor_meta"] 407 assert isinstance( 408 fake_tensor_list, (list, tuple) 409 ), "Expected a list/tuple of tensors when the op has multiple out arguments" 410 assert len(out_args_names) == len( 411 fake_tensor_list 412 ), f"Expected {len(out_args_names)} tensor specs, but got {len(node.meta['val'])}" 413 for out_arg_name, val, tensor_meta in zip( 414 out_args_names, fake_tensor_list, tensor_metadatas 415 ): 416 if val is None: 417 out_var_kwargs[out_arg_name] = None 418 continue 419 assert isinstance(val, FakeTensor) 420 out_var_kwargs[out_arg_name] = make_alloc_node( 421 graph_module, val, tensor_meta 422 ) 423 424 node.target = out_var_target 425 node.kwargs = out_var_kwargs 426 427 if (not self.ignore_to_out_var_failure) and len(missing_out_vars) > 0: 428 raise RuntimeError(f"Missing out variants: {missing_out_vars}") 429 return PassResult(graph_module, True) 430 431 432def to_scratch_op_pass(graph_module: torch.fx.GraphModule) -> PassResult: 433 for node in graph_module.graph.nodes: 434 if node.op != "call_function": 435 continue 436 target = node.target 437 if not isinstance(target, torch._ops.OpOverload): 438 # ignore ops that are not OpOverload. Examples are operator.getitem, 439 # memory.alloc etc. 440 continue 441 442 scratch_op = to_scratch_op(target) 443 if not scratch_op: 444 continue 445 446 args_vals = [nd.meta.get("val") for nd in node.args] 447 kwargs_vals = {name: nd.meta.get("val") for name, nd in node.kwargs.items()} 448 get_scratch_metas = getattr(target, "get_scratch_metas", None) 449 if not get_scratch_metas: 450 raise RuntimeError( 451 "The get_scratch_metas attribute is not found on the out variant op when converting it to a scratch op. Make sure you have imported the module that attaches the get_scratch_metas attribute to the out variant op." 452 ) 453 scratch_metas = get_scratch_metas(*args_vals, **kwargs_vals) 454 scratch_kwargs = {} 455 with graph_module.graph.inserting_before(node): 456 for name, val in scratch_metas.items(): 457 scratch = make_alloc_node(graph_module, val, None) 458 scratch_kwargs[name] = scratch 459 node.target = scratch_op 460 kwargs = dict(node.kwargs) 461 kwargs.update(scratch_kwargs) 462 node.kwargs = kwargs 463 logging.debug(f"Out variant {target} is converted to scratch op {scratch_op}") 464 return PassResult(graph_module, True) 465 466 467def dead_code_elimination_pass(graph_module: torch.fx.GraphModule) -> PassResult: 468 for subgm in graph_module.modules(): 469 if not isinstance(subgm, torch.fx.GraphModule): 470 continue 471 subgm.graph.eliminate_dead_code() 472 subgm.recompile() 473 return PassResult(graph_module, True) 474 475 476# Passes to convert a graph module from ATen to Edge IR 477 478base_pre_op_replace_passes: List[Callable[[torch.nn.Module], PassResult]] = PassManager( 479 passes=[ 480 # ReplaceSymSizeOpPass need to be run before other passes which inherits 481 # from ExportPass. ExportPass can not handle OpOverloadPacket in its 482 # call_function method. The ReplaceSymSizeOpPass pass converts sym size 483 # ops from OpOverloadPacket to OpOverload. 484 ReplaceSymSizeOpPass(), 485 NormalizeTransposePass(), 486 ReplaceBrokenOpsWithFunctionalOpsPass(), 487 ScalarToTensorPass(), 488 SymToTensorPass(), 489 RemoveNoopPass(), 490 RemoveToCopyPass(), 491 ] 492).passes 493 494base_post_op_replace_passes: List[Callable[[torch.nn.Module], PassResult]] = ( 495 PassManager( 496 passes=[ 497 dead_code_elimination_pass, 498 DebugHandleGeneratorPass(), 499 ] 500 ).passes 501) 502 503 504def propagate_dynamic_shape( 505 dynamic_memory_planning_mode: DynamicMemoryPlanningMode = DynamicMemoryPlanningMode.UPPER_BOUND, 506) -> List[PassType]: 507 """ 508 Run a few passes on the GraphModule to propagate the dynamic shape information. 509 510 Mainly used to provide dynamic shape information for delegation. 511 """ 512 return [ 513 SpecPropPass(), 514 HintBasedSymShapeEvalPass(), 515 ] 516