xref: /aosp_15_r20/external/pytorch/torch/compiler/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Any, Callable, List, TypeVar
3
4import torch
5
6
7__all__ = [
8    "compile",
9    "assume_constant_result",
10    "reset",
11    "allow_in_graph",
12    "substitute_in_graph",
13    "list_backends",
14    "disable",
15    "cudagraph_mark_step_begin",
16    "wrap_numpy",
17    "is_compiling",
18    "is_dynamo_compiling",
19]
20
21
22_F = TypeVar("_F", bound=Callable[..., Any])
23
24
25def compile(*args, **kwargs):
26    """
27    See :func:`torch.compile` for details on the arguments for this function.
28    """
29    return torch.compile(*args, **kwargs)
30
31
32def reset() -> None:
33    """
34    This function clears all compilation caches and restores the system to its initial state.
35    It is recommended to call this function, especially after using operations like `torch.compile(...)`
36    to ensure a clean state before another unrelated compilation
37    """
38    import torch._dynamo
39
40    torch._dynamo.reset()
41
42
43def allow_in_graph(fn):
44    """
45    Tells the compiler frontend (Dynamo) to skip symbolic introspection of the function
46    and instead directly write it to the graph when encountered.
47
48    If you are using :func:`torch.compile` (with backend="inductor" (the default)), or
49    :func:`torch.export.export`, and trying to black-box a Python function throughout
50    all tracing, do not use this API.
51    Instead, please create a custom operator (see :ref:`custom-ops-landing-page`)
52
53    .. warning::
54
55        If you're a typical torch.compile user (e.g. you're applying torch.compile to
56        a model to make it run faster), you probably don't want to use this function.
57        :func:`allow_in_graph` is a footgun because it skips the compiler frontend
58        (Dynamo) that is responsible for doing safety checks (graph breaks, handling
59        closures, etc). Incorrect usage will lead to difficult-to-debug silent
60        incorrectness issues.
61
62    Given a Python function with no allow_in_graph decorator, regular execution
63    of torch.compile traces through the function. :func:`allow_in_graph` changes
64    it so that the frontend does not trace inside the function, but the compiler
65    backend still traces through it. Compare this to custom operators, which
66    treats a function as a black box throughout the torch.compile stack. The following
67    table compares these mechanisms.
68
69    +------------------------+-----------------------+--------------------------------+
70    | Mechanism              | Frontend (Dynamo)     | Backend (AOTAutograd+Inductor) |
71    +========================+=======================+================================+
72    | no decorator           | trace inside          | trace inside                   |
73    +------------------------+-----------------------+--------------------------------+
74    | allow_in_graph         | opaque callable       | trace inside                   |
75    +------------------------+-----------------------+--------------------------------+
76    | custom op              | opaque callable       | opaque callable                |
77    +------------------------+-----------------------+--------------------------------+
78
79    One common use case for :func:`allow_in_graph()` is as an escape hatch for the compiler
80    frontend: if you know the function works w.r.t. to the downstream components of the
81    compilation stack (AOTAutograd and Inductor) but there is a Dynamo bug that prevents it from
82    symbolically introspecting the function properly (or if your code is in C/C++ and
83    therefore cannot be introspected with Dynamo), then one can decorate said function
84    with :func:`allow_in_graph` to bypass Dynamo.
85
86    We require that ``fn`` adhere to the following restrictions. Failure to adhere
87    results in undefined behavior:
88
89    - The inputs to ``fn`` must be Proxy-able types in the FX graph. Valid types include:
90      Tensor/int/bool/float/None/List[Tensor?]/List[int?]/List[float?]
91      Tuple[Tensor?, ...]/Tuple[int?, ...]/Tuple[float?, ...]/torch.dtype/torch.device
92    - The outputs to ``fn`` must be Proxy-able types in the FX graph (see previous bullet)
93    - all Tensors used inside of ``fn`` must be passed directly as inputs to ``fn``
94      (as opposed to being captured variables).
95
96    Args:
97        fn: A callable representing the function to be included in the graph.
98            If ``fn`` is a list or tuple of callables it recursively applies
99            :func:`allow_in_graph()` to each function and returns a new list or
100            tuple containing the modified functions.
101
102    Example::
103
104        torch.compiler.allow_in_graph(my_custom_function)
105
106        @torch.compile(...)
107        def fn(a):
108            x = torch.add(x, 1)
109            x = my_custom_function(x)
110            x = torch.add(x, 1)
111            return x
112
113        fn(...)
114
115    Will capture a single graph containing ``my_custom_function()``.
116
117    """
118    import torch._dynamo
119
120    return torch._dynamo.allow_in_graph(fn)
121
122
123def substitute_in_graph(
124    original_fn: _F,
125    *,
126    can_constant_fold_through: bool = False,
127    skip_signature_check: bool = False,
128) -> Callable[[_F], _F]:
129    """
130    Register a polyfill handler for a function, usually a C function from the C extension, to be
131    used in place of the original function when inlining the original function in the graph.
132
133    .. note::
134
135        The polyfill handler is only used when inlining the original function. It is not used when
136        the original function is called directly. In the eager mode, the decorated function calls
137        the performant C function rather than the polyfill handler.
138
139    The polyfill handler is a function that will be called in place of the original function when
140    inlining the original function. The polyfill handler should have the same signature and the same
141    behavior as the original function.
142
143    Args:
144        original_fn (callable): The original function, usually a C function, to register a polyfill
145            handler for.
146        can_constant_fold_through (bool, optional): Whether the polyfill handler can be constant
147            folded through. That is, if the polyfill handler is a pure function and its arguments
148            are constant, the result of the polyfill handler can be constant folded during the
149            compilation. Defaults to ``False``.
150        skip_signature_check (bool, optional): Whether to skip the signature check between the
151            original function and the polyfill handler. Defaults to ``False``.
152
153    Returns:
154        A decorator that registers the polyfill handler for the original function.
155
156    Example::
157
158        >>> import operator
159        >>> operator.indexOf([1, 2, 3, 4, 5], 3)
160        2
161        >>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3)
162        ... # xdoctest: +SKIP("Long tracebacks")
163        Traceback (most recent call last):
164        ...
165        torch._dynamo.exc.Unsupported: ...
166
167        >>> @torch.compiler.substitute_in_graph(operator.indexOf)
168        ... def indexOf(a, b, /):
169        ...     for i, item in enumerate(a):
170        ...         if item is b or item == b:
171        ...             return i
172        ...     raise ValueError("sequence.index(x): x not in sequence")
173        >>>
174        >>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3)
175        2
176    """
177    import torch._dynamo
178
179    return torch._dynamo.substitute_in_graph(
180        original_fn,
181        can_constant_fold_through=can_constant_fold_through,
182        skip_signature_check=skip_signature_check,
183    )
184
185
186def list_backends(exclude_tags=("debug", "experimental")) -> List[str]:
187    """
188    Return valid strings that can be passed to `torch.compile(..., backend="name")`.
189
190    Args:
191        exclude_tags(optional): A tuple of strings representing tags to exclude.
192    """
193    import torch._dynamo
194
195    return torch._dynamo.list_backends(exclude_tags)
196
197
198def assume_constant_result(fn):
199    """
200    This function is used to mark a function `fn` as having a constant result.
201    This allows the compiler to optimize away your function
202    Returns The same function `fn`
203
204    Args:
205        fn: The function to be marked as having a constant result.
206
207    .. warning::
208        `assume_constant_result` can if invalid cause safety and soundness issues, :func:`torch.compile`
209        will not attempt to validate whether the constant assumption is true or not
210
211    """
212    import torch._dynamo
213
214    return torch._dynamo.assume_constant_result(fn)
215
216
217def disable(fn=None, recursive=True):
218    """
219    This function provides both a decorator and a context manager to disable compilation on a function
220    It also provides the option of recursively disabling called functions
221
222    Args:
223        fn (optional): The function to disable
224        recursive (optional): A boolean value indicating whether the disabling should be recursive.
225    """
226    import torch._dynamo
227
228    return torch._dynamo.disable(fn, recursive)
229
230
231def cudagraph_mark_step_begin():
232    """
233    Indicates that a new iteration of inference or training is about to begin.
234
235    CUDA Graphs will free tensors of a prior iteration. A new iteration is started on each invocation of
236    torch.compile, so long as there is not a pending backward that has not been called.
237
238    If that heuristic is wrong, such as in the following example, manually mark it with this api.
239
240    .. code-block:: python
241
242        @torch.compile(mode="reduce-overhead")
243        def rand_foo():
244            return torch.rand([4], device="cuda")
245
246        for _ in range(5):
247            torch.compiler.cudagraph_mark_step_begin()
248            rand_foo() + rand_foo()
249
250    For more details, see `torch.compiler_cudagraph_trees <https://pytorch.org/docs/main/torch.compiler_cudagraph_trees.html>`__
251    """
252    from torch._inductor import cudagraph_trees
253
254    cudagraph_trees.mark_step_begin()
255
256
257def wrap_numpy(fn):
258    r"""Decorator that turns a function from ``np.ndarray``s to ``np.ndarray``s into a function
259    from ``torch.Tensor``s to ``torch.Tensor``s.
260
261    It is designed to be used with :func:`torch.compile` with ``fullgraph=True``. It allows to
262    compile a NumPy function as if it were a PyTorch function. This allows you to run NumPy code
263    on CUDA or compute its gradients.
264
265    .. note::
266
267        This decorator does not work without :func:`torch.compile`.
268
269    Example::
270
271        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
272        >>> # Compile a NumPy function as a Tensor -> Tensor function
273        >>> @torch.compile(fullgraph=True)
274        >>> @torch.compiler.wrap_numpy
275        >>> def fn(a: np.ndarray):
276        >>>     return np.sum(a * a)
277        >>> # Execute the NumPy function using Tensors on CUDA and compute the gradients
278        >>> x = torch.arange(6, dtype=torch.float32, device="cuda", requires_grad=True)
279        >>> out = fn(x)
280        >>> out.backward()
281        >>> print(x.grad)
282        tensor([ 0.,  2.,  4.,  6.,  8., 10.], device='cuda:0')
283    """
284    from torch._dynamo.external_utils import wrap_numpy as wrap
285
286    return wrap(fn)
287
288
289_is_compiling_flag: bool = False
290
291
292def is_compiling() -> bool:
293    """
294    Indicates whether a graph is executed/traced as part of torch.compile() or torch.export().
295
296    Note that there are 2 other related flags that should deprecated eventually:
297      * torch._dynamo.external_utils.is_compiling()
298      * torch._utils.is_compiling()
299
300    Example::
301
302        >>> def forward(self, x):
303        >>>     if not torch.compiler.is_compiling():
304        >>>        pass # ...logic that is not needed in a compiled/traced graph...
305        >>>
306        >>>     # ...rest of the function...
307    """
308    if torch.jit.is_scripting():
309        return False
310    else:
311        return _is_compiling_flag
312
313
314def is_dynamo_compiling() -> bool:
315    """
316    Indicates whether a graph is traced via TorchDynamo.
317
318    It's stricter than is_compiling() flag, as it would only be set to True when
319    TorchDynamo is used.
320
321    Example::
322
323        >>> def forward(self, x):
324        >>>     if not torch.compiler.is_dynamo_compiling():
325        >>>        pass # ...logic that is not needed in a TorchDynamo-traced graph...
326        >>>
327        >>>     # ...rest of the function...
328    """
329    return False
330