xref: /aosp_15_r20/external/pytorch/torch/_dynamo/external_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# This module contains functions that *will be allowed* by dynamo
3
4import functools
5import warnings
6from typing import List
7
8import torch
9import torch.utils._pytree as pytree
10
11
12try:
13    import numpy as np
14except ModuleNotFoundError:
15    np = None  # type: ignore[assignment]
16
17
18def is_compiling() -> bool:
19    """
20    Indicates whether we are tracing/compiling with torch.compile() or torch.export().
21
22    If need to check specifically that TorchDynamo is used, then use
23    torch.compiler.is_dynamo_compiling().
24
25    TODO(khabinov): we should deprecate this function and use one of these two:
26    * torch.compiler.is_compiling(),
27    * torch.compiler.is_dynamo_compiling().
28    It will depend on the context where to use what.
29    """
30    return torch.compiler.is_compiling()
31
32
33def wrap_inline(fn):
34    """
35    Create an extra frame around fn that is not in skipfiles
36    """
37
38    @functools.wraps(fn)
39    def inner(*args, **kwargs):
40        return fn(*args, **kwargs)
41
42    return inner
43
44
45def call_hook(hook, *args, **kwargs):
46    """
47    Used by compiled autograd to handle hook returning None
48    """
49    result = hook(*args)
50    if result is None:
51        return args[0]
52    elif kwargs["hook_type"] == "post_acc_grad_hook":
53        raise RuntimeError("Tensor post accumulate grad hooks should return None.")
54    return result
55
56
57def wrap_numpy(f):
58    r"""Decorator that turns a function from ``np.ndarray``s to ``np.ndarray``s into a function
59    from ``torch.Tensor``s to ``torch.Tensor``s.
60    """
61    if not np:
62        return f
63
64    @functools.wraps(f)
65    def wrap(*args, **kwargs):
66        args, kwargs = pytree.tree_map_only(
67            torch.Tensor, lambda x: x.numpy(), (args, kwargs)
68        )
69        out = f(*args, **kwargs)
70        return pytree.tree_map_only(np.ndarray, lambda x: torch.as_tensor(x), out)
71
72    return wrap
73
74
75class FakeBackwardCFunction:
76    def __init__(
77        self,
78        real: torch.autograd.function.BackwardCFunction,
79        saved_tensors: List[torch.Tensor],
80    ) -> None:
81        self.real = real
82        self.saved_tensors = saved_tensors
83
84    def __getattr__(self, name):
85        if name == "saved_variables":
86            warnings.warn(
87                "'saved_variables' is deprecated; use 'saved_tensors'",
88                DeprecationWarning,
89            )
90            return self.saved_tensors
91
92        # route any attribute that isn't defined on this obj
93        return getattr(self.real, name)
94
95
96# This function corresponds to the "eager" implementation of a lifted autograd.Function.backward
97def call_backward(backward_c_function, saved_tensors, *args):
98    fake = FakeBackwardCFunction(backward_c_function, saved_tensors)
99    grads = fake._forward_cls.backward(fake, *args)  # type: ignore[attr-defined]
100
101    # in eager, we wrap in a tuple when there's only one grad output
102    if type(grads) is not tuple:
103        grads = (grads,)
104
105    return grads
106
107
108def untyped_storage_size(x: torch.Tensor):
109    return x.untyped_storage().size()
110
111
112class FakeCompiledAutogradEngine:
113    @staticmethod
114    def queue_callback(final_callbacks, cb):
115        final_callbacks.append(cb)
116
117    @staticmethod
118    def exec_final_callbacks(final_callbacks):
119        i = 0
120        while i < len(final_callbacks):
121            cb = final_callbacks[i]
122            cb()
123            i += 1
124        final_callbacks.clear()
125
126    @staticmethod
127    def _exec_final_callbacks_stub():
128        pass
129
130
131def call_hook_from_backward_state(*args, bw_state, hook_name: str, **kwargs):
132    return getattr(bw_state, hook_name)(*args, **kwargs)
133
134
135def call_module_hooks_from_backward_state(
136    _, result, *args, bw_state, hooks_name: str, module_name: str
137):
138    module = getattr(bw_state, module_name)
139    hooks = getattr(bw_state, hooks_name)
140    for hook in hooks:
141        new_result = hook(module, result, *args)
142        if new_result is not None:
143            result = new_result
144    return result
145