xref: /aosp_15_r20/external/pytorch/torch/_dynamo/backends/cudagraphs.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import functools
4from collections import defaultdict
5from typing import Dict, List, Optional
6
7import torch
8from torch._dynamo import config
9from torch._dynamo.backends.common import aot_autograd
10from torch._dynamo.backends.debugging import boxed_nop
11from torch._inductor.cudagraph_utils import (
12    BoxedDeviceIndex,
13    check_multiple_devices_or_any_cpu_nodes,
14    format_default_skip_message,
15    get_mutation_stack_trace,
16    get_placeholder_info,
17    log_cudagraph_skip_and_bump_counter,
18)
19from torch._inductor.utils import (
20    BoxedBool,
21    count_tangents,
22    get_first_incompatible_cudagraph_node,
23    num_fw_fixed_arguments,
24    output_node,
25)
26from torch.multiprocessing.reductions import StorageWeakRef
27
28from .registry import register_backend
29
30
31def find_input_mutations(g):
32    def meta_fk(meta):
33        return meta["val"] if "val" in meta else meta["fake_result"]
34
35    inputs = defaultdict(set)
36    input_idx = 0
37    mutated_inputs = set()
38    for n in g.nodes:
39        if n.op == "placeholder":
40            if isinstance(meta_fk(n.meta), torch.Tensor):
41                inputs[StorageWeakRef(meta_fk(n.meta)._typed_storage())].add(input_idx)
42            input_idx += 1
43        elif n.op == "call_function":
44            if not hasattr(n.target, "_schema"):
45                continue
46
47            schema = n.target._schema
48            for i, arg in enumerate(schema.arguments):
49                if i < len(n.args):
50                    argument = n.args[i]
51                else:
52                    if arg.name not in n.kwargs:
53                        continue
54                    argument = n.kwargs[arg.name]
55                mut_arg = False
56                if arg.alias_info:
57                    if arg.alias_info.is_write:
58                        mut_arg = True
59                if mut_arg:
60                    # TODO: not correct for args that contain tensors in a struct
61                    # like list
62                    mutated_inputs |= inputs[
63                        StorageWeakRef(meta_fk(argument.meta)._typed_storage())
64                    ]
65
66        # TODO: error on unrecognized nodes
67    return mutated_inputs
68
69
70def get_device_node_mapping(gm: torch.fx.GraphModule):
71    device_node_mapping: Dict[torch.device, torch.fx.Node] = {}
72    for n in gm.graph.nodes:
73        t = n.meta.get("val", None)
74        if isinstance(t, torch.Tensor) and t.device not in device_node_mapping:
75            device_node_mapping[t.device] = n
76    return device_node_mapping
77
78
79def check_for_mutation_ignore_cuda_graph_managed_tensor(
80    aot_model: torch.fx.GraphModule, num_fixed
81) -> Optional[str]:
82    mutation_indices = find_input_mutations(aot_model.graph) - set(range(num_fixed))
83    if not mutation_indices:
84        return None
85
86    placeholders = get_placeholder_info(aot_model.graph)
87    return get_mutation_stack_trace(placeholders, mutation_indices)
88
89
90def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed) -> Optional[str]:
91    if not config.cudagraph_backend_support_input_mutation:
92        if mut_skip := check_for_mutation_ignore_cuda_graph_managed_tensor(
93            aot_model, num_fixed
94        ):
95            return mut_skip
96
97    if skip := check_multiple_devices_or_any_cpu_nodes(
98        get_device_node_mapping(aot_model)
99    ):
100        return skip
101
102    if node := get_first_incompatible_cudagraph_node(aot_model):
103        return format_default_skip_message(f"incompatible op ({node.name})")
104
105    return None
106
107
108def get_device_index(gm) -> int:
109    device = next(iter(get_device_node_mapping(gm)))
110    assert device.type == "cuda"
111    return device.index
112
113
114def get_stack_traces(gm) -> List[Optional[str]]:
115    output = output_node(gm)
116    assert len(output.args) == 1
117    return [
118        (arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None)
119        for arg in output.args[0]
120    ]
121
122
123def cudagraphs(dynamo_model, dynamo_inputs):
124    from torch._inductor.cudagraph_trees import cudagraphify_impl
125
126    do_cudagraphs = BoxedBool(True)
127    boxed_device_index = BoxedDeviceIndex(None)
128
129    def forward_cudagraphs(aot_model, aot_inputs, is_inference=False):
130        interp = boxed_nop(aot_model, aot_inputs)
131        fixed = num_fw_fixed_arguments(len(dynamo_inputs), len(aot_inputs))
132        if skip_msg := check_for_skip(aot_model, fixed):
133            BoxedBool.disable(do_cudagraphs)
134            log_cudagraph_skip_and_bump_counter(
135                f"skipping cudagraphs due to {skip_msg}"
136            )
137            return interp
138
139        boxed_device_index.set(get_device_index(aot_model))
140        out = cudagraphify_impl(
141            interp,
142            aot_inputs,
143            range(fixed),
144            device_index=boxed_device_index.value,
145            is_backward=False,
146            is_inference=False,
147            stack_traces=get_stack_traces(aot_model),
148            placeholders=get_placeholder_info(aot_model.graph),
149            mutated_input_idxs=find_input_mutations(aot_model.graph),
150        )
151        out._boxed_call = True
152        return out
153
154    def backward_cudagraphs(aot_model, aot_inputs):
155        interp = boxed_nop(aot_model, aot_inputs)
156        if not do_cudagraphs:
157            return aot_model
158
159        fixed = count_tangents(aot_model)
160        if skip_msg := check_for_skip(aot_model, fixed):
161            log_cudagraph_skip_and_bump_counter(
162                "skipping cudagraphs due to %s", skip_msg
163            )
164
165            # See [Backward Generation Handling]
166            manager = torch._inductor.cudagraph_trees.get_manager(
167                boxed_device_index.value, create_if_none_exists=False
168            )
169            assert manager is not None
170
171            def fn(inputs):
172                manager.set_to_running_backward()
173                return aot_model(inputs)
174
175            fn._boxed_call = True
176            return fn
177
178        out = cudagraphify_impl(
179            interp,
180            aot_inputs,
181            range(fixed),
182            device_index=get_device_index(aot_model),
183            is_backward=True,
184            is_inference=False,
185            stack_traces=get_stack_traces(aot_model),
186            placeholders=get_placeholder_info(aot_model.graph),
187            mutated_input_idxs=find_input_mutations(aot_model.graph),
188        )
189        out._boxed_call = True
190        return out
191
192    aot_cudagraphs = aot_autograd(
193        fw_compiler=forward_cudagraphs,
194        bw_compiler=backward_cudagraphs,
195        inference_compiler=functools.partial(forward_cudagraphs, is_inference=True),
196        keep_inference_input_mutations=torch._dynamo.config.cudagraph_backend_keep_input_mutation,
197    )
198    return aot_cudagraphs(dynamo_model, dynamo_inputs)
199
200
201class CudagraphsBackend:
202    compiler_name = "cudagraphs"
203
204    @staticmethod
205    def reset():
206        from torch._inductor.cudagraph_trees import reset_cudagraph_trees
207
208        reset_cudagraph_trees()
209
210    @staticmethod
211    def __call__(model, inputs):
212        return cudagraphs(model, inputs)
213
214
215# aot_cudagraphs only applies CUDA graphs to the graph.  It is also helpful
216# for debugging and can serve as a perf baseline.
217register_backend(name="cudagraphs", compiler_fn=CudagraphsBackend())
218
219
220def cudagraphs_inner(model, inputs, copy_outputs=True, copy_inputs=True):
221    """This isn't registered as a backend, but is used in some benchmarks"""
222    assert isinstance(inputs, (list, tuple))
223    if copy_inputs:
224        static_inputs = [torch.zeros_like(x) for x in inputs]
225    else:
226        static_inputs = list(inputs)
227
228    # warmup
229    torch.cuda.synchronize()
230    stream = torch.cuda.Stream()
231    stream.wait_stream(torch.cuda.current_stream())
232    with torch.cuda.stream(stream):
233        model(*inputs)
234    stream.synchronize()
235    torch.cuda.current_stream().wait_stream(stream)
236    torch.cuda.synchronize()
237
238    # record
239    graph = torch.cuda.CUDAGraph()
240    with torch.cuda.graph(graph, stream=stream):
241        static_outputs = model(*static_inputs)
242    if not isinstance(static_outputs, (list, tuple)):
243        static_outputs = (static_outputs,)
244
245    def run(*new_inputs):
246        assert len(static_inputs) == len(new_inputs)
247        if copy_inputs:
248            for dst, src in zip(static_inputs, new_inputs):
249                dst.copy_(src)
250        graph.replay()
251        if copy_outputs:
252            return [x.clone() for x in static_outputs]
253        else:
254            return static_outputs
255
256    return run
257