1# mypy: allow-untyped-defs 2import builtins 3import logging 4import operator 5import typing 6import warnings 7from contextlib import contextmanager 8from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union 9 10import torch 11import torch.export._trace 12from torch import _C 13from torch._export.passes.replace_quantized_ops_with_standard_ops_pass import ( 14 replace_quantized_ops_with_standard_ops, 15) 16from torch.export.exported_program import ExportedProgram 17from torch.export.graph_signature import ( 18 ConstantArgument, 19 CustomObjArgument, 20 InputKind, 21 InputSpec, 22 OutputKind, 23 OutputSpec, 24 TensorArgument, 25) 26from torch.fx import subgraph_rewriter 27 28 29log = logging.getLogger(__name__) 30 31 32def _get_param_count_list(method_graph, args_params): 33 param_count_list = [] 34 for input_, arg_params_ in zip(method_graph.inputs(), args_params): 35 if "PackedParams" in str(input_.type()): 36 in_vars, _ = torch.jit._flatten(arg_params_) 37 param_count_list.append(len(in_vars)) 38 else: 39 param_count_list.append(arg_params_ is not None) 40 41 return param_count_list 42 43 44def _trace_and_get_graph_from_model(model, args): 45 # A basic sanity check: make sure the state_dict keys are the same 46 # before and after running the model. Fail fast! 47 orig_state_dict_keys = torch.jit._unique_state_dict(model).keys() 48 49 # Disable Autocast cache because it replaces kernel's weight and bias 50 # by (undesired) constants. 51 # No perf impact for when there are reused weights since https://github.com/pytorch/pytorch/pull/85665 52 prev_autocast_cache_enabled = torch.is_autocast_cache_enabled() 53 torch.set_autocast_cache_enabled(False) 54 trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph( 55 model, 56 args, 57 strict=False, 58 _force_outplace=False, 59 _return_inputs_states=True, 60 ) 61 torch.set_autocast_cache_enabled(prev_autocast_cache_enabled) 62 63 if orig_state_dict_keys != torch.jit._unique_state_dict(model).keys(): 64 raise RuntimeError( 65 "state_dict changed after running the tracer; " 66 "something weird is happening in your model!" 67 ) 68 69 return trace_graph, torch_out 70 71 72def _create_jit_graph( 73 model: Union[torch.nn.Module, torch.jit.ScriptFunction], args: Sequence[Any] 74) -> Tuple[torch.Graph, List["_C.IValue"], Any, Optional[torch.ScriptModule]]: 75 if isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)): 76 flattened_args = tuple(torch.jit._flatten(tuple(args))[0]) 77 torch_out = None 78 79 if isinstance(model, torch.jit.ScriptModule): 80 try: 81 graph = model.forward.graph # type: ignore[attr-defined] 82 except AttributeError as e: 83 raise RuntimeError("'forward' method must be a script method") from e 84 _C._jit_pass_onnx_function_substitution(graph) 85 freezed_module = _C._freeze_module( 86 typing.cast(_C.ScriptModule, model._c), preserveParameters=True 87 ) 88 module, params = _C._jit_onnx_list_model_parameters(freezed_module) 89 method_graph = module._get_method("forward").graph 90 args_params = tuple(args) + tuple(params) 91 param_count_list = _get_param_count_list(method_graph, args_params) 92 in_vars, _ = torch.jit._flatten(args_params) 93 graph = _C._propagate_and_assign_input_shapes( 94 method_graph, tuple(in_vars), param_count_list, False, False 95 ) 96 return graph, params, torch_out, module 97 98 # torch.jit.ScriptFunction 99 params = [] 100 graph = model.graph 101 _C._jit_pass_onnx_function_substitution(graph) 102 param_count_list = _get_param_count_list(graph, args) 103 graph = _C._propagate_and_assign_input_shapes( 104 graph, flattened_args, param_count_list, False, False 105 ) 106 return graph, params, torch_out, None 107 108 graph, torch_out = _trace_and_get_graph_from_model(model, args) 109 _C._jit_pass_onnx_lint(graph) 110 state_dict = torch.jit._unique_state_dict(model) 111 params = list(state_dict.values()) 112 graph_inputs = list(graph.inputs()) 113 user_input_num = len(graph_inputs) - len(state_dict) 114 param_names = list(state_dict.keys()) 115 for i, inp in enumerate(graph_inputs): 116 if i >= user_input_num: 117 inp.setDebugName(param_names[i - user_input_num]) 118 _C._jit_pass_onnx_function_substitution(graph) 119 return graph, params, torch_out, None 120 121 122def list_add(a, b): 123 return a + b 124 125 126def list_append(container, element): 127 return container + [element] 128 129 130def execute_subgraph_from_prim_loop( 131 subgraph, iter_idx, len_loop_local_arguments, *args, **kwargs 132): 133 """ 134 subgraph: GraphModule from sub-block. 135 iter_idx: The index of interation. 136 len_loop_local_arguments: The number of loop local arguments in args. 137 """ 138 139 # Loop local variables. TS graph create those as inputs because their values 140 # are updated inside the loop. 141 loop_local_args = args[:len_loop_local_arguments] 142 # Global variables that are not passed in as inputs to the loop sub-blocks 143 # but are directly used. Most of time, their values are not updated, but 144 # the only exception is when there are some operations that perform inplace 145 # updates. 146 global_args = args[len_loop_local_arguments:] 147 return subgraph(*global_args, iter_idx, *loop_local_args, **kwargs) 148 149 150def inplace_optimize_sym_size_div(gm: torch.fx.GraphModule): 151 def pattern(im, dim, scale): 152 sym_size_int = torch.ops.aten.sym_size.int(im, dim) 153 scalar_tensor = torch.ops.aten.scalar_tensor(sym_size_int) 154 div_scalar_mode = torch.ops.aten.div.Scalar_mode( 155 scalar_tensor, scale, rounding_mode="trunc" 156 ) 157 int_tensor = torch.ops.aten.Int.Tensor(div_scalar_mode) 158 return int_tensor 159 160 def replacement(im, dim, scale): 161 sym_size_int = torch.ops.aten.sym_size.int(im, dim) 162 return sym_size_int // scale 163 164 replaced_patterns = subgraph_rewriter.replace_pattern(gm, pattern, replacement) 165 166 167def is_valid_for_codegen(name): 168 if len(name) == 0: 169 raise RuntimeError("Empty argument name for codegen") 170 if name[0].isdigit(): 171 return False 172 return True 173 174 175def normalize_name(name: str, prefix: str = "rename") -> str: 176 name = name.replace(".", "_") 177 if is_valid_for_codegen(name): 178 return name 179 return f"{prefix}_{name}" 180 181 182def ir_name_to_func_name(name: str) -> str: 183 """prim::If -> convert_prim_If""" 184 name_list = name.split("::") 185 return "convert_" + "_".join(name_list) 186 187 188def get_node_as_placeholder_or_get_attr(fx_graph, name, is_top_level_graph): 189 if is_top_level_graph: 190 return fx_graph.get_attr(name) 191 return fx_graph.placeholder(name) 192 193 194_TORCH_DTYPE_TO_ENUM = { 195 torch.uint8: 0, 196 torch.int8: 1, 197 torch.int16: 2, 198 torch.int32: 3, 199 torch.int64: 4, 200 torch.float16: 5, 201 torch.float32: 6, 202 torch.float64: 7, 203 torch.complex32: 8, 204 torch.complex64: 9, 205 torch.complex128: 10, 206 torch.bool: 11, 207 torch.qint8: 12, 208 torch.quint8: 13, 209 torch.bfloat16: 15, 210} 211 212_TORCH_ENUM_TO_DTYPE = {value: key for key, value in _TORCH_DTYPE_TO_ENUM.items()} 213 214 215def get_dtype_as_int(tensor): 216 """ 217 prim::dtype has the signature "Tensor a) -> int", where it gets the dtype of 218 the tensor and returns the integer corresponding to this dtype based on the 219 enum in ScalarType.h 220 """ 221 dtype = tensor.dtype 222 if dtype not in _TORCH_DTYPE_TO_ENUM: 223 raise RuntimeError(f"Unsupported dtype {dtype}") 224 return _TORCH_DTYPE_TO_ENUM[dtype] 225 226 227# Those operators will be automatically populated to a instance method 228# of TS2FXGraphConverter with name convert_<namespace>_<opname>(). 229# Please check __init__ for method population implementations. 230kind_to_standard_operators = { 231 "prim::max": builtins.max, 232 "prim::min": builtins.min, 233 "prim::TupleIndex": operator.getitem, 234 "aten::__is__": operator.is_, 235 "aten::__isnot__": operator.is_not, 236 "aten::__not__": operator.not_, 237 "aten::__contains__": operator.contains, 238 "prim::dtype": get_dtype_as_int, 239 "aten::len": len, 240 # Mapping from specialized op to its symbolic counterpart. 241 # They currently do not have any other overrides. 242 "aten::numel": torch.ops.aten.sym_numel, 243 "aten::size": torch.ops.aten.sym_size, 244 "aten::storage_offset": torch.ops.aten.sym_storage_offset, 245 "aten::stride": torch.ops.aten.sym_stride, 246} 247 248 249def get_ir_value_parent_name_and_attr_name(node): 250 irv_parent_name, irv_name = node.input().debugName(), node.output().debugName() 251 attr_name = node.s("name") 252 return irv_name, irv_parent_name, attr_name 253 254 255def construct_fqn(ir, ref_map, name_map): 256 name_list = [] 257 while ir in ref_map: 258 name_list.append(name_map[ir]) 259 ir = ref_map[ir] 260 return ".".join(reversed(name_list)) 261 262 263def get_block_to_lifted_attrs(graph: torch._C.Graph) -> Dict[torch._C.Block, Set[str]]: 264 """ 265 Perform two passes to get a mapping of blocks to a set of FQNs of its lifted attributes. 266 When a graph has control flow, the graph will be divided into multiple blocks. We want to convert 267 each block to a graph which will be passed into torch.cond. A restriction for torch.cond is that model 268 parameters/buffers are expected to be lifted as inputs to the subgraphs. Before converting the model, 269 we will run this pass which will: 270 1. Figure out which params/buffers are used within blocks through tracing the GetAttr calls. 271 2. Process the graph bottom up to find the lifted attributes of each block by taking the union 272 of the attributes used in the current block, and the lifted attributes of all its child blocks. 273 274 Returns: 275 A mapping of blocks to a set of FQNs of its lifted attributes. 276 """ 277 278 # A map from a block to its expected to be lifted arguments. 279 blocks_to_lifted_attrs: Dict[torch._C.Block, Set[str]] = {} 280 281 # Reference map stores the input (i.e., src) and output (i.e., dest) IR of a 282 # GetAttr node. By traversing this reference map, we can figure out the 283 # full IR aliasing pass and figure out the FQN of an attribute. 284 # E.g., %2 = GetAttr(linear)[%1] --> node_to_parent_map["%2"] = "%1" 285 node_to_parent_map: Dict[str, str] = {} 286 287 # Used for reconstructing the FQN of an attribute based on the reference map. 288 # In nutshell, for each GetAttr call, GetAttr(input IR, attribute name) -> output IR 289 # This name map stores which attribute name is called for a src IR --> dest IR action. 290 # E.g., %2 = GetAttr(linear)[%1] --> node_to_attr_name["%2"] = "linear" 291 node_to_attr_name: Dict[str, str] = {} 292 293 def _dfs_get_attr_dependency(entry): 294 """ 295 First DFS path to construct reference map and name map. 296 """ 297 for node in entry.nodes(): 298 if node.kind() == "prim::GetAttr": 299 ( 300 irv_name, 301 irv_parent_name, 302 attr_name, 303 ) = get_ir_value_parent_name_and_attr_name(node) 304 node_to_parent_map[irv_name] = irv_parent_name 305 node_to_attr_name[irv_name] = attr_name 306 for block in node.blocks(): 307 _dfs_get_attr_dependency(block) 308 309 def _map_blocks_to_lifted_attrs(entry): 310 """ 311 Walk the graph in a bottom-up fashion to build the expected to be 312 lifted arguments for each block. 313 """ 314 arguments: Set[str] = set() 315 for node in entry.nodes(): 316 for block in node.blocks(): 317 # Recursively build. 318 arguments = arguments.union(_map_blocks_to_lifted_attrs(block)) 319 if node.kind() == "prim::GetAttr": 320 irv_name = node.output().debugName() 321 # Skip for intermediate GetAttr, which will anyway not result a FQN. 322 # E.g., node_to_parent_name: {"%3": "%2", "%2": "%1"} 323 # node_to_attr_name: {"%3": "weight", "%2": "linear", "%1": "self"} 324 # There is only one FQN %3-->%2-->%1: self.linear.weight 325 # %2-->%1 is not a FQN: self.linear 326 if irv_name not in set(node_to_parent_map.values()): 327 arguments.add( 328 construct_fqn(irv_name, node_to_parent_map, node_to_attr_name) 329 ) 330 if not isinstance(entry, torch._C.Graph): # Skip the top level. 331 blocks_to_lifted_attrs[entry] = arguments 332 return arguments 333 334 _dfs_get_attr_dependency(graph) 335 _map_blocks_to_lifted_attrs(graph) 336 337 return blocks_to_lifted_attrs 338 339 340def get_attribute_fqn_from_ts_node( 341 name_to_attribute_fqn: Dict[str, str], node: torch._C.Node 342) -> str: 343 def get_attr(name: str): 344 if name in name_to_attribute_fqn: 345 return name_to_attribute_fqn[name] 346 else: 347 raise ValueError(f"Attribute {name} not found") 348 349 if node.kind() == "prim::SetAttr": 350 input_name = next(node.inputs()).debugName() 351 elif node.kind() == "prim::GetAttr": 352 input_name = node.input().debugName() 353 else: 354 raise RuntimeError( 355 f"Unexpected node kind when getting attribute fqn. node: {node} " 356 ) 357 358 attr_name = node.s("name") 359 root_attr_name = get_attr(input_name) 360 attr_fqn = f"{root_attr_name}.{attr_name}" if root_attr_name else attr_name 361 362 return attr_fqn 363 364 365def get_op_overload(node: torch._C.Node): 366 schema_str = node.schema() 367 assert schema_str != "(no schema)", f"got empty schema for {node}" 368 schema: torch._C.FunctionSchema = torch._C.parse_schema(schema_str) 369 ns, op_name = str(schema.name).split("::") 370 override = schema.overload_name 371 372 try: 373 op_overload_mod = getattr(torch.ops, ns) 374 op_overload_packet = getattr(op_overload_mod, op_name) 375 if override: 376 op_overload = getattr(op_overload_packet, override) 377 else: 378 op_overload = op_overload_packet.default 379 except Exception as e: 380 raise RuntimeError( 381 f"Unable to find operator {node.kind()} with schema {node.schema()}" 382 ) from e 383 384 return op_overload 385 386 387class TS2FXGraphConverter: 388 def __init__( 389 self, 390 ts_graph: Union[torch._C.Graph, torch._C.Block], 391 name_to_param: Dict[str, torch.Tensor], 392 name_to_buffer: Dict[str, torch.Tensor], 393 blocks_to_lifted_attrs: Dict[torch._C.Block, Set[str]], 394 name_to_non_tensor_attribute: Dict[str, Any], 395 name_to_constant: Dict[str, Any], 396 ): 397 self.ts_graph = ts_graph 398 self.name_to_param = name_to_param 399 self.name_to_buffer = name_to_buffer 400 401 self.fx_graph: torch.fx.Graph = torch.fx.Graph() 402 self.input_specs: List[InputSpec] = [] 403 self.output_specs: List[OutputSpec] = [] 404 405 self.name_to_node: Dict[ 406 str, Union[torch.fx.Node, List[torch.fx.Node], Dict[Any, torch.fx.Node]] 407 ] = {} 408 self.name_to_constant: Dict[str, Any] = name_to_constant 409 410 # Mapping from torchscript node output name to attribute fully qualified name 411 self.name_to_attribute_fqn: Dict[str, str] = {} 412 413 # Mapping from fully qualified name to real values or a fx graph node 414 # During convert, this represents the current value of a non-tensor attribute 415 # One use case is: 416 # def forward(self, x): 417 # c1 = self.count 418 # self.count += 1 419 # c2 = self.count 420 # return x + c1 + c2 421 self.name_to_non_tensor_attribute_node: Dict[str, Any] = {} 422 423 # Mapping from fully qualified name to initial real values inputs 424 # We separate it from self.name_to_non_tensor_attribute_node since 425 # we need initial real value input when we construct fx.GraphModule 426 self.name_to_non_tensor_attribute: Dict[str, Any] = name_to_non_tensor_attribute 427 428 self.subgraphs: Dict[str, torch.fx.GraphModule] = {} 429 430 self.blocks_to_lifted_attrs = blocks_to_lifted_attrs 431 432 # Populate methods for the standard operators. 433 for k in kind_to_standard_operators.keys(): 434 handler_func_name = ir_name_to_func_name(k) 435 # Create an indirect function call: 436 # convert_<namespace>_<opname> --> lambda node: _convert_standard_operator(node) 437 setattr( 438 self, 439 handler_func_name, 440 lambda node: self._convert_standard_operators(node), 441 ) 442 443 # This stores a list of return results that do not appear in the original TS 444 # graph's outputs. The reason we maintain this is because some operations in the sub-block 445 # might have inplace updates to the variable defined in the parent fx graph. After 446 # the execution of that sub-block, the variable defined in the parent fx graph also 447 # needs to be updated. 448 self.name_update_from_subblock_to_parent: Set[str] = set() 449 450 def _is_get_attr_node(self, fqn): 451 return ( 452 fqn in self.name_to_buffer 453 or fqn in self.name_to_param 454 or ( 455 fqn in self.name_to_constant 456 and isinstance(self.name_to_constant[fqn], torch.ScriptObject) 457 ) 458 ) 459 460 def _convert_block_to_subgraph(self, node: torch._C.Node, arguments: List[str]): 461 subgraph_nodes, subgraph_converters = [], [] 462 for block in node.blocks(): 463 subgraph_converter = TS2FXGraphConverter( 464 block, 465 self.name_to_param, 466 self.name_to_buffer, 467 self.blocks_to_lifted_attrs, 468 {}, 469 self.name_to_constant, 470 ) 471 subgraph_converter.name_to_attribute_fqn = self.name_to_attribute_fqn 472 473 for block_arg in arguments: 474 normalized_block_arg_name = normalize_name(block_arg) 475 placeholder_node = subgraph_converter.fx_graph.placeholder( 476 normalized_block_arg_name 477 ) 478 subgraph_converter.name_to_node[block_arg] = placeholder_node 479 480 subgraph = subgraph_converter.convert() 481 subgraph_name = self.add_subgraph(subgraph) 482 subgraph_nodes.append(self.fx_graph.get_attr(subgraph_name)) 483 subgraph_converters.append(subgraph_converter) 484 return subgraph_nodes, subgraph_converters 485 486 def _identify_inputs_as_arguments(self, entry): 487 """ 488 Identify inputs from the innermost sub-block. This is needed 489 for nested sub-blocks when the input is hidden in the nested sub-block. 490 E.g., example IR of input is hidden in the nested sub-block. 491 Graph[x.1] 492 %1 = ... 493 Block[] 494 Block[x.1] 495 %2 = x.1 ... 496 """ 497 arguments: Set[str] = set() 498 for block in entry.blocks(): 499 for block_node in block.nodes(): 500 for block_node_in in block_node.inputs(): 501 if ( 502 block_node_in.debugName() in self.name_to_node 503 and block_node_in.debugName() not in self.name_to_attribute_fqn 504 ): 505 arguments.add(block_node_in.debugName()) 506 arguments = arguments.union( 507 self._identify_inputs_as_arguments(block_node) 508 ) 509 return arguments 510 511 def is_top_level_graph(self): 512 return isinstance(self.ts_graph, torch._C.Graph) 513 514 def add_subgraph(self, subgraph) -> str: 515 name = f"subgraph_{len(self.subgraphs)}" 516 self.subgraphs[name] = subgraph 517 return name 518 519 def get_args_kwargs(self, node: torch._C.Node, schema): 520 args = [] 521 kwargs = {} 522 for input, schema_arg in zip(node.inputs(), schema.arguments): 523 if schema_arg.kwarg_only: 524 kwargs[schema_arg.name] = self.get_fx_value_by_ir_value(input) 525 else: 526 args.append(self.get_fx_value_by_ir_value(input)) 527 528 return tuple(args), kwargs 529 530 def get_fx_value_by_ir_value(self, value: torch._C.Value): 531 value_name = value.debugName() 532 533 if value_name in self.name_to_node: 534 input_node = self.name_to_node[value_name] 535 return input_node 536 elif value_name in self.name_to_constant: 537 if isinstance(self.name_to_constant[value_name], torch.ScriptObject): 538 return self.fx_graph.get_attr(value_name) 539 return self.name_to_constant[value_name] 540 else: 541 raise ValueError(f"Input {value_name} not found") 542 543 def get_fx_value_by_fqn(self, name): 544 if name in self.name_to_node: 545 fx_node = self.name_to_node[name] 546 elif name in self.name_to_constant: 547 fx_node = self.name_to_constant[name] 548 elif name in self.name_to_non_tensor_attribute_node: 549 fx_node = self.name_to_non_tensor_attribute_node[name] 550 elif name in self.name_to_non_tensor_attribute: 551 fx_node = self.name_to_non_tensor_attribute[name] 552 else: 553 raise ValueError(f"Attribute {name} not found") 554 return fx_node 555 556 def convert(self) -> torch.fx.GraphModule: 557 self.convert_graph_inputs() 558 559 for node in self.ts_graph.nodes(): 560 self.convert_node(node) 561 562 self.convert_graph_outputs() 563 564 # Pass parameter and buffer to the root for lookup. 565 gm = torch.fx.GraphModule( 566 { 567 **self.subgraphs, 568 **self.name_to_param, 569 **self.name_to_buffer, 570 **self.name_to_non_tensor_attribute, 571 **self.name_to_constant, 572 }, 573 self.fx_graph, 574 ) 575 576 inplace_optimize_sym_size_div(gm) 577 578 gm.graph.lint() 579 580 return gm 581 582 def convert_graph_inputs(self): 583 for graph_input in self.ts_graph.inputs(): 584 name = graph_input.debugName() 585 586 if name in self.name_to_param: 587 normalized_name = normalize_name(name) 588 self.input_specs.append( 589 InputSpec( 590 InputKind.PARAMETER, 591 arg=TensorArgument(name=normalized_name), 592 target=name, 593 ) 594 ) 595 fx_node = get_node_as_placeholder_or_get_attr( 596 self.fx_graph, name, self.is_top_level_graph() 597 ) 598 elif name in self.name_to_buffer: 599 normalized_name = normalize_name(name) 600 self.input_specs.append( 601 InputSpec( 602 InputKind.BUFFER, 603 arg=TensorArgument(name=normalized_name), 604 target=name, 605 persistent=True, 606 ) 607 ) 608 fx_node = get_node_as_placeholder_or_get_attr( 609 self.fx_graph, name, self.is_top_level_graph() 610 ) 611 elif name in self.name_to_constant: 612 assert isinstance( 613 self.name_to_constant[name], torch.ScriptObject 614 ), "Input conversion only handles ScriptObject" 615 normalized_name = normalize_name(name) 616 self.input_specs.append( 617 InputSpec( 618 InputKind.CUSTOM_OBJ, 619 arg=CustomObjArgument( 620 name=normalized_name, class_fqn=normalized_name 621 ), 622 target=name, 623 persistent=False, 624 ) 625 ) 626 fx_node = get_node_as_placeholder_or_get_attr( 627 self.fx_graph, name, self.is_top_level_graph() 628 ) 629 elif isinstance(graph_input.type(), torch.ClassType): 630 # Directly skip inputs that are ScriptObject but not used in the graph. 631 continue 632 else: 633 normalized_name = normalize_name(name, prefix="input") 634 self.input_specs.append( 635 InputSpec( 636 InputKind.USER_INPUT, 637 arg=TensorArgument(name=normalized_name), 638 target=name, 639 ) 640 ) 641 fx_node = self.fx_graph.placeholder(normalized_name) 642 643 self.name_to_node[name] = fx_node 644 645 def convert_aten_Float(self, node: torch._C.Node): 646 def to_float_tensor(t): 647 return t.to(dtype=torch.float).item() 648 649 inp_list = [ 650 self.get_fx_value_by_ir_value(inp) for inp in node.inputs() 651 ] # noqa: C416 652 fx_node = self.fx_graph.call_function( 653 to_float_tensor, 654 tuple(inp_list), 655 ) 656 self.name_to_node[node.output().debugName()] = fx_node 657 658 def convert_aten_tensor(self, node: torch._C.Node): 659 """aten::tensor creates a constant tensor ad-hoc --> GetAttr""" 660 args, kwargs = self.get_args_kwargs(node, torch.ops.aten.tensor.default._schema) 661 662 for k in kwargs: 663 if k == "requires_grad": 664 kwargs[k] = bool(kwargs[k]) # 0 -> False, 1 -> True 665 666 to_tensor = ( 667 torch.tensor 668 if all(isinstance(a, int) for a in args) 669 else torch._refs.tensor 670 ) 671 672 def target(*args, **kwargs): 673 if "dtype" in kwargs and kwargs["dtype"] is not None: 674 kwargs["dtype"] = _TORCH_ENUM_TO_DTYPE[kwargs["dtype"]] 675 return to_tensor(*args, **kwargs) 676 677 # def to_dynamic_tensor(*args, **kwargs): 678 # if "dtype" in kwargs and kwargs["dtype"] is not None: 679 # kwargs["dtype"] = _TORCH_ENUM_TO_DTYPE[kwargs["dtype"]] 680 # return torch._refs.tensor(*args, **kwargs) 681 682 output_name = node.output().debugName() 683 fx_node = self.fx_graph.call_function(target, args, kwargs) 684 self.name_to_node[output_name] = fx_node 685 686 def convert_aten_append(self, node: torch._C.Node): 687 # special handle python list append: "aten::append.t(t[](a!) self, t(c -> *) el) -> t[](a!)" 688 689 # inplace append to the list!! This is kinda crazy, as we are inplace mutating the list 690 # This makes the converter "non-functional", and the result depends on the order of the nodes being converter 691 # In a sense, the converter now becomes an stateful interpreter 692 warnings.warn( 693 "Converting aten::append.t, which is a inplace mutation of the list. " 694 "This makes the converter non-functional: the result depends on the order of the append nodes being converter!" 695 ) 696 697 args = tuple(self.get_fx_value_by_ir_value(inp) for inp in node.inputs()) 698 fx_node = self.fx_graph.call_function(list_append, args) 699 self.name_to_node[node.output().debugName()] = fx_node 700 701 # inplace mutate arg[0], which is the python list 702 self.name_to_node[node.inputsAt(0).debugName()] = fx_node 703 704 # Variables that need to be updated to parent module. 705 if not self.is_top_level_graph() and args[0].op == "placeholder": 706 self.name_update_from_subblock_to_parent.add(node.inputsAt(0).debugName()) 707 708 def convert_prim_Constant(self, node: torch._C.Node): 709 name = node.output().debugName() 710 711 value: Any = None 712 if node.hasAttribute("value"): 713 constant_kind = node.kindOf("value") 714 if constant_kind == "i": 715 value = node.i("value") 716 elif constant_kind == "f": 717 value = node.f("value") 718 elif constant_kind == "s": 719 value = node.s("value") 720 elif constant_kind == "t": 721 alias_name = ( 722 f"lifted_tensor_{name}" # Follow naming convention from EP tracing. 723 ) 724 fx_node = self.fx_graph.get_attr(alias_name) 725 self.name_to_node[name] = fx_node 726 name, value = alias_name, node.t("value") 727 elif constant_kind == "ival": 728 value = node.ival("value") 729 else: 730 raise ValueError(f"Unsupported constant type: {node.kindOf('value')}") 731 else: 732 value = None 733 734 self.name_to_constant[name] = value 735 736 def convert_prim_CallMethod(self, node: torch._C.Node): 737 inp_list = [ 738 self.get_fx_value_by_ir_value(inp) for inp in node.inputs() 739 ] # noqa: C416 740 fx_node = self.fx_graph.call_method( 741 node.s("name"), 742 tuple(inp_list), 743 ) 744 self.name_to_node[node.output().debugName()] = fx_node 745 746 def convert_prim_device(self, node: torch._C.Node): 747 input_type = node.input().type() 748 if input_type.isSubtypeOf(torch._C.TensorType.get()): 749 device = input_type.device() # type: ignore[attr-defined] 750 output_name = node.output().debugName() 751 self.name_to_constant[output_name] = device 752 else: 753 raise ValueError(f"Unsupported JitType ({input_type}) when get device") 754 755 def convert_prim_GetAttr(self, node: torch._C.Node): 756 # Build fully qulified name 757 attr_fqn = get_attribute_fqn_from_ts_node(self.name_to_attribute_fqn, node) 758 output_name = node.output().debugName() 759 self.name_to_attribute_fqn[output_name] = attr_fqn 760 761 if self.is_top_level_graph(): 762 if self._is_get_attr_node(attr_fqn): 763 # We insert a get_attr node due to two reasons. 764 # First, ts graph does not lift tensor constants as input nodes. So 765 # tensor constants may be ignored by in convert_graph_inputs(). 766 # Second, attr_fqn may have been written to via SetAttr. Two 767 # GetAttr may give different values. 768 self.name_to_node[output_name] = self.fx_graph.get_attr(attr_fqn) 769 else: 770 if attr_fqn not in self.name_to_non_tensor_attribute_node: 771 self.name_to_non_tensor_attribute_node[ 772 attr_fqn 773 ] = self.name_to_non_tensor_attribute[attr_fqn] 774 self.name_to_node[output_name] = self.name_to_non_tensor_attribute_node[ 775 attr_fqn 776 ] 777 else: 778 # Special support for if blocks which do not allow SetAttr TorchScript 779 # node and get_attr FX Graph Node. 780 if self._is_get_attr_node(attr_fqn): 781 self.name_to_node[output_name] = self.name_to_node[attr_fqn] 782 783 def convert_prim_SetAttr(self, node: torch._C.Node): 784 attr_fqn = get_attribute_fqn_from_ts_node(self.name_to_attribute_fqn, node) 785 attr_value = tuple(node.inputs())[1] 786 ts_graph_tensor_input = self.get_fx_value_by_ir_value(attr_value) 787 if self._is_get_attr_node(attr_fqn): 788 fx_attr_node = self.fx_graph.get_attr(attr_fqn) 789 self.fx_graph.call_function( 790 torch.Tensor.copy_, (fx_attr_node, ts_graph_tensor_input) 791 ) 792 else: 793 self.name_to_non_tensor_attribute_node[attr_fqn] = ts_graph_tensor_input 794 795 def convert_call_function_op(self, node: torch._C.Node): 796 target = get_op_overload(node) 797 798 args, kwargs = self.get_args_kwargs(node, target._schema) 799 800 fx_node = self.fx_graph.call_function(target, args, kwargs) 801 802 # TODO: covnert sourceRange() into stack_trace 803 # fx_node.meta["stack_trace"] = node.sourceRange() 804 805 if node.outputsSize() == 1: 806 output_name = node.output().debugName() 807 self.name_to_node[output_name] = fx_node 808 else: 809 for i, outp in enumerate(node.outputs()): 810 output_name = outp.debugName() 811 next_fx_node = self.fx_graph.call_function( 812 operator.getitem, (fx_node, i) 813 ) 814 self.name_to_node[output_name] = next_fx_node 815 816 def convert_prim_TupleConstruct(self, node: torch._C.Node): 817 self._convert_prim_iterator(node) 818 819 def convert_prim_ListConstruct(self, node: torch._C.Node): 820 self._convert_prim_iterator(node) 821 822 def _convert_prim_iterator(self, node: torch._C.Node): 823 output_list = [] 824 for inp in node.inputs(): 825 output_list.append(self.get_fx_value_by_ir_value(inp)) 826 827 output_name = node.output().debugName() 828 self.name_to_node[output_name] = output_list 829 830 def convert_prim_DictConstruct(self, node: torch._C.Node): 831 output_dict = {} 832 k, v = None, None 833 for i, inp in enumerate(node.inputs()): 834 # We assume key value are stored in pair in the DictConstruct. 835 # The first element is the key and the following is the value. 836 if i % 2 == 0: 837 k = self.get_fx_value_by_ir_value(inp) 838 else: 839 v = self.get_fx_value_by_ir_value(inp) 840 assert ( 841 k is not None and v is not None 842 ), "DictConstruct has an empty key value pair." 843 output_dict[k] = v 844 k, v = None, None 845 846 assert ( 847 k is None and v is None 848 ), "DictConstruct has an odd number of elements (violating our assumption)." 849 850 output_name = node.output().debugName() 851 self.name_to_node[output_name] = output_dict 852 853 def convert_prim_ListUnpack(self, node: torch._C.Node): 854 self._convert_prim_unpack_iterator(node) 855 856 def convert_prim_TupleUnpack(self, node: torch._C.Node): 857 self._convert_prim_unpack_iterator(node) 858 859 def _convert_prim_unpack_iterator(self, node: torch._C.Node): 860 # Single input and multiple outputs for unpacking. 861 for i, outp in enumerate(node.outputs()): 862 outp_name = outp.debugName() 863 inp = self.get_fx_value_by_ir_value(node.input()) 864 fx_node = self.fx_graph.call_function(operator.getitem, (inp, i)) 865 self.name_to_node[outp_name] = fx_node 866 867 def convert_aten_Int(self, node: torch._C.Node): 868 # converts aten::Int as aten._to_copy + aten::_local_scalar_dense 869 target = torch.ops.aten._to_copy.default 870 args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs()) 871 to_copy_node = self.fx_graph.call_function(target, args, {"dtype": torch.int32}) 872 873 fx_node = self.fx_graph.call_function( 874 torch.ops.aten._local_scalar_dense.default, (to_copy_node,) 875 ) 876 877 # TODO: covnert sourceRange() into stack_trace 878 # fx_node.meta["stack_trace"] = node.sourceRange() 879 880 output_name = node.output().debugName() 881 self.name_to_node[output_name] = fx_node 882 883 def convert_prim_NumToTensor(self, node: torch._C.Node): 884 # Converts prim::NumToTensor as aten.scalar_tensor. 885 # prim::NumToTensor IRs are currently triggered by: 886 # .size() https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/frontend/tracer.cpp#L950 887 # .numel() https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/frontend/tracer.cpp#L971 888 # For both of those APIs, torch.jit.trace implicitly sets the output tensor type 889 # to be LongTensor. 890 target = torch.ops.aten.scalar_tensor 891 args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs()) 892 893 fx_node = self.fx_graph.call_function(target, args, {"dtype": torch.long}) 894 output_name = node.output().debugName() 895 self.name_to_node[output_name] = fx_node 896 897 def convert_prim_CreateObject(self, node: torch._C.Node): 898 output_name = node.output().debugName() 899 self.name_to_attribute_fqn[output_name] = "" 900 901 def convert_aten__convolution(self, node: torch._C.Node): 902 # converts aten::_convolution as aten.convolution, since aten::_convolution 903 # doesn't have a meta function 904 target = torch.ops.aten.convolution.default 905 args, kwargs = self.get_args_kwargs(node, target._schema) 906 907 fx_node = self.fx_graph.call_function(target, args, kwargs) 908 909 output_name = node.output().debugName() 910 self.name_to_node[output_name] = fx_node 911 912 def convert_aten_div(self, node: torch._C.Node): 913 target = get_op_overload(node) 914 schema = target._schema 915 916 args, kwargs = self.get_args_kwargs(node, schema) 917 918 # converts aten::div.Tensor_mode(x, tensor_constant) 919 # as aten.div.Scalar_mode(x, tensor_constant.item()) 920 if schema.overload_name == "Tensor_mode": 921 arg1_name = args[1].name 922 if arg1_name in self.name_to_constant and isinstance( 923 self.name_to_constant[arg1_name], torch.Tensor 924 ): 925 tensor_constant = self.name_to_constant[arg1_name] 926 if tensor_constant.numel() == 1: 927 updated_args = list(args) 928 updated_args[1] = self.name_to_constant[arg1_name].item() 929 930 fx_node = self.fx_graph.call_function( 931 torch.ops.aten.div.Scalar_mode, 932 tuple(updated_args), 933 kwargs, 934 ) 935 936 # TODO: covnert sourceRange() into stack_trace 937 # fx_node.meta["stack_trace"] = node.sourceRange() 938 939 output_name = node.output().debugName() 940 self.name_to_node[output_name] = fx_node 941 return 942 943 self.convert_call_function_op(node) 944 945 def convert_aten___getitem__(self, node: torch._C.Node): 946 input_container, index = tuple( 947 self.get_fx_value_by_ir_value(input) for input in node.inputs() 948 ) 949 fx_node = self.fx_graph.call_function( 950 operator.getitem, (input_container, index) 951 ) 952 output_name = node.output().debugName() 953 self.name_to_node[output_name] = fx_node 954 955 def convert_aten_to(self, node: torch._C.Node): 956 target = get_op_overload(node) 957 args, kwargs = self.get_args_kwargs(node, target._schema) 958 959 # special handle aten.to.dtype and aten.to.prim_dtype followed by inplace_mutation_op 960 # coz aten.to + inplace_mutation_op pattern would trigger 961 # "cannot mutate tensors with frozen storage" functionalization error. 962 # To work around the issue, we override the copy to be True, so that the output 963 # is for sure not an alias of input 964 if target == torch.ops.aten.to.dtype or target == torch.ops.aten.to.prim_dtype: 965 user_nodes = [use.user for use in node.output().uses()] 966 user_targets = [ 967 get_op_overload(user_node) 968 for user_node in user_nodes 969 if user_node.schema() != "(no schema)" 970 ] 971 has_mutable_target = any( 972 target._schema.is_mutable for target in user_targets 973 ) 974 975 if has_mutable_target: 976 assert len(args) >= 4 977 new_args = list(args) 978 new_args[3] = True # copy, override to True 979 fx_node = self.fx_graph.call_function( 980 torch.ops.aten.to.dtype, tuple(new_args) 981 ) 982 # temp hack to work around the issue https://github.com/pytorch/pytorch/issues/131679 983 # When this issue is fixed, the clone node would be no longer needed 984 clone_node = self.fx_graph.call_function( 985 torch.ops.aten.clone.default, (fx_node,) 986 ) 987 output_name = node.output().debugName() 988 self.name_to_node[output_name] = clone_node 989 return 990 991 self.convert_call_function_op(node) 992 993 def convert_aten_add(self, node: torch._C.Node): 994 if node.schema() == "(no schema)": 995 if isinstance(node.inputsAt(0).type(), torch.ListType) and isinstance( 996 node.inputsAt(1).type(), torch.ListType 997 ): 998 target = torch.ops.aten.add.t 999 else: 1000 raise RuntimeError(f"unable to determind the target for {node}") 1001 else: 1002 target = get_op_overload(node) 1003 1004 if target == torch.ops.aten.add.t: 1005 # special handle python list/tuple add: "aten::add.t(t[] a, t[] b) -> t[]" for 1006 # RuntimeError: aten::add() Expected a value of type 'List[t]' for argument 'a' but instead found type 'immutable_list'. 1007 args, kwargs = self.get_args_kwargs(node, target._schema) 1008 output_name = node.output().debugName() 1009 self.name_to_node[output_name] = self.fx_graph.call_function(list_add, args) 1010 else: 1011 self.convert_call_function_op(node) 1012 1013 def _check_prim_loop_support(self, node): 1014 inputs = list(node.inputs()) 1015 1016 # TODO: (1/N) stage. 1017 if inputs[0].debugName() not in self.name_to_constant: 1018 raise RuntimeError( 1019 "prim::Loop currently cannot run with dynamic value of number of iterations." 1020 ) 1021 1022 # Make sure the condition is not updated in the subblock. 1023 subblock = next(node.blocks()) 1024 condition_output_name = next(subblock.outputs()).debugName() 1025 for node in subblock.nodes(): 1026 if ( 1027 node.outputsSize() == 1 1028 and node.output().debugName() == condition_output_name 1029 ): 1030 raise RuntimeError( 1031 "prim::Loop currently cannot run with dynamic value of condition." 1032 ) 1033 if node.outputsSize() >= 2: 1034 for outp in node.outputs(): 1035 if outp.debugName() == condition_output_name: 1036 raise RuntimeError( 1037 "prim::Loop currently cannot run with dynamic value of condition." 1038 ) 1039 1040 def convert_prim_Loop(self, node: torch._C.Node): 1041 inputs = list(node.inputs()) 1042 self._check_prim_loop_support(node) 1043 1044 num_iterations = self.get_fx_value_by_ir_value(inputs[0]) 1045 1046 # Find inputs. 1047 loop_local_arguments = [inp.debugName() for inp in inputs[2:]] 1048 1049 global_arguments = self._identify_inputs_as_arguments(node) 1050 1051 # Lift parameters as inputs. 1052 for block in node.blocks(): 1053 global_arguments = global_arguments.union( 1054 self.blocks_to_lifted_attrs[block] 1055 ) 1056 1057 global_arguments = list(global_arguments) 1058 1059 subgraph_nodes, subgraph_converters = self._convert_block_to_subgraph( 1060 node, global_arguments 1061 ) 1062 1063 assert len(subgraph_nodes) == 1 1064 subgraph_converter = subgraph_converters[0] 1065 if not self.is_top_level_graph(): 1066 self.name_update_from_subblock_to_parent = ( 1067 self.name_update_from_subblock_to_parent.union( 1068 subgraph_converter.name_update_from_subblock_to_parent 1069 ) 1070 ) 1071 1072 fx_block_args = [ 1073 self.get_fx_value_by_fqn(name) 1074 for name in loop_local_arguments + global_arguments 1075 ] 1076 for iter_idx in range(num_iterations): 1077 loop_node = self.fx_graph.call_function( 1078 execute_subgraph_from_prim_loop, 1079 # Check execute_node function for the expected arguments order. 1080 ( 1081 subgraph_nodes[0], 1082 iter_idx, 1083 len(loop_local_arguments), 1084 *fx_block_args, 1085 ), 1086 {}, 1087 ) 1088 1089 # Update the value of loop local variables. 1090 if node.outputsSize() >= 1: 1091 for i, outp in enumerate(node.outputs()): 1092 output_name = outp.debugName() 1093 self.name_to_node[output_name] = self.fx_graph.call_function( 1094 operator.getitem, 1095 ( 1096 loop_node, 1097 i + 1, 1098 ), # + 1 because the 0th element is the condition. 1099 ) 1100 fx_block_args[i] = self.name_to_node[output_name] 1101 1102 # Update the value of global variables, whose values are modified inplace. 1103 for i, name in enumerate( 1104 subgraph_converter.name_update_from_subblock_to_parent 1105 ): 1106 self.name_to_node[name] = self.fx_graph.call_function( 1107 operator.getitem, 1108 ( 1109 loop_node, 1110 i + node.outputsSize() + 1, 1111 ), # + 1 because the 0th element is the condition. 1112 ) 1113 global_argument_index = global_arguments.index(name) 1114 fx_block_args[ 1115 i + node.outputsSize() + global_argument_index 1116 ] = self.name_to_node[name] 1117 1118 def _check_set_attr_in_if_block(self, if_node: torch._C.Node): 1119 for block in if_node.blocks(): 1120 for node in block.nodes(): 1121 if node.kind() == "prim::SetAttr": 1122 raise RuntimeError( 1123 "During converting prim::If to torch.cond, found prim::SetAttr op" 1124 " which is not supported yet. Please file an issue if you come " 1125 "across this error." 1126 ) 1127 1128 def convert_prim_If(self, node: torch._C.Node): 1129 self._check_set_attr_in_if_block(node) 1130 1131 inputs = list(node.inputs()) 1132 assert len(inputs) == 1 1133 predicate = self.get_fx_value_by_ir_value(inputs[0]) 1134 1135 # Find inputs. 1136 arguments = self._identify_inputs_as_arguments(node) 1137 1138 # Lift parameters as inputs. 1139 for block in node.blocks(): 1140 arguments = arguments.union(self.blocks_to_lifted_attrs[block]) 1141 1142 arguments = list(arguments) 1143 subgraph_nodes, _ = self._convert_block_to_subgraph(node, arguments) 1144 1145 assert len(subgraph_nodes) == 2 1146 1147 fx_block_args = [self.get_fx_value_by_fqn(name) for name in arguments] 1148 1149 args = ( 1150 predicate, 1151 subgraph_nodes[0], 1152 subgraph_nodes[1], 1153 tuple(fx_block_args), 1154 ) 1155 1156 cond_node = self.fx_graph.call_function(torch.cond, args, {}) 1157 1158 # prim::If can also have zero output. 1159 if node.outputsSize() == 1: 1160 output_name = node.output().debugName() 1161 self.name_to_node[output_name] = cond_node 1162 elif node.outputsSize() > 1: 1163 for i, output in enumerate(node.outputs()): 1164 output_name = output.debugName() 1165 getitem = self.fx_graph.call_function(operator.getitem, (cond_node, i)) 1166 self.name_to_node[output_name] = getitem 1167 1168 def convert_aten_Bool(self, node: torch._C.Node): 1169 self._convert_as_noop(node) 1170 1171 def convert_prim_Enter(self, node: torch._C.Node): 1172 # export generally treats prim::Enter as noop 1173 # The only context manager export supports is aten::enable_grad. 1174 # Unfortunately, TorchScript does not support aten::enable_grad yet. 1175 # TODO: support aten::enable_grad in both TorchScript and Converter. 1176 return 1177 1178 def convert_prim_Exit(self, node: torch._C.Node): 1179 # export treats prim::Exit as noop 1180 return 1181 1182 def _convert_as_noop(self, node: torch._C.Node): 1183 # Converts the node as a no-op by mapping its output node as arg[0] 1184 1185 target = get_op_overload(node) 1186 schema = target._schema 1187 1188 args, kwargs = self.get_args_kwargs(node, schema) 1189 1190 output_name = node.output().debugName() 1191 self.name_to_node[output_name] = args[0] 1192 1193 def convert_profiler__record_function_exit(self, node: torch._C.Node): 1194 # _record_function_exit has side effect so we keep it in fx.graph 1195 # currently, _record_function_enter_new and _record_function_exit are 1196 # discarded during `retrace_as_exported_program`. 1197 target = torch.ops.profiler._record_function_exit 1198 args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs()) 1199 self.fx_graph.call_function(target, args) 1200 1201 def convert_prim_tolist(self, node: torch._C.Node): 1202 # prim::tolist cannot be supported by `_convert_standard_operators` 1203 # since it requires call_method instead of call_function. 1204 target = "tolist" 1205 args = (self.get_fx_value_by_ir_value(next(node.inputs())),) 1206 fx_node = self.fx_graph.call_method(target, args) 1207 output_name = node.output().debugName() 1208 self.name_to_node[output_name] = fx_node 1209 1210 def convert_prim_Uninitialized(self, node: torch._C.Node): 1211 # `prim::Uninitialized` is inserted by the compiler when it can prove 1212 # the value will never be used. It can be introduced by exceptions, 1213 # breaks, continues, and returns. 1214 # So we add a dummy constant to the graph. 1215 output_name = node.output().debugName() 1216 self.name_to_constant[output_name] = torch.Tensor() 1217 1218 def _convert_standard_operators(self, node: torch._C.Node): 1219 target = kind_to_standard_operators[node.kind()] 1220 args = tuple(self.get_fx_value_by_ir_value(input) for input in node.inputs()) 1221 fx_node = self.fx_graph.call_function(target, args) 1222 output_name = node.output().debugName() 1223 self.name_to_node[output_name] = fx_node 1224 1225 def convert_node(self, node: torch._C.Node): 1226 node_kind = node.kind() 1227 1228 # Get handler based on namespace and operator name. 1229 # Provide a default node handler as well in case we don't find 1230 # matching converter for that. 1231 handler_func_name = ir_name_to_func_name(node_kind) 1232 handler_func = getattr(self, handler_func_name, self.convert_call_function_op) 1233 1234 # str calls print function implemented in CPP. To avoid repeating 1235 # the entire logic here, we simply keep first line from node string (getting rid 1236 # of sub-blocks IR prints). 1237 node_str = "".join(str(node).split("\n")[:1]) 1238 log.debug("[%s] converts [%s]", handler_func.__name__, node_str) 1239 try: 1240 handler_func(node) 1241 except Exception as e: 1242 raise RuntimeError(f"TS2EPConverter failed for node {node_kind}") from e 1243 1244 def convert_graph_outputs(self): 1245 args = [] 1246 outp_name_list = [outp.debugName() for outp in self.ts_graph.outputs()] + list( 1247 self.name_update_from_subblock_to_parent 1248 ) 1249 for output_name in outp_name_list: 1250 if output_name in self.name_to_node: 1251 fx_node = self.name_to_node[output_name] 1252 # TODO: Revisit this later after HigherOrderOp design changes. 1253 # Currently, we cannot directly return input as output. 1254 if ( 1255 not self.is_top_level_graph() 1256 and isinstance(fx_node, torch.fx.Node) 1257 and fx_node.op == "placeholder" 1258 ): 1259 fx_node = self.fx_graph.call_function(torch.clone, (fx_node,)) 1260 args.append(fx_node) 1261 self.output_specs.append( 1262 OutputSpec( 1263 OutputKind.USER_OUTPUT, 1264 arg=TensorArgument(name=output_name), 1265 target=output_name, 1266 ) 1267 ) 1268 elif output_name in self.name_to_constant: 1269 args.append(self.name_to_constant[output_name]) 1270 self.output_specs.append( 1271 OutputSpec( 1272 OutputKind.USER_OUTPUT, 1273 arg=ConstantArgument( 1274 name=output_name, value=self.name_to_constant[output_name] 1275 ), 1276 target=output_name, 1277 ) 1278 ) 1279 else: 1280 raise ValueError(f"Output {output_name} not found") 1281 1282 if len(args) == 0: 1283 # Sub-block of prim::If can have zero output. 1284 self.fx_graph.output([]) 1285 elif len(args) == 1: 1286 self.fx_graph.output( 1287 args[0] 1288 ) # Get rid of an extra list wrapped around final output. 1289 elif len(args) > 1: 1290 self.fx_graph.output( 1291 args 1292 ) # For prim::Loop and prim::If with multiple outputs. 1293 else: 1294 # Sub-block of prim::Loop can have multiple outputs. 1295 self.fx_graph.output(args) 1296 1297 1298class ExplainTS2FXGraphConverter(TS2FXGraphConverter): 1299 """ 1300 Run TS2FXGraphConverter in an explain mode. It collects all failed operators conversions 1301 and provide that information to users. In order to collect all failed conversions, it 1302 also mocks some internal attributes (e.g., name_to_node). 1303 """ 1304 1305 class _DictMock(dict): 1306 def __init__(self, dict_data, mock_value): 1307 super().__init__(dict_data) 1308 self.mock_value = mock_value 1309 1310 def __getitem__(self, key): 1311 # If the original dictionary has the key, return its value. 1312 # Otherwise, return the mock value. 1313 if not super().__contains__(key): 1314 return self.mock_value 1315 return super().__getitem__(key) 1316 1317 def __contains__(self, key): 1318 return True 1319 1320 def __init__( 1321 self, 1322 ts_graph: Union[torch._C.Graph, torch._C.Block], 1323 name_to_param: Dict[str, torch.Tensor], 1324 name_to_buffer: Dict[str, torch.Tensor], 1325 blocks_to_lifted_attrs: Dict[torch._C.Block, Set[str]], 1326 name_to_non_tensor_attribute: Dict[str, Any], 1327 name_to_constant: Dict[str, Any], 1328 ): 1329 super().__init__( 1330 ts_graph, 1331 name_to_param, 1332 name_to_buffer, 1333 blocks_to_lifted_attrs, 1334 name_to_non_tensor_attribute, 1335 name_to_constant, 1336 ) 1337 1338 # Data to keep track of unsupported nodes. 1339 self.unsupported_node_list: List[torch._C.Node] = [] 1340 1341 # Add mock to needed attributes. 1342 self.name_to_node = ExplainTS2FXGraphConverter._DictMock( 1343 self.name_to_node, 1344 # Dummy node. 1345 torch.fx.Node( 1346 None, # type: ignore[arg-type] 1347 "mock", 1348 "call_function", 1349 lambda: None, 1350 (), 1351 {}, 1352 ), 1353 ) 1354 1355 def explain(self): 1356 self.convert_graph_inputs() 1357 for node in self.ts_graph.nodes(): 1358 self.convert_node(node) 1359 self.convert_graph_outputs() 1360 1361 def convert_node(self, node): 1362 try: 1363 super().convert_node(node) 1364 except Exception as e: 1365 self.unsupported_node_list.append(node) 1366 1367 1368@contextmanager 1369def disable_logging(log): 1370 disabled = log.disabled 1371 log.disabled = True 1372 try: 1373 yield 1374 finally: 1375 log.disabled = disabled 1376 1377 1378class TS2EPConverter: 1379 # TorchScript model to ExportedProgram converter 1380 def __init__( 1381 self, 1382 ts_model: Union[torch.jit.ScriptModule, torch.jit.ScriptFunction], 1383 sample_args: Tuple[Any, ...], 1384 sample_kwargs: Optional[Dict[str, Any]] = None, 1385 ): 1386 self.ts_model = ts_model 1387 self.ts_graph, self.params, _, _ = _create_jit_graph(ts_model, sample_args) 1388 1389 self.sample_args = sample_args 1390 self.sample_kwargs = sample_kwargs 1391 1392 self.name_to_param: Dict[str, torch.Tensor] = {} 1393 self.name_to_buffer: Dict[str, torch.Tensor] = {} 1394 param_list = ( 1395 list(self.ts_model.parameters()) 1396 if not isinstance(self.ts_model, torch._C.ScriptFunction) 1397 else [] 1398 ) 1399 if not isinstance(self.ts_model, torch._C.ScriptFunction): 1400 for k, tensor in self.ts_model.state_dict().items(): # type: ignore[union-attr] 1401 # Check if tensor belongs to any parameter. 1402 if any( 1403 (tensor == param).all() 1404 for param in param_list 1405 if tensor.shape == param.shape 1406 ): 1407 self.name_to_param[k] = tensor 1408 else: 1409 self.name_to_buffer[k] = tensor 1410 1411 self.name_to_non_tensor_attributes: Dict[str, Any] = {} 1412 self.name_to_constant: Dict[str, Any] = {} 1413 1414 self.lift_get_attr() 1415 1416 def convert(self) -> ExportedProgram: 1417 log.info( 1418 """ 1419TS2EPConverter logging starts from here. 1420 1421INFO: (TORCH_LOGS="export" <cmd>) 1422 * Log TorchScript IR. 1423 1424DEBUG: (TORCH_LOGS="+export" <cmd>), additionally 1425 * Log conversion IR by IR in a format of [<conversion handler name>] converts [<IR>]. 1426 """ 1427 ) 1428 log.info("TorchScript graph\n\n%s\n", self.ts_graph) 1429 1430 blocks_to_lifted_attrs = get_block_to_lifted_attrs(self.ts_graph) 1431 1432 graph_converter = TS2FXGraphConverter( 1433 self.ts_graph, 1434 self.name_to_param, 1435 self.name_to_buffer, 1436 blocks_to_lifted_attrs, 1437 self.name_to_non_tensor_attributes, 1438 self.name_to_constant, 1439 ) 1440 gm = graph_converter.convert() 1441 1442 # Post-proccessing step to deal with quantized operators. 1443 replace_quantized_ops_with_standard_ops(gm) 1444 log.info("GraphModule: %s", gm.print_readable(print_output=False)) 1445 1446 ep = self.retrace_as_exported_program( 1447 gm, 1448 graph_converter.name_to_constant, 1449 ) 1450 log.info("%s", ep) 1451 1452 # Post-processing step to ensure ExportedProgram has the same state_dict as 1453 # the original TorchScript model. Throw warnings for additionally populated 1454 # state_dict entries. 1455 if not isinstance(self.ts_model, torch._C.ScriptFunction): 1456 for k, tensor in self.ts_model.state_dict().items(): # type: ignore[union-attr] 1457 if k not in ep.state_dict: 1458 warnings.warn( 1459 f"Manually populate {k} into state_dict ExportedProgram, but it is never used by the ExportedProgram." 1460 ) 1461 ep.state_dict[k] = tensor 1462 1463 return ep 1464 1465 @disable_logging(log) 1466 def explain(self, print_output=True): 1467 blocks_to_lifted_attrs = get_block_to_lifted_attrs(self.ts_graph) 1468 1469 graph_converter = ExplainTS2FXGraphConverter( 1470 self.ts_graph, 1471 self.name_to_param, 1472 self.name_to_buffer, 1473 blocks_to_lifted_attrs, 1474 self.name_to_non_tensor_attributes, 1475 self.name_to_constant, 1476 ) 1477 graph_converter.explain() 1478 if len(graph_converter.unsupported_node_list) > 0: 1479 explain_str = "Unsupported nodes are found in the following list:" 1480 for i, n in enumerate(graph_converter.unsupported_node_list): 1481 node_str = "".join(str(n).split("\n")[:1]) 1482 explain_str += f"\n\n {i}. {n.kind()} [{node_str}]" 1483 else: 1484 explain_str = "Success!" 1485 if print_output: 1486 print(explain_str) 1487 return explain_str 1488 1489 def retrace_as_exported_program( 1490 self, 1491 gm: torch.fx.GraphModule, 1492 name_to_constant: Dict[str, Any], 1493 ): 1494 # TODO: adjust input orders to match GraphSignature convention 1495 ep = torch.export._trace._export( 1496 gm, 1497 self.sample_args, 1498 strict=False, 1499 pre_dispatch=True, 1500 ) 1501 1502 # Post-processing to make sure the ExportedProgram states are correct. 1503 # Because during conversion, we set tensor constants as GetAttr, 1504 # retracing cannot recognize them as tensor constants but instead 1505 # treat them as buffers. We need to set them again here. 1506 ep._constants.update( 1507 { 1508 k: v 1509 for k, v in name_to_constant.items() 1510 if isinstance(v, (torch.Tensor, torch.ScriptObject)) 1511 } 1512 ) 1513 for k in name_to_constant: 1514 ep.state_dict.pop(k, None) 1515 1516 for i, spec in enumerate(ep.graph_signature.input_specs): 1517 # Mark as constant tensors for erroneously traced buffers. 1518 if spec.kind == InputKind.BUFFER and spec.target in name_to_constant: 1519 assert isinstance( 1520 name_to_constant[spec.target], torch.Tensor 1521 ), f"{type(name_to_constant[spec.target])} has been erroneously marked as buffer" 1522 spec.kind = InputKind.CONSTANT_TENSOR 1523 ep.verifier().check(ep) 1524 1525 return ep 1526 1527 def lift_get_attr(self): 1528 # This function lifts multiple data types. 1529 1530 # 1. Tensor constants attributes (e.g., self.data = torch.tensor([2,3])) 1531 # to buffers. Currently, when there are tensor constants, export 1532 # would error and ask users to register tensor constants as buffers. 1533 # Since it is hard to manually do so for TorchScript models 1534 # (e.g., source code is missing), this function automatically 1535 # lifts tensor constants to be buffers. 1536 1537 # 2. ScriptObbject to constant. It will then be converted to getattr in 1538 # in the fx graph. 1539 # 1540 # This function should happen in TS2EPConverter instead of 1541 # TS2FXGraphConverter since it gets attributes from self.ts_model 1542 # which is not accessable in TS2FXGraphConverter. It is similar to where 1543 # we collect self.name_to_param and self.name_to_buffer. 1544 name_to_attribute_fqn: Dict[str, str] = {} 1545 1546 def get_attr(fqn: str): 1547 name = fqn.split(".") 1548 v = self.ts_model 1549 for n in name: 1550 v = getattr(v, n) 1551 return v 1552 1553 def get_fqn(node: torch._C.Node): 1554 attr_name = node.s("name") 1555 input_name = node.input().debugName() 1556 root_attr_name = name_to_attribute_fqn[input_name] 1557 attr_fqn = f"{root_attr_name}.{attr_name}" if root_attr_name else attr_name 1558 return attr_fqn 1559 1560 def _dfs_get_attr(block): 1561 for node in block.nodes(): 1562 if node.kind() == "prim::CreateObject": 1563 output_name = node.output().debugName() 1564 name_to_attribute_fqn[output_name] = "" 1565 1566 if node.kind() == "prim::GetAttr": 1567 attr_fqn = get_fqn(node) 1568 value = get_attr(attr_fqn) 1569 output_name = node.output().debugName() 1570 name_to_attribute_fqn[output_name] = attr_fqn 1571 if isinstance(value, torch.Tensor): 1572 if attr_fqn not in self.name_to_buffer: 1573 # Lift tensor constants to be a buffer 1574 self.name_to_buffer[attr_fqn] = value 1575 elif isinstance(value, torch.ScriptObject): 1576 if attr_fqn not in self.name_to_constant: 1577 self.name_to_constant[attr_fqn] = value 1578 else: 1579 self.name_to_non_tensor_attributes[attr_fqn] = value 1580 1581 for subblock in node.blocks(): 1582 _dfs_get_attr(subblock) 1583 1584 _dfs_get_attr(self.ts_graph) 1585