xref: /aosp_15_r20/external/pytorch/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""
3This module dispatches the graphs to either the forward-only or joint compilation
4pathways, taking into account the AOTConfig and the collected ViewAndMutationMetadata.
5"""
6
7import dataclasses
8from typing import Any, List, Optional, Tuple
9
10import torch
11import torch.utils._pytree as pytree
12import torch.utils.dlpack
13from torch import Tensor
14from torch._dispatch.python import enable_python_dispatcher
15from torch._dynamo.utils import lazy_format_graph_code
16from torch._logging import getArtifactLogger, trace_structured
17from torch._subclasses.functional_tensor import FunctionalTensorMode
18from torch.fx.experimental.proxy_tensor import make_fx
19from torch.utils._python_dispatch import _detect_infra_mode
20
21from .. import config
22from .functional_utils import (
23    assert_functional_graph,
24    propagate_input_mutation_stacktraces,
25)
26from .schemas import AOTConfig, SubclassMeta, ViewAndMutationMeta
27from .traced_function_transforms import (
28    aot_dispatch_subclass,
29    create_functionalized_fn,
30    create_joint,
31    fn_input_mutations_to_outputs,
32    fn_prepped_for_autograd,
33    handle_effect_tokens_fn,
34)
35from .utils import (
36    copy_fwd_metadata_to_bw_nodes,
37    root_module_when_exporting_non_strict,
38    unlift_tokens,
39)
40
41
42aot_graphs_log = getArtifactLogger(__name__, "aot_graphs")
43
44
45def _create_graph(f, args, *, aot_config: AOTConfig) -> torch.fx.GraphModule:
46    # FunctionalTensorMode must be enabled here.
47    # See Note [Accessing .grad_fn on FunctionalTensor]
48    with enable_python_dispatcher(), FunctionalTensorMode(
49        pre_dispatch=aot_config.pre_dispatch,
50        export=aot_config.is_export,
51        # Allow token discovery for joint fn tracing as tokens can be used in backward.
52        _allow_token_discovery=True,
53    ):
54        fx_g = make_fx(
55            f,
56            decomposition_table=aot_config.decompositions,
57            record_module_stack=True,
58            pre_dispatch=aot_config.pre_dispatch,
59        )(*args)
60
61    return fx_g
62
63
64def aot_dispatch_base_graph(
65    flat_fn,
66    flat_args: List[Tensor],
67    aot_config: AOTConfig,
68    *,
69    fw_metadata: ViewAndMutationMeta,
70) -> Tuple[torch.fx.GraphModule, List[Any], Optional[SubclassMeta]]:
71    # aot_dispatch_base requires functionalization, but doesn't need to handle as many cases as the autograd case.
72    # The cases that aot_dispatch_base doesn't need to handle include:
73    # - outputs that are aliases of graph intermediates
74    # - outputs that are aliases of graph inputs
75    # While cases that it does need to handle include:
76    # - input mutations (including when inputs are aliases of each other)
77    # - input metadata mutations
78    fn_to_trace = fn_input_mutations_to_outputs(
79        flat_fn,
80        fw_metadata,
81        keep_data_input_mutations=aot_config.keep_inference_input_mutations,
82    )
83
84    fn_to_trace, updated_flat_args = create_functionalized_fn(
85        fn_to_trace,
86        flat_args,
87        meta=fw_metadata,
88        aot_config=aot_config,
89        trace_joint=False,
90    )
91
92    # TODO: replace with AOTDispatchSubclassWrapper once we refactor
93    # fn_input_mutations_to_outputs and create_functionalized_fn
94    # into CompilerWrappers.
95    (
96        fn_to_trace,
97        updated_flat_args_subclasses_desugared,
98        maybe_subclass_meta,
99    ) = aot_dispatch_subclass(
100        fn_to_trace,
101        updated_flat_args,
102        is_joint_structure=False,
103        meta=fw_metadata,
104        fw_only=flat_fn,
105    )
106
107    (fn_to_trace, updated_flat_args_subclasses_desugared) = handle_effect_tokens_fn(
108        fn_to_trace,
109        updated_flat_args_subclasses_desugared,
110        meta=fw_metadata,
111        trace_joint=False,
112    )
113
114    aot_graphs_log.debug(
115        "aot_config id: %s, fw_metadata=%s,subclass_metadata=%s",
116        str(aot_config.aot_id),
117        str(fw_metadata),
118        str(maybe_subclass_meta),
119    )
120
121    # We track buffer assignments when exporting in non-strict mode.
122    # (In contrast, strict mode errors on any attribute assignment.)
123    mod_when_exporting_non_strict = root_module_when_exporting_non_strict(flat_fn)
124    if aot_config.is_export and mod_when_exporting_non_strict is not None:
125        # For any buffer that is assigned, we want to associate it to the final proxy node
126        # that it is assigned to. This node can then be added as a buffer mutation output.
127        assigned_buffers = {}
128
129        def _map_assigned_buffer_to_proxy(_mod, name, buffer):
130            # We intercept buffer assignments on the root module through this hook.
131            if _mod._buffers is mod_when_exporting_non_strict._buffers:
132                # The value assigned to a buffer is a functional tensor, which wraps a fake tensor.
133                assert isinstance(
134                    buffer, torch._subclasses.functional_tensor.FunctionalTensor
135                )
136                fake = buffer.from_functional()
137                # The fake tensor in turn is associated with a proxy node.
138                proxy_mode = _detect_infra_mode(torch._C._TorchDispatchModeKey.PROXY)
139                assert proxy_mode is not None
140                proxy = torch.fx.experimental.proxy_tensor.get_proxy_slot(
141                    fake, proxy_mode.tracer
142                ).proxy.node
143                # We map the assigned buffer to this proxy node.
144                assigned_buffers[name] = proxy.name
145            return buffer
146
147        handle = torch.nn.modules.module.register_module_buffer_registration_hook(
148            _map_assigned_buffer_to_proxy
149        )
150
151    saved_updated_flat_args_subclasses_desugared = pytree.tree_map_only(
152        torch.Tensor, lambda t: t.detach(), updated_flat_args_subclasses_desugared
153    )
154    fw_module = _create_graph(
155        fn_to_trace,
156        updated_flat_args_subclasses_desugared,
157        aot_config=aot_config,
158    )
159
160    if aot_config.is_export and mod_when_exporting_non_strict is not None:
161        # We update metadata to consider any assigned buffers as buffer mutations.
162        i = len(dict(mod_when_exporting_non_strict.named_parameters()))
163        for name, _ in mod_when_exporting_non_strict.named_buffers():
164            if name in assigned_buffers and not fw_metadata.input_info[i].mutates_data:  # type: ignore[possibly-undefined]
165                fw_metadata.input_info[i] = dataclasses.replace(
166                    fw_metadata.input_info[i], mutates_data=True
167                )
168                fw_metadata.num_mutated_inp_runtime_indices += 1
169            i += 1
170
171        # We add nodes corresponding to buffer assignments as output nodes in the graph.
172        add_nodes = []
173        output_node = None
174        output_node = list(fw_module.graph.nodes)[-1]
175        for name in assigned_buffers.values():  # type: ignore[possibly-undefined]
176            for node in fw_module.graph.nodes:
177                if node.name == name:
178                    add_nodes.append(node)
179                    node.users[output_node] = None
180        output_node.args = ((*add_nodes, *output_node.args[0]),)
181
182        handle.remove()  # type: ignore[possibly-undefined]
183
184    # As long as we opted to remove input mutations, then
185    # there should be *NO* mutating ops in the graph at this point.
186    copy_count = assert_functional_graph(fw_module.graph)
187    fw_module.graph.eliminate_dead_code()
188    fw_module.recompile()
189
190    copy_count2 = assert_functional_graph(fw_module.graph)
191    propagate_input_mutation_stacktraces(fw_module.graph)
192
193    # See Note [Side-Effectful Tokens in AOTAutograd]
194    num_tokens = len(fw_metadata.tokens)
195    if num_tokens != 0 and config.unlift_effect_tokens:
196        unlift_tokens(fw_module, fw_metadata, aot_config)
197        saved_updated_flat_args_subclasses_desugared = (
198            saved_updated_flat_args_subclasses_desugared[num_tokens:]
199        )
200
201    assert copy_count == copy_count2
202
203    if aot_config.enable_log:
204        aot_graphs_log.info(
205            "%s",
206            lazy_format_graph_code(
207                "Forward graph",
208                fw_module,
209                aot_config.aot_id,
210                include_stride=True,
211                include_device=True,
212                colored=True,
213            ),
214        )
215        trace_structured(
216            "aot_forward_graph",
217            payload_fn=lambda: fw_module.print_readable(
218                print_output=False, include_stride=True, include_device=True
219            ),
220        )
221
222    # TODO: should factor this into a separate function for export that always only returns just the graph.
223    if aot_config.is_export:
224        assert (
225            maybe_subclass_meta is None
226        ), "aot_export_module does not support tensor subclass inputs for now."
227    return fw_module, saved_updated_flat_args_subclasses_desugared, maybe_subclass_meta
228
229
230# Has the precondition that there
231# are no duplicate arguments in flat_args (e.g., the same Tensor
232# object never shows up twice.  However, two tensor inputs MAY alias
233# the same storage, so long as they have separate TensorImpls.)
234def aot_dispatch_autograd_graph(
235    flat_fn,
236    flat_args: List[Any],
237    aot_config: AOTConfig,
238    *,
239    fw_metadata: ViewAndMutationMeta,
240) -> Tuple[torch.fx.GraphModule, Tuple[List[Any], List[Any]], Optional[SubclassMeta]]:
241    # traced_tangents corresponds to the set of outputs in the traced forward that should get grad_outputs in the traced backward.
242    # It includes outputs of the original forward, *and* any updated inputs due to input mutations.
243    # However, it does *not* include any outputs that are aliases of inputs or intermediates, or any metadata-only input mutations.
244    joint_inputs = (flat_args, fw_metadata.traced_tangents)
245
246    fn_prepared_for_autograd = fn_prepped_for_autograd(
247        flat_fn,
248        fw_metadata,
249    )
250    joint_fn_to_trace = create_joint(fn_prepared_for_autograd, aot_config=aot_config)
251
252    joint_fn_to_trace, updated_joint_inputs = create_functionalized_fn(
253        joint_fn_to_trace,
254        joint_inputs,
255        meta=fw_metadata,
256        aot_config=aot_config,
257        trace_joint=True,
258    )
259
260    # TODO: replace with AOTDispatchSubclassWrapper once we refactor
261    # fn_input_mutations_to_outputs and create_functionalized_fn
262    # into CompilerWrappers.
263    subclass_tracing_info = aot_dispatch_subclass(
264        joint_fn_to_trace,
265        updated_joint_inputs,
266        is_joint_structure=True,
267        meta=fw_metadata,
268        fw_only=flat_fn,
269    )
270
271    joint_fn_to_trace = subclass_tracing_info.plain_tensor_trace_fn
272    updated_joint_inputs = subclass_tracing_info.plain_tensor_args
273
274    (joint_fn_to_trace, updated_joint_inputs) = handle_effect_tokens_fn(
275        joint_fn_to_trace,
276        updated_joint_inputs,
277        meta=fw_metadata,
278        trace_joint=True,
279    )
280
281    # When we call _create_graph, this may mutate the metadata of joint
282    # inputs.  But callers are expecting to get the original joint inputs.  So
283    # we make aliases of all the inputs to make sure we have a copy that
284    # doesn't get modified.
285    #
286    # This destroys requires_grad/grad_fn information.  However, backends
287    # beneath AOTAutograd are indifferent to this information, so it doesn't
288    # matter.
289    saved_updated_joint_inputs = pytree.tree_map_only(
290        torch.Tensor, lambda t: t.detach(), updated_joint_inputs
291    )
292    maybe_subclass_meta = subclass_tracing_info.maybe_subclass_meta
293
294    fx_g = _create_graph(joint_fn_to_trace, updated_joint_inputs, aot_config=aot_config)
295
296    # There should be *NO* mutating ops in the graph at this point.
297    assert_functional_graph(fx_g.graph)
298
299    # Redundant with the check above, but worth having in case tracing introduced
300    # a fake tensor. Unlikely.
301    # See Note: [Fake Modules and AOTAutograd]
302    torch._dynamo.utils.assert_no_fake_params_or_buffers(fx_g)
303    fx_g.graph.eliminate_dead_code()
304    copy_fwd_metadata_to_bw_nodes(fx_g)
305    fx_g.recompile()
306
307    # TODO: in AOTAutograd, we create metadata like _indices_of_inps_to_detach to detect
308    # when we need to manually detach() some inputs in the forward.
309    # Higher order ops might eventually need to do the same.
310    if aot_config.is_export:
311        assert (
312            maybe_subclass_meta is None
313        ), "aot_export_module does not support tensor subclass inputs for now."
314    return fx_g, saved_updated_joint_inputs, maybe_subclass_meta
315