1# mypy: allow-untyped-defs 2""" 3This module dispatches the graphs to either the forward-only or joint compilation 4pathways, taking into account the AOTConfig and the collected ViewAndMutationMetadata. 5""" 6 7import dataclasses 8from typing import Any, List, Optional, Tuple 9 10import torch 11import torch.utils._pytree as pytree 12import torch.utils.dlpack 13from torch import Tensor 14from torch._dispatch.python import enable_python_dispatcher 15from torch._dynamo.utils import lazy_format_graph_code 16from torch._logging import getArtifactLogger, trace_structured 17from torch._subclasses.functional_tensor import FunctionalTensorMode 18from torch.fx.experimental.proxy_tensor import make_fx 19from torch.utils._python_dispatch import _detect_infra_mode 20 21from .. import config 22from .functional_utils import ( 23 assert_functional_graph, 24 propagate_input_mutation_stacktraces, 25) 26from .schemas import AOTConfig, SubclassMeta, ViewAndMutationMeta 27from .traced_function_transforms import ( 28 aot_dispatch_subclass, 29 create_functionalized_fn, 30 create_joint, 31 fn_input_mutations_to_outputs, 32 fn_prepped_for_autograd, 33 handle_effect_tokens_fn, 34) 35from .utils import ( 36 copy_fwd_metadata_to_bw_nodes, 37 root_module_when_exporting_non_strict, 38 unlift_tokens, 39) 40 41 42aot_graphs_log = getArtifactLogger(__name__, "aot_graphs") 43 44 45def _create_graph(f, args, *, aot_config: AOTConfig) -> torch.fx.GraphModule: 46 # FunctionalTensorMode must be enabled here. 47 # See Note [Accessing .grad_fn on FunctionalTensor] 48 with enable_python_dispatcher(), FunctionalTensorMode( 49 pre_dispatch=aot_config.pre_dispatch, 50 export=aot_config.is_export, 51 # Allow token discovery for joint fn tracing as tokens can be used in backward. 52 _allow_token_discovery=True, 53 ): 54 fx_g = make_fx( 55 f, 56 decomposition_table=aot_config.decompositions, 57 record_module_stack=True, 58 pre_dispatch=aot_config.pre_dispatch, 59 )(*args) 60 61 return fx_g 62 63 64def aot_dispatch_base_graph( 65 flat_fn, 66 flat_args: List[Tensor], 67 aot_config: AOTConfig, 68 *, 69 fw_metadata: ViewAndMutationMeta, 70) -> Tuple[torch.fx.GraphModule, List[Any], Optional[SubclassMeta]]: 71 # aot_dispatch_base requires functionalization, but doesn't need to handle as many cases as the autograd case. 72 # The cases that aot_dispatch_base doesn't need to handle include: 73 # - outputs that are aliases of graph intermediates 74 # - outputs that are aliases of graph inputs 75 # While cases that it does need to handle include: 76 # - input mutations (including when inputs are aliases of each other) 77 # - input metadata mutations 78 fn_to_trace = fn_input_mutations_to_outputs( 79 flat_fn, 80 fw_metadata, 81 keep_data_input_mutations=aot_config.keep_inference_input_mutations, 82 ) 83 84 fn_to_trace, updated_flat_args = create_functionalized_fn( 85 fn_to_trace, 86 flat_args, 87 meta=fw_metadata, 88 aot_config=aot_config, 89 trace_joint=False, 90 ) 91 92 # TODO: replace with AOTDispatchSubclassWrapper once we refactor 93 # fn_input_mutations_to_outputs and create_functionalized_fn 94 # into CompilerWrappers. 95 ( 96 fn_to_trace, 97 updated_flat_args_subclasses_desugared, 98 maybe_subclass_meta, 99 ) = aot_dispatch_subclass( 100 fn_to_trace, 101 updated_flat_args, 102 is_joint_structure=False, 103 meta=fw_metadata, 104 fw_only=flat_fn, 105 ) 106 107 (fn_to_trace, updated_flat_args_subclasses_desugared) = handle_effect_tokens_fn( 108 fn_to_trace, 109 updated_flat_args_subclasses_desugared, 110 meta=fw_metadata, 111 trace_joint=False, 112 ) 113 114 aot_graphs_log.debug( 115 "aot_config id: %s, fw_metadata=%s,subclass_metadata=%s", 116 str(aot_config.aot_id), 117 str(fw_metadata), 118 str(maybe_subclass_meta), 119 ) 120 121 # We track buffer assignments when exporting in non-strict mode. 122 # (In contrast, strict mode errors on any attribute assignment.) 123 mod_when_exporting_non_strict = root_module_when_exporting_non_strict(flat_fn) 124 if aot_config.is_export and mod_when_exporting_non_strict is not None: 125 # For any buffer that is assigned, we want to associate it to the final proxy node 126 # that it is assigned to. This node can then be added as a buffer mutation output. 127 assigned_buffers = {} 128 129 def _map_assigned_buffer_to_proxy(_mod, name, buffer): 130 # We intercept buffer assignments on the root module through this hook. 131 if _mod._buffers is mod_when_exporting_non_strict._buffers: 132 # The value assigned to a buffer is a functional tensor, which wraps a fake tensor. 133 assert isinstance( 134 buffer, torch._subclasses.functional_tensor.FunctionalTensor 135 ) 136 fake = buffer.from_functional() 137 # The fake tensor in turn is associated with a proxy node. 138 proxy_mode = _detect_infra_mode(torch._C._TorchDispatchModeKey.PROXY) 139 assert proxy_mode is not None 140 proxy = torch.fx.experimental.proxy_tensor.get_proxy_slot( 141 fake, proxy_mode.tracer 142 ).proxy.node 143 # We map the assigned buffer to this proxy node. 144 assigned_buffers[name] = proxy.name 145 return buffer 146 147 handle = torch.nn.modules.module.register_module_buffer_registration_hook( 148 _map_assigned_buffer_to_proxy 149 ) 150 151 saved_updated_flat_args_subclasses_desugared = pytree.tree_map_only( 152 torch.Tensor, lambda t: t.detach(), updated_flat_args_subclasses_desugared 153 ) 154 fw_module = _create_graph( 155 fn_to_trace, 156 updated_flat_args_subclasses_desugared, 157 aot_config=aot_config, 158 ) 159 160 if aot_config.is_export and mod_when_exporting_non_strict is not None: 161 # We update metadata to consider any assigned buffers as buffer mutations. 162 i = len(dict(mod_when_exporting_non_strict.named_parameters())) 163 for name, _ in mod_when_exporting_non_strict.named_buffers(): 164 if name in assigned_buffers and not fw_metadata.input_info[i].mutates_data: # type: ignore[possibly-undefined] 165 fw_metadata.input_info[i] = dataclasses.replace( 166 fw_metadata.input_info[i], mutates_data=True 167 ) 168 fw_metadata.num_mutated_inp_runtime_indices += 1 169 i += 1 170 171 # We add nodes corresponding to buffer assignments as output nodes in the graph. 172 add_nodes = [] 173 output_node = None 174 output_node = list(fw_module.graph.nodes)[-1] 175 for name in assigned_buffers.values(): # type: ignore[possibly-undefined] 176 for node in fw_module.graph.nodes: 177 if node.name == name: 178 add_nodes.append(node) 179 node.users[output_node] = None 180 output_node.args = ((*add_nodes, *output_node.args[0]),) 181 182 handle.remove() # type: ignore[possibly-undefined] 183 184 # As long as we opted to remove input mutations, then 185 # there should be *NO* mutating ops in the graph at this point. 186 copy_count = assert_functional_graph(fw_module.graph) 187 fw_module.graph.eliminate_dead_code() 188 fw_module.recompile() 189 190 copy_count2 = assert_functional_graph(fw_module.graph) 191 propagate_input_mutation_stacktraces(fw_module.graph) 192 193 # See Note [Side-Effectful Tokens in AOTAutograd] 194 num_tokens = len(fw_metadata.tokens) 195 if num_tokens != 0 and config.unlift_effect_tokens: 196 unlift_tokens(fw_module, fw_metadata, aot_config) 197 saved_updated_flat_args_subclasses_desugared = ( 198 saved_updated_flat_args_subclasses_desugared[num_tokens:] 199 ) 200 201 assert copy_count == copy_count2 202 203 if aot_config.enable_log: 204 aot_graphs_log.info( 205 "%s", 206 lazy_format_graph_code( 207 "Forward graph", 208 fw_module, 209 aot_config.aot_id, 210 include_stride=True, 211 include_device=True, 212 colored=True, 213 ), 214 ) 215 trace_structured( 216 "aot_forward_graph", 217 payload_fn=lambda: fw_module.print_readable( 218 print_output=False, include_stride=True, include_device=True 219 ), 220 ) 221 222 # TODO: should factor this into a separate function for export that always only returns just the graph. 223 if aot_config.is_export: 224 assert ( 225 maybe_subclass_meta is None 226 ), "aot_export_module does not support tensor subclass inputs for now." 227 return fw_module, saved_updated_flat_args_subclasses_desugared, maybe_subclass_meta 228 229 230# Has the precondition that there 231# are no duplicate arguments in flat_args (e.g., the same Tensor 232# object never shows up twice. However, two tensor inputs MAY alias 233# the same storage, so long as they have separate TensorImpls.) 234def aot_dispatch_autograd_graph( 235 flat_fn, 236 flat_args: List[Any], 237 aot_config: AOTConfig, 238 *, 239 fw_metadata: ViewAndMutationMeta, 240) -> Tuple[torch.fx.GraphModule, Tuple[List[Any], List[Any]], Optional[SubclassMeta]]: 241 # traced_tangents corresponds to the set of outputs in the traced forward that should get grad_outputs in the traced backward. 242 # It includes outputs of the original forward, *and* any updated inputs due to input mutations. 243 # However, it does *not* include any outputs that are aliases of inputs or intermediates, or any metadata-only input mutations. 244 joint_inputs = (flat_args, fw_metadata.traced_tangents) 245 246 fn_prepared_for_autograd = fn_prepped_for_autograd( 247 flat_fn, 248 fw_metadata, 249 ) 250 joint_fn_to_trace = create_joint(fn_prepared_for_autograd, aot_config=aot_config) 251 252 joint_fn_to_trace, updated_joint_inputs = create_functionalized_fn( 253 joint_fn_to_trace, 254 joint_inputs, 255 meta=fw_metadata, 256 aot_config=aot_config, 257 trace_joint=True, 258 ) 259 260 # TODO: replace with AOTDispatchSubclassWrapper once we refactor 261 # fn_input_mutations_to_outputs and create_functionalized_fn 262 # into CompilerWrappers. 263 subclass_tracing_info = aot_dispatch_subclass( 264 joint_fn_to_trace, 265 updated_joint_inputs, 266 is_joint_structure=True, 267 meta=fw_metadata, 268 fw_only=flat_fn, 269 ) 270 271 joint_fn_to_trace = subclass_tracing_info.plain_tensor_trace_fn 272 updated_joint_inputs = subclass_tracing_info.plain_tensor_args 273 274 (joint_fn_to_trace, updated_joint_inputs) = handle_effect_tokens_fn( 275 joint_fn_to_trace, 276 updated_joint_inputs, 277 meta=fw_metadata, 278 trace_joint=True, 279 ) 280 281 # When we call _create_graph, this may mutate the metadata of joint 282 # inputs. But callers are expecting to get the original joint inputs. So 283 # we make aliases of all the inputs to make sure we have a copy that 284 # doesn't get modified. 285 # 286 # This destroys requires_grad/grad_fn information. However, backends 287 # beneath AOTAutograd are indifferent to this information, so it doesn't 288 # matter. 289 saved_updated_joint_inputs = pytree.tree_map_only( 290 torch.Tensor, lambda t: t.detach(), updated_joint_inputs 291 ) 292 maybe_subclass_meta = subclass_tracing_info.maybe_subclass_meta 293 294 fx_g = _create_graph(joint_fn_to_trace, updated_joint_inputs, aot_config=aot_config) 295 296 # There should be *NO* mutating ops in the graph at this point. 297 assert_functional_graph(fx_g.graph) 298 299 # Redundant with the check above, but worth having in case tracing introduced 300 # a fake tensor. Unlikely. 301 # See Note: [Fake Modules and AOTAutograd] 302 torch._dynamo.utils.assert_no_fake_params_or_buffers(fx_g) 303 fx_g.graph.eliminate_dead_code() 304 copy_fwd_metadata_to_bw_nodes(fx_g) 305 fx_g.recompile() 306 307 # TODO: in AOTAutograd, we create metadata like _indices_of_inps_to_detach to detect 308 # when we need to manually detach() some inputs in the forward. 309 # Higher order ops might eventually need to do the same. 310 if aot_config.is_export: 311 assert ( 312 maybe_subclass_meta is None 313 ), "aot_export_module does not support tensor subclass inputs for now." 314 return fx_g, saved_updated_joint_inputs, maybe_subclass_meta 315