1# mypy: allow-untyped-defs 2import copy 3import warnings 4from itertools import chain 5from typing import Any, Dict, List, Optional, Tuple 6 7import torch 8import torch.utils._pytree as pytree 9from torch._export.utils import _check_input_constraints_for_graph 10from torch.export.unflatten import _assign_attr, _AttrKind, _recursive_getattr 11from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo 12 13from ._remove_effect_tokens_pass import _remove_effect_tokens 14from .exported_program import ( 15 ExportedProgram, 16 ExportGraphSignature, 17 InputKind, 18 OutputKind, 19) 20 21 22@torch._dynamo.disable 23def _check_input_constraints_pre_hook(self, *args, **kwargs): 24 flat_args_with_path, received_spec = pytree.tree_flatten_with_path(args) 25 26 if received_spec != self._in_spec: 27 raise ValueError( # noqa: B904 28 "Trying to flatten user inputs with exported input tree spec: \n" 29 f"{self._in_spec}\n" 30 "but actually got inputs with tree spec of: \n" 31 f"{received_spec}" 32 ) 33 34 return _check_input_constraints_for_graph( 35 [node for node in self.graph.nodes if node.op == "placeholder"], 36 flat_args_with_path, 37 self.range_constraints, 38 ) 39 40 41def _unlift_inputs_as_getattr( 42 gm: torch.fx.GraphModule, 43 lifted_inputs: List[Optional[str]], 44) -> Tuple[Dict[str, torch.fx.Node], Dict[str, torch.fx.Node]]: 45 """ 46 Unlift inputs referring to params/buffers/constants as getattr nodes in the 47 graph 48 """ 49 unlifted_name_to_node = {} 50 input_name_to_node = {} 51 52 placeholder_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"] 53 assert len(lifted_inputs) == len(placeholder_nodes) 54 for input_node, lifted_node in zip(placeholder_nodes, lifted_inputs): 55 if lifted_node is None: 56 input_name_to_node[input_node.name] = input_node 57 58 else: 59 with gm.graph.inserting_after(input_node): 60 getattr_node = gm.graph.get_attr(lifted_node) 61 input_node.replace_all_uses_with(getattr_node) 62 metadata = input_node.meta 63 gm.graph.erase_node(input_node) 64 getattr_node.meta = metadata 65 unlifted_name_to_node[lifted_node] = getattr_node 66 67 return unlifted_name_to_node, input_name_to_node 68 69 70def _insert_copy_for_mutations( 71 gm: torch.fx.GraphModule, 72 mutated_outputs: List[Optional[str]], 73 unlifted_name_to_node: Dict[str, torch.fx.Node], 74 input_name_to_node: Dict[str, torch.fx.Node], 75) -> None: 76 """ 77 Find the all the buffers and inputs that were mutated and insert copy_ 78 operators to reflect mutations. 79 """ 80 output_node = None 81 for node in gm.graph.nodes: 82 if node.op == "output": 83 output_node = node 84 break 85 assert output_node is not None 86 outputs = pytree.tree_flatten(output_node.args)[0] 87 assert len(outputs) == len(mutated_outputs) 88 89 user_output_nodes = [] 90 return_nodes_to_copy = {} 91 for return_node, mutated_node_name in zip(outputs, mutated_outputs): 92 if mutated_node_name is None: 93 user_output_nodes.append(return_node) 94 continue 95 96 if mutated_node_name in unlifted_name_to_node: 97 mutated_node = unlifted_name_to_node[mutated_node_name] 98 elif mutated_node_name in input_name_to_node: 99 mutated_node = input_name_to_node[mutated_node_name] 100 else: 101 raise RuntimeError( 102 f"Could not find {mutated_node_name} in either buffer or input nodes" 103 ) 104 105 with gm.graph.inserting_before(output_node): 106 copy_node = gm.graph.call_function( 107 torch.ops.aten.copy_.default, (mutated_node, return_node) 108 ) 109 return_nodes_to_copy[return_node] = copy_node 110 111 output_args = [ 112 return_nodes_to_copy[node] if node in return_nodes_to_copy else node 113 for node in user_output_nodes 114 ] 115 with gm.graph.inserting_before(output_node): 116 # Only return user outputs 117 new_output = gm.graph.output(tuple(output_args)) 118 new_output.meta.update(output_node.meta) 119 output_node.replace_all_uses_with(new_output) 120 gm.graph.erase_node(output_node) 121 122 123def _get_codegen( 124 in_spec: pytree.TreeSpec, 125 out_spec: Optional[pytree.TreeSpec], 126 forward_arg_names: Optional[List[str]] = None, 127) -> _PyTreeCodeGen: 128 """ 129 Create the codegen for the graph module based on the in/out specs 130 """ 131 if forward_arg_names: 132 names = forward_arg_names 133 else: 134 if ( 135 in_spec.type == tuple 136 and in_spec.num_children == 2 137 and in_spec.children_specs[0].type == tuple 138 and in_spec.children_specs[1].type == dict 139 ): 140 # if in_spec contains the args (tuple) and kwargs (dict) 141 names = [f"arg_{i}" for i in range(in_spec.children_specs[0].num_children)] 142 # add kwarg names 143 names.extend(in_spec.children_specs[1].context) 144 else: 145 names = [f"arg_{i}" for i in range(in_spec.num_children)] 146 147 return _PyTreeCodeGen( 148 _PyTreeInfo( 149 names, 150 in_spec, 151 out_spec, 152 ) 153 ) 154 155 156def _unlift( 157 gm: torch.fx.GraphModule, 158 lifted_inputs: List[Optional[str]], 159 mutated_outputs: List[Optional[str]], 160 in_spec: pytree.TreeSpec, 161 out_spec: Optional[pytree.TreeSpec], 162 state_dict: Dict[str, Any], 163 constants: Dict[str, Any], 164 forward_arg_names: Optional[List[str]] = None, 165): 166 """ 167 Args: 168 lifted_inputs: A list matching the graph module's input nodes. For 169 an input node that is referring to a lifted parameter/buffer, this 170 list will contain the fqn the corresponding attribute. Otherwise, this 171 list will contain None. This is used to unlift the lifted parameters as 172 get_attr nodes. 173 174 mutated_outputs: A list matching the graph module's output nodes. For 175 an output node that is referring to a mutated buffer or user input, this 176 list will contain the name of the corresponding buffer or user input 177 that needs to be mutated. Otherwise, this list will contain None. This 178 is used to re-insert an inplace copy_ operator to copy the mutated 179 values back to the original node. 180 """ 181 unlifted_name_to_node, input_name_to_node = _unlift_inputs_as_getattr( 182 gm, lifted_inputs 183 ) 184 _insert_copy_for_mutations( 185 gm, mutated_outputs, unlifted_name_to_node, input_name_to_node 186 ) 187 gm.graph._codegen = _get_codegen(in_spec, out_spec, forward_arg_names) 188 gm.graph.lint() 189 gm.recompile() 190 return gm 191 192 193def _register_attrs_to_new_gm( 194 new_gm: torch.fx.GraphModule, 195 graph_signature: ExportGraphSignature, 196 state_dict: Dict[str, Any], 197 constants: Dict[str, Any], 198) -> None: 199 non_persistent_buffers = set(graph_signature.non_persistent_buffers) 200 for name in graph_signature.buffers: 201 if name in non_persistent_buffers: 202 persistent = False 203 value = constants[name] 204 else: 205 persistent = True 206 value = state_dict[name] 207 _assign_attr( 208 value, new_gm, name, attr_kind=_AttrKind.BUFFER, persistent=persistent 209 ) 210 for name in graph_signature.parameters: 211 value = state_dict[name] 212 _assign_attr( 213 value, 214 new_gm, 215 name, 216 attr_kind=_AttrKind.PARAMETER, 217 ) 218 219 for name in chain( 220 graph_signature.lifted_custom_objs, graph_signature.lifted_tensor_constants 221 ): 222 value = constants[name] 223 _assign_attr( 224 value, 225 new_gm, 226 name, 227 attr_kind=_AttrKind.CONSTANT, 228 ) 229 230 231class _StatefulGraphModuleFactory(type): 232 """ 233 Metaclass that ensures a private constructor for _StatefulGraphModule 234 """ 235 236 def __call__(cls, *args, **kwargs): 237 raise TypeError( 238 f"{cls.__module__}.{cls.__qualname__} has no public constructor. " 239 ) 240 241 def _create(cls, root, graph, range_constraints=None): 242 return super().__call__( 243 root, 244 graph, 245 range_constraints=range_constraints, 246 ) 247 248 249class _StatefulGraphModule(torch.fx.GraphModule, metaclass=_StatefulGraphModuleFactory): 250 def __init__(self, root, graph, range_constraints=None): 251 super().__init__(root, graph) 252 # Need to fix up non-persistent buffers. 253 self.range_constraints = range_constraints or [] 254 255 256def _create_stateful_graph_module( 257 plain_graph_module: torch.fx.GraphModule, 258 range_constraints, 259 # TODO(suo) this should not be optional, but is since we still ahve 260 # capture_pre_autograd_graph grr 261 graph_signature: Optional[ExportGraphSignature] = None, 262): 263 stateful_gm = _StatefulGraphModule._create( 264 plain_graph_module, 265 plain_graph_module.graph, 266 range_constraints=range_constraints, 267 ) 268 269 stateful_gm.register_forward_pre_hook( 270 _check_input_constraints_pre_hook, with_kwargs=True 271 ) 272 273 if graph_signature is None: 274 return stateful_gm 275 276 # Fix up lifted tensor constants. 277 # fx.GraphModule() constructor silently turns a constant attribute of plain_graph_module 278 # into a buffer in stateful_gm and creates an inconsistency with graph_signature. 279 # We fix this by de-registering these buffers in lifted_tensor_constants 280 # and call _assign_attr(attr_kind=CONSTANT) to register them as constants. 281 for constant_fqn in graph_signature.lifted_tensor_constants: 282 # Sometimes, the constant can require gradient, this is probably a bug in user code, 283 # e.g. `self.const = torch.randn(2, 2, requires_grad=True)`. 284 # We call detach on the constant_val since they're tensor contants and we don't need to 285 # compute their gradients anyway. 286 # Users should properly register it as parameter if they want it to require gradient. 287 buffer = stateful_gm.get_buffer(constant_fqn) 288 if buffer.requires_grad: 289 warnings.warn( 290 f"A model attribute `{constant_fqn}` requires gradient. " 291 f"but it's not properly registered as a parameter. " 292 f"torch.export will detach it and treat it as a constant tensor " 293 f"but please register it as parameter instead." 294 ) 295 buffer = buffer.detach() 296 *prefix, field = constant_fqn.rsplit(".") 297 submod = _recursive_getattr(stateful_gm, prefix) 298 delattr(submod, field) 299 _assign_attr(buffer, stateful_gm, constant_fqn, attr_kind=_AttrKind.CONSTANT) 300 301 # Fix up non-persistent buffers. torch.fx does not distinguish between 302 # persistent and non-persistent buffers, so we must restore that distinction 303 # here. 304 for buffer in graph_signature.non_persistent_buffers: 305 _assign_attr( 306 plain_graph_module.get_buffer(buffer), 307 stateful_gm, 308 buffer, 309 attr_kind=_AttrKind.BUFFER, 310 persistent=False, 311 ) 312 313 return stateful_gm 314 315 316def _unlift_exported_program_lifted_states(ep: ExportedProgram) -> torch.nn.Module: 317 ep = _remove_effect_tokens(ep) 318 new_gm = torch.fx.GraphModule(ep.graph_module, copy.deepcopy(ep.graph)) 319 _register_attrs_to_new_gm(new_gm, ep.graph_signature, ep.state_dict, ep.constants) 320 forward_arg_names = ep.graph_module.meta.get("forward_arg_names") 321 322 lifted_inputs: List[Optional[str]] = [ 323 ( 324 in_spec.target 325 if in_spec.kind 326 in ( 327 InputKind.BUFFER, 328 InputKind.CONSTANT_TENSOR, 329 InputKind.PARAMETER, 330 InputKind.CUSTOM_OBJ, 331 ) 332 else None 333 ) 334 for in_spec in ep.graph_signature.input_specs 335 ] 336 337 mutated_outputs: List[Optional[str]] = [ 338 ( 339 out_spec.target 340 if out_spec.kind 341 in (OutputKind.BUFFER_MUTATION, OutputKind.USER_INPUT_MUTATION) 342 else None 343 ) 344 for out_spec in ep.graph_signature.output_specs 345 ] 346 347 new_gm = _unlift( 348 new_gm, 349 lifted_inputs, 350 mutated_outputs, 351 ep.call_spec.in_spec, 352 ep.call_spec.out_spec, 353 ep.state_dict, 354 ep.constants, 355 forward_arg_names=forward_arg_names, 356 ) 357 unlift_gm = _create_stateful_graph_module( 358 new_gm, ep.range_constraints, ep.graph_signature 359 ) 360 unlift_gm.meta.update(ep.graph_module.meta) 361 return unlift_gm 362