xref: /aosp_15_r20/external/pytorch/torch/_dynamo/compiled_autograd.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import contextlib
3import functools
4from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union
5
6import torch
7from torch._dynamo.external_utils import (
8    call_backward,
9    call_hook,
10    FakeCompiledAutogradEngine,
11)
12from torch._dynamo.source import GetItemSource, LocalSource
13from torch._dynamo.utils import counters, lazy_format_graph_code, set_locals_to_steal
14from torch._logging import getArtifactLogger, trace_structured
15from torch._prims_common import clone_preserve_strides
16from torch._subclasses import FakeTensorMode
17from torch.fx import GraphModule
18from torch.fx.experimental._backward_state import BackwardState
19from torch.fx.experimental.proxy_tensor import (
20    decompose,
21    disable_autocast_cache,
22    disable_proxy_modes_tracing,
23    fetch_object_proxy,
24    ProxyTorchDispatchMode,
25    PythonKeyTracer,
26    track_tensor_tree,
27)
28from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv
29from torch.fx.traceback import preserve_node_meta, set_stack_trace
30from torch.utils._traceback import CapturedTraceback
31
32
33if TYPE_CHECKING:
34    from torch.fx.proxy import Proxy
35
36
37compiled_autograd_log = getArtifactLogger(__name__, "compiled_autograd")
38verbose_log = getArtifactLogger(__name__, "compiled_autograd_verbose")
39
40
41def snapshot_verbose_logging_enabled():
42    return torch._logging._internal.log_state.is_artifact_enabled(
43        "compiled_autograd_verbose"
44    )
45
46
47def cpp_verbose_log_fn(msg: str) -> None:
48    verbose_log.debug(msg)
49
50
51def snapshot_cudagraph_enabled():
52    return torch._inductor.config.triton.cudagraphs
53
54
55def maybe_clone(x):
56    if x is not None:
57        return clone_preserve_strides(x)
58    return x
59
60
61class AutogradCompilerInstance:
62    def __init__(self, compiler_fn) -> None:
63        self.compiler_fn = compiler_fn
64        self.stack = contextlib.ExitStack()
65        self.close = self.stack.close
66        self.shape_env = ShapeEnv()
67        self.fake_tensor_mode = FakeTensorMode(
68            allow_fallback_kernels=True,
69            allow_non_fake_inputs=True,
70            shape_env=self.shape_env,
71        )
72        self.fx_tracer = PythonKeyTracer()
73        self.proxy_mode = ProxyTorchDispatchMode(self.fx_tracer, "symbolic")
74        self.hooks_proxy: Optional[Proxy] = None
75        self.graph_placeholders = ["inputs", "sizes", "scalars", "hooks"]
76
77    def wrap_fake(self, x, source):
78        assert isinstance(x, torch.Tensor)
79        return self.fake_tensor_mode.from_tensor(x, source=source)
80
81    @staticmethod
82    def source(name, idx) -> GetItemSource:
83        return GetItemSource(LocalSource(name), idx)
84
85    def begin_capture(
86        self,
87        inputs: List[torch.Tensor],
88        sizes: List[int],
89        scalars: List[Union[int, float]],
90    ):
91        counters["compiled_autograd"]["captures"] += 1
92        self.aot_graph_cls_name: Optional[str] = None
93        self.aot_graph_infos: Dict[int, Dict[str, Any]] = {}
94        self.fx_tracer.root = torch.nn.Module()
95        self.fx_tracer.graph = torch.fx.Graph(tracer_cls=PythonKeyTracer)
96        self.fx_tracer.tensor_attrs = {}
97        args_proxy, sizes_proxy, scalars_proxy, self.hooks_proxy = (
98            self.fx_tracer.create_proxy("placeholder", name, (), {})
99            for name in self.graph_placeholders
100        )
101
102        # tensor inputs to fake tensors
103        inputs = [
104            self.wrap_fake(x, self.source("inputs", idx))
105            for idx, x in enumerate(inputs)
106        ]
107        self.bind_tensors_to_proxies(inputs, args_proxy)
108
109        # size inputs to symints
110        sizes = [
111            self.shape_env.create_unspecified_symint_and_symbol(
112                val,
113                self.source("sizes", idx),
114                DimDynamic.DYNAMIC,
115            )
116            for idx, val in enumerate(sizes)
117        ]
118        self.bind_tensors_to_proxies(sizes, sizes_proxy)
119
120        for idx, val in enumerate(scalars):
121            source = self.source("scalars", idx)
122            if isinstance(val, int):
123                scalars[idx] = self.shape_env.create_unspecified_symint_and_symbol(
124                    val,
125                    source,
126                    DimDynamic.DYNAMIC,
127                )
128            elif isinstance(val, float):
129                scalars[idx] = self.shape_env.create_symfloatnode(
130                    self.shape_env.create_unspecified_symbol(
131                        val,
132                        source=source,
133                        dynamic_dim=DimDynamic.DYNAMIC,
134                    ),
135                    hint=val,
136                    source=source,
137                )
138            else:
139                raise AssertionError("Unexpected scalar type: ", type(val))
140        self.bind_tensors_to_proxies(scalars, scalars_proxy)
141
142        # TODO(jansel): are all these modes needed?
143        self.stack.enter_context(decompose({}))
144        self.stack.enter_context(self.fake_tensor_mode)
145        self.stack.enter_context(self.proxy_mode)
146        self.stack.enter_context(disable_autocast_cache())
147        self.stack.enter_context(preserve_node_meta())
148        return inputs, sizes, scalars
149
150    def proxy_call_backward(
151        self,
152        inputs,
153        output_metadatas,
154        saved_tensors,
155        backward_idx: int,
156    ):
157        assert self.hooks_proxy is not None
158        backward_c_function = self.hooks_proxy[backward_idx]  # type: ignore[index]
159        proxies = self.fx_tracer.create_proxy(
160            kind="call_function",
161            target=call_backward,
162            args=(
163                backward_c_function,
164                self.to_proxy(saved_tensors),
165                *self.to_proxy(inputs),
166            ),
167            kwargs={},
168        )
169
170        with disable_proxy_modes_tracing():
171            # create fake Tensors
172            grad_ins: List[Optional[torch.Tensor]] = []
173            for output_metadata in output_metadatas:
174                if output_metadata is None:
175                    grad_ins.append(None)
176                    continue
177
178                layout, device, dtype, size = output_metadata
179                grad_ins.append(
180                    torch.empty(size=size, dtype=dtype, layout=layout, device=device)
181                )
182            self.bind_tensors_to_proxies(grad_ins, proxies)
183        return tuple(grad_ins)
184
185    def proxy_call_hook(self, hook, *args, **kwargs):
186        return self.fx_tracer.create_proxy(
187            "call_function",
188            call_hook,
189            (
190                hook,
191                *[self.to_proxy(x) for x in args],
192            ),
193            kwargs,
194        )
195
196    def tensor_pre_hook(self, inputs, hook_id, i: int):
197        assert self.hooks_proxy is not None
198        hook = self.hooks_proxy[hook_id]  # type: ignore[index]
199        proxy = self.proxy_call_hook(
200            hook,
201            inputs[i],
202            hook_type="tensor_pre_hook",
203        )
204        with disable_proxy_modes_tracing():
205            inputs[i] = maybe_clone(inputs[i])
206            self.bind_tensors_to_proxies([inputs[i]], [proxy])
207        return inputs
208
209    def pre_hook(self, inputs, hook_id):
210        assert self.hooks_proxy is not None
211        hook = self.hooks_proxy[hook_id]  # type: ignore[index]
212        proxies = self.proxy_call_hook(
213            hook,
214            inputs,
215            hook_type="pre_hook",
216        )
217        with disable_proxy_modes_tracing():
218            inputs = [maybe_clone(x) for x in inputs]
219            self.bind_tensors_to_proxies(inputs, proxies)
220        return inputs
221
222    def post_hook(self, outputs, inputs, hook_id):
223        assert self.hooks_proxy is not None
224        hook = self.hooks_proxy[hook_id]  # type: ignore[index]
225        proxies = self.proxy_call_hook(
226            hook,
227            outputs,
228            inputs,
229            hook_type="post_hook",
230        )
231        with disable_proxy_modes_tracing():
232            outputs = [maybe_clone(x) for x in outputs]
233            self.bind_tensors_to_proxies(outputs, proxies)
234        return outputs
235
236    def post_acc_grad_hook(self, input, hook_id):
237        assert isinstance(input, torch.Tensor)
238        assert self.hooks_proxy is not None
239        hook = self.hooks_proxy[hook_id]  # type: ignore[index]
240        proxy = self.proxy_call_hook(
241            hook,
242            input,
243            hook_type="post_acc_grad_hook",
244        )
245        with disable_proxy_modes_tracing():
246            input = [maybe_clone(input)]
247            self.bind_tensors_to_proxies(input, [proxy])
248        return input
249
250    # Note: [Compiled autograd and cudagraphs]
251    # Eager autograd backward implements scalars as 0-dim tensors, see DivBackward0::other_.
252    # When compiled autograd traces those nodes, it lifts the scalar tensors, resulting in a graph
253    # with some cpu 0-dim tensor inputs. To prevent the entire graph from skipping cudagraph, we move the
254    # scalars tensors to cuda. This works because ATen/prims ops will accept cuda 0-dim tensors too.
255    def move_graph_nodes_to_cuda(self, graph) -> List[int]:
256        to_move: Dict[int, torch.fx.Node] = {}
257        has_cuda_inputs = False
258        nodes = list(graph.nodes)
259        assert nodes[0].target == "inputs"
260        inputs = nodes[0]
261        inputs_users = list(inputs.users.keys())
262        # input access nodes should immediately follow placeholder nodes
263        first_getitem_idx = len(self.graph_placeholders)
264        assert nodes[first_getitem_idx] == inputs_users[0]
265        last_getitem_idx = first_getitem_idx + len(inputs_users) - 1
266        assert nodes[last_getitem_idx] == inputs_users[-1]
267        for i, node in enumerate(inputs_users):
268            if not has_cuda_inputs and node.meta["val"].device.type == "cuda":
269                has_cuda_inputs = True
270                continue
271
272            is_cpu = node.meta["val"].device.type == "cpu"
273            is_scalar = len(node.meta["val"].size()) == 0
274            if is_cpu and is_scalar:
275                node_users = list(node.users.keys())
276                if all(
277                    isinstance(user.target, torch._ops.OpOverload)
278                    and user.target.namespace in ("prims", "aten")
279                    for user in node_users
280                ):
281                    # all users are prims/aten, can move safely
282                    to_move[i] = node
283
284        # only move cpu scalars to cuda if there were cuda activations in this graph,
285        # this is to handle the case where cudagraphs is enabled on a cpu-only graph
286        if has_cuda_inputs:
287            for node in to_move.values():
288                node.meta["val"] = node.meta["val"].cuda()
289
290            # return runtime indices we need to move to cuda
291            return list(to_move.keys())
292
293        return []
294
295    def end_capture(self, outputs):
296        self.fx_tracer.create_proxy(
297            "call_function",
298            FakeCompiledAutogradEngine._exec_final_callbacks_stub,
299            (),
300            {},
301        )
302        self.stack.close()
303        self.fx_tracer.create_node(
304            "output",
305            "output",
306            (self.fx_tracer.create_arg(self.to_proxy(outputs)),),
307            {},
308        )
309        self.rename_aot_dispatcher_nodes()
310        self.reorder_accumulate_grad_nodes()
311        runtime_inputs_to_move: List[int] = []
312        if snapshot_cudagraph_enabled():
313            runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph)
314
315        graph = GraphModule(
316            self.fx_tracer.root, self.fx_tracer.graph, "CompiledAutograd"
317        )
318        set_locals_to_steal(graph, ["inputs"])
319        lazy_graph_code = lazy_format_graph_code(
320            "Compiled autograd graph",
321            graph,
322            include_device=True,
323            include_stride=True,
324            colored=True,
325        )
326        compiled_autograd_log.info("%s", lazy_graph_code)
327        verbose_log.debug("%s", lazy_graph_code)
328        trace_structured(
329            "compiled_autograd_graph",
330            payload_fn=lambda: graph.print_readable(print_output=False),
331        )
332
333        def runtime_wrapper(compiled_fn, inputs, sizes, scalars, hooks):
334            global in_compiled_autograd_region
335            try:
336                in_compiled_autograd_region = True
337                for i in runtime_inputs_to_move:
338                    inputs[i] = inputs[i].pin_memory().cuda(non_blocking=True)
339
340                return compiled_fn(inputs, sizes, scalars, hooks)
341            finally:
342                in_compiled_autograd_region = False
343
344        return runtime_wrapper, self.compiler_fn(graph)
345
346    def rename_aot_dispatcher_nodes(self):
347        """
348        Renames nodes as they appear in the AOTDispatcher backward graphs, prefixed by AOT id
349        e.g. AOTDispatcher backward graph X's `sin_Y` -> `aotX_sin_Y`
350        """
351        if self.aot_graph_cls_name is None:
352            return
353
354        def is_similar(a: torch.fx.node.Node, b: torch.fx.node.Node):
355            target_match = a.target == b.target
356            if not target_match:
357                target_match = (
358                    hasattr(a.target, "__name__")
359                    and hasattr(b.target, "__name__")
360                    and a.target.__name__ == b.target.__name__
361                )
362            return (
363                target_match
364                and a.op == b.op
365                and a.type == b.type
366                and len(a.all_input_nodes) == len(b.all_input_nodes)
367            )
368
369        for nodecall_index, info in self.aot_graph_infos.items():
370            ca_node_start_idx = info["ca_node_start_idx"]
371            aot_id = info["aot_id"]
372            aot_graph = info["aot_gm"].graph
373
374            # 1. Find the first op from user code in the AOT graph
375            aot_it = iter(aot_graph.nodes)
376            aot_node = next(aot_it)
377            assert aot_node is not None
378            try:
379                while aot_node.op != "call_function":
380                    aot_node = next(aot_it)
381            except StopIteration:
382                continue
383
384            try:
385                # 2. Find the first op in the compiled autograd graph segment
386                ca_it = iter(self.fx_tracer.graph.nodes)
387                for _ in range(ca_node_start_idx):
388                    next(ca_it)
389                ca_node = next(ca_it)
390
391                # Graphs should all end with output node
392                while ca_node.op != "output" and not is_similar(ca_node, aot_node):
393                    # The compiled autograd graph may contain lazily inserted ops
394                    # We skip those when aligning nodes
395                    ca_node = next(ca_it)
396
397                # 3. Keep alligned and rename nodes
398                while aot_node.op != "output" and ca_node.op != "output":
399                    if not ca_node.users:
400                        # TODO: DCE for compiled autograd graph
401                        ca_node = next(ca_it)
402                        continue
403
404                    if not is_similar(aot_node, ca_node):
405                        # There should be no lazily inserted ops in the middle of a match
406                        # So any deviation is an error
407                        raise StopIteration
408
409                    ca_node.name = f"aot{aot_id}_{aot_node.name}"
410                    for i, inp in enumerate(aot_node.all_input_nodes):
411                        ca_node.all_input_nodes[i].name = f"aot{aot_id}_{inp.name}"
412
413                    aot_node = next(aot_it)
414                    ca_node = next(ca_it)
415            except StopIteration:
416                verbose_log.debug(
417                    "Failed to match %s%s (NodeCall %s) nodes with AOT backward graph %s nodes",
418                    self.aot_graph_cls_name,
419                    aot_id,
420                    nodecall_index,
421                    aot_id,
422                )
423
424    def reorder_accumulate_grad_nodes(self):
425        """
426        Usage of AOTAutograd causes all the accumulate_grad_ nodes to get pushed to the end of
427        the graph.  This differs from eager mode, which schedules them as soon as possible. This
428        pass attempts to reorder the graph to mimic eager behavior.
429        """
430        for node in self.fx_tracer.graph.find_nodes(
431            op="call_function", target=torch.ops.inductor.accumulate_grad_.default
432        ):
433            arg = max(node.args)  # last arg
434            if arg is not node.prev and arg.op != "placeholder":
435                arg.append(node)
436
437    def to_proxy(self, t):
438        if t is None:
439            return None
440        if isinstance(t, list):
441            return [self.to_proxy(x) for x in t]
442        if isinstance(t, tuple):
443            return tuple(self.to_proxy(x) for x in t)
444        # can it be torch.SymInt as the code used to imply?
445        assert isinstance(t, torch.Tensor)
446        proxy_tensor = fetch_object_proxy(self.fx_tracer, t)
447        assert isinstance(proxy_tensor, torch.fx.experimental.proxy_tensor._ProxyTensor)
448        return proxy_tensor.proxy
449
450    def bind_tensors_to_proxies(self, tensors, proxies):
451        if isinstance(proxies, torch.fx.Proxy):
452            proxies = [proxies[i] for i in range(len(tensors))]  # type: ignore[index]
453        assert len(tensors) == len(proxies)
454        track_tensor_tree(tensors, proxies, constant=None, tracer=self.fx_tracer)
455
456    def bind_backward_state(self, index: int):
457        assert self.hooks_proxy is not None
458        proxy = self.hooks_proxy[index]  # type: ignore[index]
459        bw_state = BackwardState()
460        track_tensor_tree(bw_state, proxy, constant=None, tracer=self.fx_tracer)
461        return bw_state
462
463    def set_node_origin(
464        self,
465        node_name: str,
466        nodecall_index: int,
467        pyobj: Optional[torch.autograd.Function],
468    ):
469        maybe_aot_id = ""
470        if pyobj is not None:
471            forward_cls = pyobj._forward_cls  # type: ignore[attr-defined]
472            if hasattr(forward_cls, "_aot_id"):
473                # backward was created by AOT Dispatcher
474                self.aot_graph_cls_name = node_name
475                maybe_aot_id = forward_cls._aot_id
476                self.aot_graph_infos[nodecall_index] = {
477                    "ca_node_start_idx": len(self.fx_tracer.graph.nodes),
478                    "aot_id": maybe_aot_id,
479                    "aot_gm": forward_cls._lazy_backward_info.bw_module,
480                }
481
482        new_code = f"{node_name}{maybe_aot_id} (NodeCall {nodecall_index})"
483        raw_stack_trace = CapturedTraceback.extract().format()[-1]
484        new_stack_trace = raw_stack_trace.replace(
485            "raw_stack_trace = CapturedTraceback.extract().format()[-1]", new_code
486        )
487        set_stack_trace(new_stack_trace)
488
489
490# state of the autograd engine dispatch, kept in sync by enable/disable context managers
491compiled_autograd_enabled = False
492
493# global flag to check if we are processing graphs produced from a compiled autograd graph
494in_compiled_autograd_region = False
495
496
497@contextlib.contextmanager
498def enable(compiler_fn):
499    prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler(
500        functools.partial(AutogradCompilerInstance, compiler_fn)
501    )
502    if snapshot_verbose_logging_enabled():
503        torch._C._dynamo.compiled_autograd.set_verbose_logger(cpp_verbose_log_fn)
504    global compiled_autograd_enabled
505    compiled_autograd_enabled = True
506    try:
507        with torch.autograd.set_multithreading_enabled(False):
508            yield
509    finally:
510        if not prior:
511            compiled_autograd_enabled = False
512        torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior)
513
514
515@contextlib.contextmanager
516def disable():
517    prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler(None)
518    global compiled_autograd_enabled
519    compiled_autograd_enabled = False
520    try:
521        yield
522    finally:
523        if prior:
524            compiled_autograd_enabled = True
525        torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior)
526
527
528# return to starting state of a new process
529def reset() -> None:
530    compiled_autograd_enable = False
531    assert not in_compiled_autograd_region
532    torch._C._dynamo.compiled_autograd.set_autograd_compiler(None)
533    torch._C._dynamo.compiled_autograd.set_verbose_logger(None)
534