1# mypy: allow-untyped-defs 2import abc 3import copy 4import operator 5from collections import defaultdict 6from contextlib import contextmanager 7from copy import deepcopy 8from enum import Enum 9from typing import Any, cast, Dict, List, Optional, Set, Tuple, Union 10 11import torch 12import torch.fx._pytree as fx_pytree 13import torch.utils._pytree as pytree 14from torch._library.fake_class_registry import FakeScriptObject 15from torch.export._tree_utils import reorder_kwargs 16from torch.export.exported_program import ( 17 ConstantArgument, 18 ExportedProgram, 19 InputKind, 20 ModuleCallSignature, 21 SymIntArgument, 22 TensorArgument, 23) 24from torch.fx._symbolic_trace import is_fx_tracing 25from torch.fx.graph_module import _print_readable 26from torch.utils._pytree import GetAttrKey, SequenceKey 27 28from ._remove_effect_tokens_pass import _remove_effect_tokens 29 30 31__all__ = ["InterpreterModule", "UnflattenedModule", "unflatten", "FlatArgsAdapter"] 32 33 34class _AttrKind(Enum): 35 PARAMETER = "parameter" 36 BUFFER = "buffer" 37 CONSTANT = "constant" 38 39 40RUN_WITH_INTERPRETER = True 41 42 43@contextmanager 44def _disable_interpreter(): 45 global RUN_WITH_INTERPRETER 46 old_flag = RUN_WITH_INTERPRETER 47 RUN_WITH_INTERPRETER = False 48 try: 49 yield 50 finally: 51 RUN_WITH_INTERPRETER = old_flag 52 53 54# Assign attribute 'from_obj' to the qualified name 'target' on 'to_module 55# This installs empty Modules where none exist yet if they are subpaths of target 56def _assign_attr( 57 from_obj: Union[torch.Tensor, torch.ScriptObject], 58 to_module: torch.nn.Module, 59 target: str, 60 attr_kind: _AttrKind, 61 persistent: bool = True, 62): 63 *prefix, field = target.split(".") 64 for item in prefix: 65 t = getattr(to_module, item, None) 66 67 if t is None: 68 t = torch.nn.Module() 69 setattr(to_module, item, t) 70 to_module = t 71 72 if attr_kind == _AttrKind.PARAMETER: 73 assert isinstance(from_obj, torch.nn.Parameter) 74 to_module.register_parameter(field, from_obj) 75 elif attr_kind == _AttrKind.BUFFER: 76 assert isinstance(from_obj, torch.Tensor) 77 to_module.register_buffer(field, from_obj, persistent=persistent) 78 elif attr_kind == _AttrKind.CONSTANT: 79 assert not isinstance( 80 from_obj, FakeScriptObject 81 ), "FakeScriptObject should only exist during tracing." 82 assert isinstance( 83 from_obj, 84 ( 85 torch.Tensor, 86 torch.ScriptObject, 87 ), 88 ) 89 setattr(to_module, field, from_obj) 90 91 92class InterpreterModule(torch.nn.Module): 93 """A module that uses torch.fx.Interpreter to execute instead of the usual 94 codegen that GraphModule uses. This provides better stack trace information 95 and makes it easier to debug execution. 96 """ 97 98 def __init__( 99 self, 100 graph: torch.fx.Graph, 101 ): 102 super().__init__() 103 self.graph = graph 104 self.graph.owning_module = self 105 self._run_with_interpeter = RUN_WITH_INTERPRETER 106 107 def forward(self, *args, **kwargs): 108 assert self.graph_module is not None, "Didn't finalize this InterpreterModule" 109 if not is_fx_tracing() and ( 110 torch.compiler.is_dynamo_compiling() or not self._run_with_interpeter 111 ): 112 # Dynamo cannot trace through torch.fx.Interpreter, so fall back to 113 # GraphModule codegen in this instance. 114 # Patch the codegened forward to run with this InterpreterModule, 115 # so attribute accesses, etc. are on this module instead. 116 return type(self.graph_module).forward(self, *args, **kwargs) 117 else: 118 if kwargs: 119 # Handle **kwargs. FX only natively supports positional 120 # arguments (through placeholders). So in order to pass in 121 # kwargs, we must correspond the names of the placeholders with 122 # the keys in the kwarg dict. 123 arg_list = list(args) 124 kwarg_names = self.arg_names[len(arg_list) :] 125 for kwarg_name in kwarg_names: 126 if kwarg_name in kwargs: 127 arg_list.append(kwargs[kwarg_name]) 128 129 # Assert that the kwargs passed in exactly match the positional 130 # arguments specified by the GraphModule. This should be 131 # guaranteed by the unflattening process. 132 assert len(kwarg_names) == len(kwargs) 133 assert len(arg_list) == len(self.arg_names) 134 args = tuple(arg_list) 135 136 return torch.fx.Interpreter(self, graph=self.graph).run( 137 *args, enable_io_processing=False 138 ) 139 140 def finalize(self): 141 # We need to "finalize" because GraphModule populates its own state_dict 142 # based on the get_attrs observed in the graph. So we need to fully 143 # construct the graph and call _sink_params before generating this 144 # GraphModule. 145 146 # need to set `graph_module` directly on the dict to avoid it getting 147 # registered as a submodule. 148 self.__dict__["graph_module"] = torch.fx.GraphModule(self, self.graph) 149 self.graph.lint() 150 151 # Cache arg names for kwarg handling (see forward()) 152 self.arg_names = [] 153 for node in self.graph.nodes: 154 if node.op == "placeholder": 155 self.arg_names.append(node.target) 156 157 def print_readable( 158 self, 159 print_output=True, 160 include_stride=False, 161 include_device=False, 162 colored=False, 163 ): 164 return _print_readable( 165 self, 166 "InterpreterModule", 167 print_output, 168 include_stride, 169 include_device, 170 colored, 171 ) 172 173 174class FlatArgsAdapter(abc.ABC): 175 """ 176 Adapts input arguments with ``input_spec`` to align ``target_spec``. 177 """ 178 179 @abc.abstractmethod 180 def adapt( 181 self, 182 target_spec: pytree.TreeSpec, 183 input_spec: pytree.TreeSpec, 184 input_args: List[Any], 185 ) -> List[Any]: 186 """NOTE: This adapter may mutate given ``input_args_with_path``.""" 187 ... 188 189 190class UnflattenedModule(torch.nn.Module): 191 def __init__( 192 self, 193 export_module: ExportedProgram, 194 flat_args_adapter: Optional[FlatArgsAdapter] = None, 195 ): 196 super().__init__() 197 if export_module.graph_signature.backward_signature is not None: 198 raise ValueError("Unflattening on JointExportModule NYI") 199 200 fqn_list = [entry.fqn for entry in export_module.module_call_graph] 201 assert fqn_list[0] == "" 202 export_graph = deepcopy(export_module.graph) 203 self.graph_signature = deepcopy(export_module.graph_signature) 204 self.graph = torch.fx.Graph() 205 self.module_call_graph = deepcopy(export_module.module_call_graph) 206 self.flat_args_adapter = flat_args_adapter 207 # Flag to indicate whether args have been adapted. 208 self.adapted = False 209 self._run_with_interpeter = RUN_WITH_INTERPRETER 210 211 _inplace_buffer_mutations(export_graph, self.graph_signature) 212 _outline_submodules(export_graph, self) 213 214 self.range_constraints = export_module.range_constraints 215 self.equality_constraints: List = [] 216 217 # aliasing/unused param or buffer issues: 218 # in strict-mode export, dynamo export will deduplicate aliased tensors, 219 # and ignore unused tensors. For aliasing, this causes issues when some aliases 220 # are unused, and we're unable to match the placeholder node to the correct FQN. 221 # This leads to the graph signature potentially having the wrong target FQN, 222 # and downstream issues where parameters are assigned to the wrong target attribute, 223 # mismatching the relevant placeholder node in the unflattened module. 224 # To resolve this we restore (_assign_attr) all aliased/unused tensors in 225 # the state_dict as module attributes, but only keep the used tensors in the 226 # graph's forward pass (_sink_params). 227 state_dict = export_module.state_dict 228 assigned_params: Set[str] = set() # tracking unused params 229 id_to_param: Dict[int, torch.nn.Parameter] = {} # handling weight-sharing 230 for name in self.graph_signature.parameters: # this loop adds used params 231 param = state_dict[name] 232 if id(param) not in id_to_param: 233 id_to_param[id(param)] = torch.nn.Parameter( 234 param.clone(), requires_grad=param.requires_grad 235 ) 236 237 _assign_attr( 238 id_to_param[id(param)], 239 self, 240 name, 241 attr_kind=_AttrKind.PARAMETER, 242 ) 243 assigned_params.add(name) 244 245 non_persistent_buffers = set(self.graph_signature.non_persistent_buffers) 246 assigned_buffers: Set[str] = set() # tracking unused buffers 247 id_to_buffer: Dict[ 248 int, Tuple[torch.nn.Parameter, bool] 249 ] = {} # handle weight-sharing 250 for name in self.graph_signature.buffers: # this loop adds used buffers 251 if name in non_persistent_buffers: 252 persistent = False 253 buffer = export_module.constants[name] 254 else: 255 persistent = True 256 buffer = state_dict[name] 257 258 if id(buffer) not in id_to_buffer: 259 id_to_buffer[id(buffer)] = (buffer.clone(), persistent) 260 261 _assign_attr( 262 id_to_buffer[id(buffer)][0], 263 self, 264 name, 265 attr_kind=_AttrKind.BUFFER, 266 persistent=persistent, 267 ) 268 assigned_buffers.add(name) 269 270 # restore aliased/unused params and buffers 271 # these appear in state dict but not graph signature 272 for name, tensor in state_dict.items(): 273 if name in assigned_params or name in assigned_buffers: # already assigned 274 continue 275 276 is_buffer = False 277 if id(tensor) in id_to_buffer or not isinstance( 278 tensor, torch.nn.Parameter 279 ): # aliased buffer 280 is_buffer = True 281 282 if is_buffer: 283 if ( 284 id(tensor) not in id_to_buffer 285 ): # this is completely unused (not weight-sharing) 286 id_to_buffer[id(tensor)] = ( 287 tensor, 288 True, 289 ) # assign to respect original model 290 _assign_attr( 291 id_to_buffer[id(tensor)][0], 292 self, 293 name, 294 attr_kind=_AttrKind.BUFFER, 295 persistent=True, 296 ) 297 else: 298 if id(tensor) not in id_to_param: # this is unused 299 id_to_param[id(tensor)] = tensor 300 _assign_attr( 301 id_to_param[id(tensor)], 302 self, 303 name, 304 attr_kind=_AttrKind.PARAMETER, 305 ) 306 307 # use id map so we don't double-clone aliased constants 308 id_to_const: Dict[int, Union[torch.Tensor, torch._C.ScriptObject]] = {} 309 for fqn, constant in export_module.constants.items(): 310 if id(constant) not in id_to_const: 311 if isinstance(constant, torch.Tensor): 312 constant = constant.clone() 313 id_to_const[id(constant)] = constant 314 _constant = id_to_const[id(constant)] 315 _assign_attr( 316 _constant, 317 self, 318 fqn, 319 attr_kind=_AttrKind.CONSTANT, 320 ) 321 322 # This is to handle parameters/buffers that point to the same tensor 323 # object id -> list of (node_name, target_name) 324 consts_map: Dict[int, List[Tuple[str, str]]] = defaultdict(list) 325 consts_targets: Set[str] = set() 326 327 def add_to_consts_map(obj_id, node_name, target_name): 328 name_list = consts_map[obj_id] 329 name_list.append((node_name, target_name)) 330 331 added_params_buffers: Set[str] = set() # track aliased/unused params, buffers 332 for s in self.graph_signature.input_specs: 333 if s.kind == InputKind.PARAMETER or ( 334 s.kind == InputKind.BUFFER and s.persistent 335 ): 336 assert hasattr(s.arg, "name") 337 assert isinstance(s.target, str) 338 add_to_consts_map( 339 id(export_module.state_dict[s.target]), s.arg.name, s.target 340 ) 341 consts_targets.add(s.target) 342 added_params_buffers.add(s.target) 343 elif ( 344 (s.kind == InputKind.BUFFER and not s.persistent) 345 or s.kind == InputKind.CONSTANT_TENSOR 346 or s.kind == InputKind.CUSTOM_OBJ 347 ): 348 assert hasattr(s.arg, "name") 349 assert isinstance(s.target, str) 350 add_to_consts_map( 351 id(export_module.constants[s.target]), s.arg.name, s.target 352 ) 353 consts_targets.add(s.target) 354 355 # add constants that are aliased and don't appear in graph signature 356 for const_name, const in export_module.constants.items(): 357 if const_name not in consts_targets: 358 assert ( 359 id(const) in consts_map 360 ), "Constants should be either aliased or appear in graph signature" 361 ph_name, _ = consts_map[id(const)][0] 362 add_to_consts_map(id(const), ph_name, const_name) 363 added_params_buffers.add(s.target) 364 365 # add aliased/unused params and buffers that don't appear in graph signature 366 for fqn, tensor in export_module.state_dict.items(): 367 if fqn not in added_params_buffers: 368 if id(tensor) not in consts_map: 369 # completely unused (no weight-sharing), ignore. 370 # this weight doesn't appear in graph module, 371 # so won't cause FQN assignment issues 372 continue 373 ph_name, _ = consts_map[id(tensor)][0] 374 add_to_consts_map(id(tensor), ph_name, fqn) 375 376 # node name -> list of possible targets 377 inputs_to_state: Dict[str, List[str]] = {} 378 for node_target in consts_map.values(): 379 targets = [t[1] for t in node_target] 380 for n, _ in node_target: 381 inputs_to_state[n] = targets 382 383 _sink_params(self, inputs_to_state, []) 384 385 # Helper function to check input nodes of `module` has been processed. 386 def check_module_inputs(module, scope): 387 if hasattr(module, "graph"): 388 for node in module.graph.nodes: 389 # sink_params() should turn placeholders into get_attr nodes 390 # for attributes that are within scope of the current 391 # module. We allow attributes to remain as placeholders if 392 # they are inputs in the original module signature, meaning 393 # they are a parent module's attribute, and therefore out of 394 # scope of the current module. 395 if ( 396 node.op == "placeholder" 397 and node.name in inputs_to_state 398 and any( 399 fqn.split(".")[: len(scope)] == scope 400 for fqn in inputs_to_state[node.name] 401 ) # matching scope to avoid wrong assert 402 ): 403 raise AssertionError( 404 f"{node.name} was not sunk into the module {scope} which has the graph: {module.graph}" 405 ) 406 # Recursively check the submodules. 407 for name, submod in module.named_children(): 408 scope.append(name) 409 check_module_inputs(submod, scope) 410 411 # Recurively check all input nodes have been processed. 412 check_module_inputs(self, []) 413 414 # Cache so we don't have to compute this every time. 415 # NOTE: this needs to be kept in sync with the placeholders in 416 # self.graph, but currently we have no way to guarantee that. 417 self.input_placeholders = [ 418 node for node in self.graph.nodes if node.op == "placeholder" 419 ] 420 self.check_input_constraints = True 421 # TODO(zhxchen17) We can register modules ahead of time instead of reorder later. 422 fqn_order = {fqn: i for i, fqn in enumerate(fqn_list)} 423 # In the case of legacy IR, we might be missing some modules from metadata. 424 for name, _ in self.named_modules(remove_duplicate=False): 425 if name not in fqn_order: 426 fqn_order[name] = len(fqn_order) 427 _reorder_submodules(self, fqn_order) 428 assert [fqn for fqn, _ in self.named_modules(remove_duplicate=False)] == list( 429 fqn_order.keys() 430 ) 431 self.graph.lint() 432 433 def _print_graph(self): 434 for fqn, mod in self.named_modules(): 435 print(fqn + ":") 436 if hasattr(mod, "graph") and isinstance(mod.graph, torch.fx.Graph): 437 print(mod.graph) 438 439 def forward(self, *args, **kwargs): 440 signature = self.module_call_graph[0].signature 441 442 reordered_kwargs = reorder_kwargs(kwargs, signature.in_spec) 443 444 flat_args_with_path, in_spec = pytree.tree_flatten_with_path( 445 (args, reordered_kwargs) 446 ) 447 flat_args = [x[1] for x in flat_args_with_path] 448 if is_fx_tracing(): 449 return_val = torch.fx.Interpreter(self, graph=self.graph).run( 450 *flat_args, enable_io_processing=False 451 ) 452 # For scalar return value, fx.Graph wraps in a tuple 453 if isinstance(return_val, tuple) and len(return_val) == 1: 454 return return_val[0] 455 return return_val 456 457 if in_spec != signature.in_spec: 458 if not self.adapted: 459 print( 460 "Input treespec does not match with exported module's: \n" 461 f"Input treespec: {in_spec}. ", 462 f"Exported module treespec: {signature.in_spec}", 463 ) 464 if self.flat_args_adapter is None: 465 raise TypeError( 466 "There is no flat args adapter sepcified. " 467 "Are you sure you are calling this with the right arguments? " 468 ) 469 else: 470 if not self.adapted: 471 print("Adapting flat arg to match exported module's treespec") 472 flat_args = self.flat_args_adapter.adapt( 473 target_spec=signature.in_spec, 474 input_spec=in_spec, 475 input_args=flat_args, 476 ) 477 self.adapted = True 478 if len(flat_args) != signature.in_spec.num_leaves: 479 raise TypeError( 480 f"Flat args adaption failed, number of args mismatch " 481 f"Adatped: {len(flat_args)} \n" 482 f"Exported module: {signature.in_spec.num_leaves}" 483 ) 484 485 if self.check_input_constraints: 486 # Import here to avoid an unfortunate circular dependency. 487 # TODO(suo): untangle this. 488 from torch._export.utils import _check_input_constraints_for_graph 489 490 if self.adapted is True: 491 # TODO(suo): The FlatArgsAdapter returns a list of flat args, 492 # which we don't have keypaths for. For now, just create a dummy 493 # keypath to associate with the arg. 494 new_flat_args_with_path = [ # type: ignore[var-annotated] 495 ((SequenceKey(idx=0), GetAttrKey(name="<unknown location>")), arg) 496 for arg in flat_args 497 ] 498 else: 499 new_flat_args_with_path = flat_args_with_path # type: ignore[assignment] 500 501 _check_input_constraints_for_graph( 502 self.input_placeholders, new_flat_args_with_path, self.range_constraints 503 ) 504 if torch.compiler.is_dynamo_compiling() and not self._run_with_interpreter: 505 tree_out = torch.fx.GraphModule(self, self.graph)(*flat_args) 506 else: 507 tree_out = torch.fx.Interpreter(self, graph=self.graph).run( 508 *flat_args, enable_io_processing=False 509 ) 510 return pytree.tree_unflatten(tree_out, signature.out_spec) 511 512 def print_readable( 513 self, 514 print_output=True, 515 include_stride=False, 516 include_device=False, 517 colored=False, 518 ): 519 return _print_readable( 520 self, 521 "UnflattenedModule", 522 print_output, 523 include_stride, 524 include_device, 525 colored, 526 ) 527 528 529def unflatten( 530 module: ExportedProgram, flat_args_adapter: Optional[FlatArgsAdapter] = None 531) -> UnflattenedModule: 532 """Unflatten an ExportedProgram, producing a module with the same module 533 hierarchy as the original eager module. This can be useful if you are trying 534 to use :mod:`torch.export` with another system that expects a module 535 hierachy instead of the flat graph that :mod:`torch.export` usually produces. 536 537 .. note:: The args/kwargs of unflattened modules will not necessarily match 538 the eager module, so doing a module swap (e.g. :code:`self.submod = 539 new_mod`) will not necessarily work. If you need to swap a module out, you 540 need to set the :code:`preserve_module_call_signature` parameter of 541 :func:`torch.export.export`. 542 543 Args: 544 module (ExportedProgram): The ExportedProgram to unflatten. 545 flat_args_adapter (Optional[FlatArgsAdapter]): Adapt flat args if input TreeSpec does not match with exported module's. 546 547 Returns: 548 An instance of :class:`UnflattenedModule`, which has the same module 549 hierarchy as the original eager module pre-export. 550 """ 551 module = _remove_effect_tokens(module) 552 return UnflattenedModule(module, flat_args_adapter) 553 554 555def _inplace_buffer_mutations(graph: torch.fx.Graph, graph_signature) -> None: 556 """Transform buffer mutations from their functionalized form into a copy_ 557 node in the graph. 558 559 Functionalization represents buffer mutation by passing the buffer as an input and output. So for example, the eager code: 560 def forward(self, x): 561 self.buffer += x 562 return x * x 563 564 Will become a graph that looks like: 565 def forward(self, buffer, x): 566 mutated_buffer = aten.add(buffer, x) 567 mul = aten.mul(x, x) 568 return (mutated_buffer, mul) 569 570 We want to inplace this into something that looks like the original eager code: 571 def forward(self, buffer, x): 572 mutated_buffer = aten.add(buffer, x) 573 buffer.copy_(mutated_buffer) 574 mul = aten.mul(x, x) 575 return (mul,) 576 """ 577 output_node = next(iter(reversed(graph.nodes))) 578 assert output_node.op == "output" and len(output_node.args) == 1 579 return_args = output_node.args[0] 580 581 mutation_node_to_buffer = graph_signature.buffers_to_mutate 582 mutations = return_args[: len(mutation_node_to_buffer)] 583 buffers_to_inputs = {v: k for k, v in graph_signature.inputs_to_buffers.items()} 584 input_name_to_node = { 585 node.name: node for node in graph.nodes if node.op == "placeholder" 586 } 587 588 for mutation in mutations: 589 buffer_name = mutation_node_to_buffer[mutation.name] 590 input_name = buffers_to_inputs[buffer_name] 591 input_node = input_name_to_node[input_name] 592 593 with graph.inserting_after(mutation): 594 new_node = graph.create_node( 595 "call_function", torch.ops.aten.copy_, (input_node, mutation) 596 ) 597 for k, v in mutation.meta.items(): 598 new_node.meta[k] = v 599 # Replace all uses of the previously functional mutation with our copy_ output. 600 mutation.replace_all_uses_with(new_node, lambda x: x is not new_node) 601 602 # Remove the mutated buffer from the graph outputs, since we don't need to 603 # thread it through anymore. We don't need to handle the inputs, which will 604 # be handled by _sink_params. 605 user_outputs = tuple( 606 return_args[len(mutation_node_to_buffer) :], 607 ) 608 output_node.args = ((user_outputs),) 609 610 611def _is_prefix(candidate, target): 612 """Check whether `candidate` is a prefix of `target`.""" 613 return len(candidate) < len(target) and target[: len(candidate)] == candidate 614 615 616def _compute_accessor(parent_fqn: str, child_fqn: str) -> str: 617 if parent_fqn == "": 618 # Handle the root module correctly. 619 return child_fqn 620 621 parent_split = parent_fqn.split(".") 622 child_split = child_fqn.split(".") 623 624 # TODO: support skip connection by inlining the child module. 625 if child_split[: len(parent_split)] != parent_split: 626 raise RuntimeError( 627 f"Child module '{child_fqn}' is not a descendant of parent mldule '{parent_fqn}'." 628 "This is currently unsupported." 629 "Please try to make child module attach to parent module direclty." 630 ) 631 return ".".join(child_split[len(parent_split) :]) 632 633 634def _verify_graph_equivalence(x: torch.nn.Module, y: torch.nn.Module): 635 def graph_dump(graph: torch.fx.Graph) -> str: 636 ret = [] 637 nodes_idx: Dict[int, int] = {} 638 639 def arg_dump(arg) -> str: 640 if isinstance(arg, torch.fx.Node): 641 return "%" + str(nodes_idx[id(arg)]) 642 return str(arg) 643 644 for i, node in enumerate(graph.nodes): 645 args_dump = [str(arg) for arg in pytree.tree_map(arg_dump, node.args)] 646 args_dump += [ 647 f"{key}={value}" 648 for key, value in pytree.tree_map(arg_dump, node.kwargs).items() 649 ] 650 target = node.target if node.op == "call_function" else "" 651 ret.append(f"{i}: {node.op}[{target}]({', '.join(args_dump)})") 652 nodes_idx[id(node)] = i 653 return "\n".join(ret) 654 655 assert graph_dump(x.graph) == graph_dump(y.graph) 656 657 658def _add_spec(gm: torch.nn.Module, spec) -> str: 659 i = 0 660 while hasattr(gm, f"_spec_{i}"): 661 i += 1 662 name = f"_spec_{i}" 663 setattr(gm, name, spec) 664 return name 665 666 667def _generate_flatten(gm: torch.nn.Module, node, spec) -> torch.fx.Node: 668 name = _add_spec(gm, spec) 669 spec_node = gm.graph.get_attr(name) 670 return gm.graph.call_function(fx_pytree.tree_flatten_spec, (node, spec_node)) 671 672 673def _generate_unflatten(gm: torch.nn.Module, nodes, spec) -> torch.fx.Node: 674 name = _add_spec(gm, spec) 675 spec_node = gm.graph.get_attr(name) 676 return gm.graph.call_function(pytree.tree_unflatten, (nodes, spec_node)) 677 678 679def _get_submodule(mod: torch.nn.Module, target: str): 680 *prefix, field = target.split(".") 681 682 for item in prefix: 683 submod = getattr(mod, item, None) 684 685 if submod is None: 686 return None 687 688 if not isinstance(submod, torch.nn.Module): 689 return None 690 691 mod = submod 692 693 return getattr(mod, field, None) 694 695 696def _add_submodule(mod: torch.nn.Module, target: str, module_to_add: torch.nn.Module): 697 *prefix, field = target.split(".") 698 699 for item in prefix: 700 submod = getattr(mod, item, None) 701 702 if submod is None: 703 submod = torch.nn.Module() 704 setattr(mod, item, submod) 705 706 if not isinstance(submod, torch.nn.Module): 707 return False 708 709 mod = submod 710 711 mod.add_module(field, module_to_add) 712 713 714class _ModuleFrame: 715 def __init__( 716 self, 717 flat_graph: torch.fx.Graph, 718 nodes: Tuple[torch.fx.Node, ...], 719 seen_nodes, 720 seen_modules, 721 parent, 722 module_stack: List[str], 723 module_id, 724 module_call_graph: Dict[str, ModuleCallSignature], 725 module: Optional[torch.nn.Module] = None, 726 ): 727 self.flat_graph = flat_graph 728 self.nodes = nodes 729 self.seen_nodes = seen_nodes 730 self.seen_modules = seen_modules 731 self.parent = parent 732 self.module_stack = module_stack 733 self.module_id = module_id 734 735 self.module_call_graph = module_call_graph 736 self.verbose = False 737 738 self.fqn = self.module_stack[-1] 739 if module is not None: 740 self.module = module 741 else: 742 self.module = InterpreterModule(torch.fx.Graph()) 743 if self.module_id in self.seen_modules: 744 self.cached_graph_module = self.seen_modules[self.module_id] 745 else: 746 self.cached_graph_module = None 747 self.seen_modules[self.module_id] = self.module 748 749 self.graph = self.module.graph 750 751 # Mapping of nodes in the flat graph to nodes in this graph. 752 self.node_map: Dict[torch.fx.Node, torch.fx.Node] = {} 753 self.node_to_placeholder = {} 754 755 self.parent_call_module: Optional[torch.fx.Node] = None 756 if parent is not None: 757 accessor = _compute_accessor(parent.fqn, self.fqn) 758 _add_submodule( 759 parent.module, 760 accessor, 761 ( 762 self.module 763 if self.cached_graph_module is None 764 else self.cached_graph_module 765 ), 766 ) 767 self.parent_call_module = parent.graph.call_module(accessor) 768 769 signature = module_call_graph.get(self.fqn) 770 if signature is not None and self.parent is not None: 771 assert signature.in_spec.num_children == 2 772 args_spec = signature.in_spec.children_specs[0] 773 kwargs_spec = signature.in_spec.children_specs[1] 774 assert args_spec.context is None 775 assert kwargs_spec.context is not None 776 777 with self.graph.inserting_after(None): 778 arg_nodes = [] 779 for idx in range(args_spec.num_children): 780 arg_nodes.append(self.graph.placeholder(f"_positional_arg_{idx}")) 781 kwarg_nodes = {} 782 for name in kwargs_spec.context: 783 kwarg_nodes[name] = self.graph.placeholder(name) 784 flat_args = _generate_flatten( 785 self.module, 786 (tuple(arg_nodes), kwarg_nodes), 787 signature.in_spec, 788 ) 789 for idx, arg in enumerate(signature.inputs): 790 flat_arg_node = self.graph.create_node( 791 op="call_function", 792 target=operator.getitem, 793 args=(flat_args, idx), 794 name=( 795 arg.name 796 if not isinstance(arg, ConstantArgument) 797 else f"_constant_{idx}" 798 ), 799 ) 800 if isinstance(arg, ConstantArgument): 801 continue 802 803 if arg.name in self.seen_nodes: 804 flat_arg_node.meta = copy.copy(self.seen_nodes[arg.name].meta) 805 self.node_to_placeholder[ 806 self.seen_nodes[arg.name] 807 ] = flat_arg_node 808 809 with self.parent.graph.inserting_before(self.parent_call_module): 810 input_nodes: List[Optional[torch.fx.Node]] = [] 811 for input in signature.inputs: 812 if isinstance(input, ConstantArgument) and input.value is None: 813 input_nodes.append(None) 814 elif input.name not in self.seen_nodes: 815 input_nodes.append(None) 816 else: 817 assert isinstance(input, (TensorArgument, SymIntArgument)) 818 input_nodes.append( 819 self.parent.remap_input(self.seen_nodes[input.name]) 820 ) 821 822 inputs_node = _generate_unflatten( 823 self.parent.module, 824 input_nodes, 825 signature.in_spec, 826 ) 827 828 args_node = self.parent.graph.call_function( 829 operator.getitem, (inputs_node, 0) 830 ) 831 kwargs_node = self.parent.graph.call_function( 832 operator.getitem, (inputs_node, 1) 833 ) 834 arg_nodes = [ 835 self.parent.graph.call_function(operator.getitem, (args_node, i)) 836 for i in range(args_spec.num_children) 837 ] 838 kwarg_nodes = { 839 k: self.parent.graph.call_function( 840 operator.getitem, (kwargs_node, k) 841 ) 842 for k in kwargs_spec.context 843 } 844 assert self.parent_call_module is not None 845 self.parent_call_module.args = tuple(arg_nodes) 846 self.parent_call_module.kwargs = kwarg_nodes 847 848 def add_placeholder(self, x): 849 assert self.fqn != "", f"Cannot add placeholder {x} to root module" 850 assert x.graph is self.flat_graph 851 # x is not in subgraph, create a new placeholder for subgraph 852 with self.graph.inserting_before(None): 853 placeholder_node = self.graph.placeholder(x.name, type_expr=x.type) 854 # copy all meta fields, even if some fields might be irrelvant for 855 # the placeholder node 856 placeholder_node.meta = copy.copy(x.meta) 857 self.node_to_placeholder[x] = placeholder_node 858 859 def copy_sym_call_function(self, x): 860 # This only exists because we deduplicate sym_size nodes in the flat export graph, 861 # and if preserve_module_call_signature is set, we may not be able to pass sym_size 862 # nodes, or their downstream users, as inputs to submodule calls. 863 # To avoid this we copy these call_function nodes with sym_type results. 864 # This should however only be done for sym_type nodes - call_function nodes on tensors 865 # should not be deduplicated in the first place. 866 args = pytree.tree_map_only(torch.fx.Node, self.remap_input, x.args) 867 kwargs = pytree.tree_map_only(torch.fx.Node, self.remap_input, x.kwargs) 868 node = self.graph.call_function(x.target, args, kwargs) 869 node.meta = copy.copy(x.meta) 870 self.node_map[x] = node 871 return node 872 873 def remap_input(self, x): 874 assert x.graph is self.flat_graph 875 if x in self.node_map: 876 return self.node_map[x] 877 self.print(f"remap_input({x})") 878 if x in self.node_to_placeholder: 879 return self.node_to_placeholder[x] 880 elif ( 881 x.op == "placeholder" 882 or self.module_call_graph.get(self.fqn) is None 883 # allow placeholder creation if we are not preserving module call signature 884 ): 885 self.add_placeholder(x) 886 if self.parent_call_module is not None: 887 # Important to *prepend* the output to match how we are 888 # inserting placeholder nodes. 889 with self.parent.graph.inserting_before(self.parent_call_module): 890 self.parent_call_module.insert_arg(0, self.parent.remap_input(x)) 891 return self.node_to_placeholder[x] 892 elif x.op == "call_function" and ( 893 x.target 894 in ( 895 torch.ops.aten.sym_size.int, 896 torch.ops.aten.item.default, 897 torch.ops.aten.unbind.int, 898 torch.ops.aten.sum.dim_IntList, 899 torch.ops.aten.view.default, 900 torch.ops.aten.diff.default, 901 ) 902 or (hasattr(x.target, "__module__") and x.target.__module__ == "_operator") 903 ): 904 # export deduplicates sym_size nodes, and may need to re-copy them 905 # if module call signature needs to be preserved 906 self.copy_sym_call_function(x) 907 return self.node_map[x] 908 else: 909 raise RuntimeError( 910 f"Could not run remap_input() on op type: {x.op} for node {x}" 911 ) 912 913 def finalize_outputs(self): 914 orig_outputs = [] 915 916 signature = self.module_call_graph.get(self.fqn) 917 if signature is not None and self.parent is not None: 918 for output in signature.outputs: 919 if isinstance(output, (TensorArgument, SymIntArgument)): 920 if output.name in self.seen_nodes: 921 orig_outputs.append(self.seen_nodes[output.name]) 922 else: 923 orig_outputs.append(None) 924 else: 925 raise RuntimeError( 926 f"Unsupported data type for output node: {output}" 927 ) 928 929 def get_actual_output_node(output): 930 if output is None: 931 return None 932 933 seen_node = self.seen_nodes[output.name] 934 if seen_node in self.node_map: 935 return self.node_map[seen_node] 936 elif seen_node in self.node_to_placeholder: 937 return self.node_to_placeholder[seen_node] 938 else: 939 raise RuntimeError( 940 f"Could not find output node {output}. Graph: {self.graph}" 941 ) 942 943 tree_out_node = _generate_unflatten( 944 self.module, 945 tuple(get_actual_output_node(output) for output in orig_outputs), 946 signature.out_spec, 947 ) 948 parent_out: Optional[torch.fx.Node] = _generate_flatten( 949 self.parent.module, self.parent_call_module, signature.out_spec 950 ) 951 graph_outputs: Union[torch.fx.Node, List[torch.fx.Node]] = tree_out_node 952 else: 953 graph_outputs = [] 954 # Iterate through nodes we have copied into self.graph. 955 for orig_node in self.node_map.keys(): 956 for user_node in orig_node.users: 957 if user_node.name not in self.seen_nodes: 958 # external user node, need to expose as an output 959 orig_outputs.append(orig_node) 960 graph_outputs.append(self.node_map[orig_node]) 961 break 962 963 parent_out = self.parent_call_module 964 if len(graph_outputs) == 1: 965 graph_outputs = graph_outputs[0] 966 967 assert isinstance(graph_outputs, (list, torch.fx.Node)) 968 969 self.graph.output(graph_outputs) 970 971 # Rewrite outputs in parent module 972 if parent_out is None: 973 return 974 975 parent_out.meta["val"] = ( 976 graph_outputs.meta.get("val") 977 if isinstance(graph_outputs, torch.fx.Node) 978 else [o.meta.get("val") for o in graph_outputs] 979 ) 980 981 if len(orig_outputs) == 1 and signature is None: 982 self.parent.node_map[orig_outputs[0]] = parent_out 983 else: 984 for i, orig_output in enumerate(orig_outputs): 985 if orig_output is None: 986 continue 987 # Use Proxy to record getitem access. 988 proxy_out = torch.fx.Proxy(parent_out)[i].node # type: ignore[index] 989 proxy_out.meta["val"] = orig_output.meta.get("val") 990 self.parent.node_map[orig_output] = proxy_out 991 992 if self.cached_graph_module is not None: 993 _verify_graph_equivalence(self.cached_graph_module, self.module) 994 995 def copy_node(self, node): 996 self.print("copying", node.format_node()) 997 self.node_map[node] = self.graph.node_copy(node, self.remap_input) 998 self.seen_nodes[node.name] = node 999 1000 def run_outer(self): 1001 i = 0 1002 for node in self.flat_graph.nodes: 1003 self.print(i, node.meta.get("nn_module_stack"), node.format_node()) 1004 i += 1 1005 1006 # Copy all graph inputs 1007 node_idx: int = 0 1008 node = self.nodes[node_idx] 1009 while node.op == "placeholder": 1010 self.copy_node(node) 1011 node_idx += 1 1012 node = self.nodes[node_idx] 1013 1014 self.run_from(node_idx) 1015 1016 # Copy graph outputs 1017 for node in self.flat_graph.nodes: 1018 if node.op == "output": 1019 self.copy_node(node) 1020 1021 def print(self, *args, **kwargs): 1022 if self.verbose: 1023 print(*args, **kwargs) 1024 1025 def run_from(self, node_idx): 1026 module_idx = 0 1027 # Walk through the graph, building up a new graph with the right submodules 1028 while node_idx < len(self.nodes): 1029 node = self.nodes[node_idx] 1030 assert node.op != "placeholder" 1031 1032 self.print() 1033 self.print("STEP", node_idx, node.format_node()) 1034 self.print(self.module_stack) 1035 if node.op == "output": 1036 if len(self.module_stack) == 1: 1037 # We want the output node of the original graph to be handled 1038 # specially by the outermost stack frame (in run_outer). So 1039 # skip finalization here. 1040 return node_idx 1041 1042 # We've reached the end of the graph. Wrap up all the existing stack frames. 1043 self.finalize_outputs() 1044 return node_idx 1045 1046 if len(node.meta.get("nn_module_stack", {})) == 0: 1047 raise RuntimeError(f"Unable to find nn_module_stack for node {node}") 1048 1049 nn_module_stack = node.meta["nn_module_stack"] 1050 from torch._export.passes._node_metadata_hook import ( 1051 _EMPTY_NN_MODULE_STACK_KEY, 1052 ) 1053 1054 if ( 1055 len(nn_module_stack) == 1 1056 and _EMPTY_NN_MODULE_STACK_KEY in nn_module_stack 1057 ): 1058 # Empty case from the node_metadata_hook 1059 node_module_stack = self.module_stack 1060 else: 1061 node_module_stack = [ 1062 path for path, ty in node.meta["nn_module_stack"].values() 1063 ] 1064 1065 if node_module_stack[: len(self.module_stack)] != self.module_stack: 1066 # This means that the current module is done executing and the 1067 # current node is the beginning of a new module. 1068 # 1069 # In this case, we should finalize this module and return without 1070 # incrementing the node counter. 1071 self.finalize_outputs() 1072 self.print("outlining", self.fqn) 1073 self.print(self.graph) 1074 return node_idx 1075 1076 assert node_module_stack is not None 1077 1078 if _is_prefix(self.module_stack, node_module_stack): 1079 # This means that the current node represents the execution of a new 1080 # module. 1081 next_module = node_module_stack[len(self.module_stack)] 1082 self.print("Creating new stack frame for", next_module) 1083 # Run a nested version of module outliner from the current node 1084 # counter. Once it is complete, continue from that point. 1085 node_idx = _ModuleFrame( 1086 self.flat_graph, 1087 self.nodes, 1088 self.seen_nodes, 1089 self.seen_modules, 1090 self, 1091 self.module_stack + [next_module], 1092 list(node.meta["nn_module_stack"].keys())[len(self.module_stack)], 1093 self.module_call_graph, 1094 ).run_from(node_idx) 1095 module_idx += 1 1096 continue 1097 1098 # The only remaining possibility is that we are in the right stack 1099 # frame. Copy the node into this frame's graph and increment the node counter. 1100 assert node_module_stack == self.module_stack 1101 self.copy_node(node) 1102 node_idx += 1 1103 1104 1105def _outline_submodules(orig_graph: torch.fx.Graph, root_module: UnflattenedModule): 1106 seen_nodes: Dict[str, torch.fx.Node] = {} 1107 seen_modules: Dict[int, torch.nn.Module] = {} 1108 _ModuleFrame( 1109 orig_graph, 1110 tuple(orig_graph.nodes), 1111 seen_nodes, 1112 seen_modules, 1113 None, 1114 [""], 1115 "", 1116 { 1117 entry.fqn: entry.signature 1118 for entry in root_module.module_call_graph 1119 if entry.signature 1120 }, 1121 module=root_module, 1122 ).run_outer() 1123 1124 1125def _reorder_submodules( 1126 parent: torch.nn.Module, fqn_order: Dict[str, int], prefix: str = "" 1127): 1128 # TODO Can be optimized by adding submodules ahead of time. 1129 if prefix == "": 1130 for fqn in list(fqn_order.keys())[1:]: 1131 if _get_submodule(parent, fqn) is None: 1132 _add_submodule(parent, fqn, torch.nn.Module()) 1133 1134 children = [] 1135 for name, child in list(parent._modules.items()): 1136 if child is None: 1137 continue 1138 fqn = prefix + name 1139 _reorder_submodules(child, fqn_order, prefix=fqn + ".") 1140 delattr(parent, name) 1141 children.append((fqn_order[fqn], name, child)) 1142 children.sort(key=operator.itemgetter(0)) 1143 for _, name, child in children: 1144 parent.register_module(name, child) 1145 1146 1147def _sink_params( 1148 module: torch.nn.Module, 1149 inputs_to_state: Dict[str, List[str]], 1150 scope: List[str], 1151): 1152 """Sink params, buffers, and constants from graph inputs into get_attr nodes. 1153 1154 Exported modules are purely functional, so they pass their parameters and 1155 buffers in as inputs to the graph. 1156 1157 To replicate eager's semantics, we need to get them from the module state 1158 via get_attr instead. 1159 1160 module: GraphModule, potentially containining nested submodules. 1161 inputs_to_state: mapping graph input names to the corresponding key in the state_dict. 1162 scope: tracks where we are in the module hierarchy, so that we can emit the 1163 right `getattr(self, "foo.bar")` calls, etc. 1164 """ 1165 # This dict records inputs removed by child modules. 1166 # Maps the module object id to the list of placeholder node names 1167 # in the child module that were removed. 1168 module_id_to_inputs_removed: Dict[int, List[str]] = defaultdict(list) 1169 1170 # We need to use _modules here instead of named_children(), because we 1171 # explicitly want duplicate modules to show up in the traversal. 1172 for name, submodule in module._modules.items(): 1173 submod_id_to_inputs_removed = _sink_params( 1174 cast(torch.nn.Module, submodule), inputs_to_state, scope + [name] 1175 ) 1176 for k, v in submod_id_to_inputs_removed.items(): 1177 module_id_to_inputs_removed[k].extend(v) 1178 1179 if not hasattr(module, "graph"): 1180 # Not all modules have graphs defined, if they are empty modules with no operations (like ParameterList) 1181 return module_id_to_inputs_removed 1182 1183 graph = module.graph 1184 inputs = list(filter(lambda n: n.op == "placeholder", graph.nodes)) 1185 the_last_input = inputs[-1] 1186 1187 # Also remove from call_module nodes 1188 call_module_nodes = filter(lambda n: n.op == "call_module", graph.nodes) 1189 for node in call_module_nodes: 1190 submodule = _recursive_getattr(module, node.target.split(".")) 1191 # remove placeholder from call_module node arguments, only if we've 1192 # erased the placeholder node in the corresponding _sink_params() call 1193 if submodule is not None and id(submodule) in module_id_to_inputs_removed: 1194 node.args = tuple( 1195 filter( 1196 lambda n: n.name not in module_id_to_inputs_removed[id(submodule)], 1197 node.args, 1198 ) 1199 ) 1200 1201 # Filter out inputs_to_state corresponding to current scope. 1202 inputs_to_state_of_scope: Dict[torch.fx.Node, list[str]] = {} 1203 for node in inputs: 1204 if node.name not in inputs_to_state: 1205 continue 1206 1207 state_name = None 1208 for sn in inputs_to_state[node.name]: 1209 sn_split = sn.split(".") 1210 if sn_split[: len(scope)] == scope: 1211 state_name = sn_split 1212 break 1213 1214 # If there's a mismatch beteewn scope name and state name, then 1215 # there must be multuple scopes pointing to the same state name, 1216 # meaning some modules are shared. In such case, we can simply skip 1217 # updating the current node because another later iteration will 1218 # take care of this input node when the unique match between scope 1219 # and state name occurs. To make sure this always happen, we should 1220 # enforce the invariant that no placeholder node in the unflattened 1221 # graph appears in inputs_to_state dict, which means all the extra 1222 # input nodes have been handled. 1223 if state_name is None: 1224 continue 1225 1226 inputs_to_state_of_scope[node] = state_name 1227 1228 # Record name of remove inputs for return purpose. 1229 inputs_removed: List[str] = [] 1230 1231 for node, state_name in inputs_to_state_of_scope.items(): 1232 if len(node.users) > 0: 1233 attr_path = state_name[len(scope) :] 1234 state_attr = _recursive_getattr(module, attr_path) 1235 assert isinstance(state_attr, (torch.Tensor, torch.ScriptObject)) 1236 1237 # Make sure the newly created get_attr node is placed after the last placeholder node 1238 with graph.inserting_after(the_last_input): 1239 new_node = graph.create_node("get_attr", ".".join(attr_path)) 1240 1241 node.replace_all_uses_with(new_node, propagate_meta=True) 1242 1243 graph.erase_node(node) 1244 inputs_removed.append(node.name) 1245 1246 if isinstance(module, InterpreterModule): 1247 module.finalize() 1248 1249 return {id(module): inputs_removed} 1250 1251 1252def _recursive_getattr(obj, attr_path): 1253 for attr in attr_path: 1254 if not hasattr(obj, attr): 1255 return None 1256 obj = getattr(obj, attr) 1257 1258 return obj 1259