1# mypy: allow-untyped-defs 2from typing import Any, Callable, List, TypeVar 3 4import torch 5 6 7__all__ = [ 8 "compile", 9 "assume_constant_result", 10 "reset", 11 "allow_in_graph", 12 "substitute_in_graph", 13 "list_backends", 14 "disable", 15 "cudagraph_mark_step_begin", 16 "wrap_numpy", 17 "is_compiling", 18 "is_dynamo_compiling", 19] 20 21 22_F = TypeVar("_F", bound=Callable[..., Any]) 23 24 25def compile(*args, **kwargs): 26 """ 27 See :func:`torch.compile` for details on the arguments for this function. 28 """ 29 return torch.compile(*args, **kwargs) 30 31 32def reset() -> None: 33 """ 34 This function clears all compilation caches and restores the system to its initial state. 35 It is recommended to call this function, especially after using operations like `torch.compile(...)` 36 to ensure a clean state before another unrelated compilation 37 """ 38 import torch._dynamo 39 40 torch._dynamo.reset() 41 42 43def allow_in_graph(fn): 44 """ 45 Tells the compiler frontend (Dynamo) to skip symbolic introspection of the function 46 and instead directly write it to the graph when encountered. 47 48 If you are using :func:`torch.compile` (with backend="inductor" (the default)), or 49 :func:`torch.export.export`, and trying to black-box a Python function throughout 50 all tracing, do not use this API. 51 Instead, please create a custom operator (see :ref:`custom-ops-landing-page`) 52 53 .. warning:: 54 55 If you're a typical torch.compile user (e.g. you're applying torch.compile to 56 a model to make it run faster), you probably don't want to use this function. 57 :func:`allow_in_graph` is a footgun because it skips the compiler frontend 58 (Dynamo) that is responsible for doing safety checks (graph breaks, handling 59 closures, etc). Incorrect usage will lead to difficult-to-debug silent 60 incorrectness issues. 61 62 Given a Python function with no allow_in_graph decorator, regular execution 63 of torch.compile traces through the function. :func:`allow_in_graph` changes 64 it so that the frontend does not trace inside the function, but the compiler 65 backend still traces through it. Compare this to custom operators, which 66 treats a function as a black box throughout the torch.compile stack. The following 67 table compares these mechanisms. 68 69 +------------------------+-----------------------+--------------------------------+ 70 | Mechanism | Frontend (Dynamo) | Backend (AOTAutograd+Inductor) | 71 +========================+=======================+================================+ 72 | no decorator | trace inside | trace inside | 73 +------------------------+-----------------------+--------------------------------+ 74 | allow_in_graph | opaque callable | trace inside | 75 +------------------------+-----------------------+--------------------------------+ 76 | custom op | opaque callable | opaque callable | 77 +------------------------+-----------------------+--------------------------------+ 78 79 One common use case for :func:`allow_in_graph()` is as an escape hatch for the compiler 80 frontend: if you know the function works w.r.t. to the downstream components of the 81 compilation stack (AOTAutograd and Inductor) but there is a Dynamo bug that prevents it from 82 symbolically introspecting the function properly (or if your code is in C/C++ and 83 therefore cannot be introspected with Dynamo), then one can decorate said function 84 with :func:`allow_in_graph` to bypass Dynamo. 85 86 We require that ``fn`` adhere to the following restrictions. Failure to adhere 87 results in undefined behavior: 88 89 - The inputs to ``fn`` must be Proxy-able types in the FX graph. Valid types include: 90 Tensor/int/bool/float/None/List[Tensor?]/List[int?]/List[float?] 91 Tuple[Tensor?, ...]/Tuple[int?, ...]/Tuple[float?, ...]/torch.dtype/torch.device 92 - The outputs to ``fn`` must be Proxy-able types in the FX graph (see previous bullet) 93 - all Tensors used inside of ``fn`` must be passed directly as inputs to ``fn`` 94 (as opposed to being captured variables). 95 96 Args: 97 fn: A callable representing the function to be included in the graph. 98 If ``fn`` is a list or tuple of callables it recursively applies 99 :func:`allow_in_graph()` to each function and returns a new list or 100 tuple containing the modified functions. 101 102 Example:: 103 104 torch.compiler.allow_in_graph(my_custom_function) 105 106 @torch.compile(...) 107 def fn(a): 108 x = torch.add(x, 1) 109 x = my_custom_function(x) 110 x = torch.add(x, 1) 111 return x 112 113 fn(...) 114 115 Will capture a single graph containing ``my_custom_function()``. 116 117 """ 118 import torch._dynamo 119 120 return torch._dynamo.allow_in_graph(fn) 121 122 123def substitute_in_graph( 124 original_fn: _F, 125 *, 126 can_constant_fold_through: bool = False, 127 skip_signature_check: bool = False, 128) -> Callable[[_F], _F]: 129 """ 130 Register a polyfill handler for a function, usually a C function from the C extension, to be 131 used in place of the original function when inlining the original function in the graph. 132 133 .. note:: 134 135 The polyfill handler is only used when inlining the original function. It is not used when 136 the original function is called directly. In the eager mode, the decorated function calls 137 the performant C function rather than the polyfill handler. 138 139 The polyfill handler is a function that will be called in place of the original function when 140 inlining the original function. The polyfill handler should have the same signature and the same 141 behavior as the original function. 142 143 Args: 144 original_fn (callable): The original function, usually a C function, to register a polyfill 145 handler for. 146 can_constant_fold_through (bool, optional): Whether the polyfill handler can be constant 147 folded through. That is, if the polyfill handler is a pure function and its arguments 148 are constant, the result of the polyfill handler can be constant folded during the 149 compilation. Defaults to ``False``. 150 skip_signature_check (bool, optional): Whether to skip the signature check between the 151 original function and the polyfill handler. Defaults to ``False``. 152 153 Returns: 154 A decorator that registers the polyfill handler for the original function. 155 156 Example:: 157 158 >>> import operator 159 >>> operator.indexOf([1, 2, 3, 4, 5], 3) 160 2 161 >>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3) 162 ... # xdoctest: +SKIP("Long tracebacks") 163 Traceback (most recent call last): 164 ... 165 torch._dynamo.exc.Unsupported: ... 166 167 >>> @torch.compiler.substitute_in_graph(operator.indexOf) 168 ... def indexOf(a, b, /): 169 ... for i, item in enumerate(a): 170 ... if item is b or item == b: 171 ... return i 172 ... raise ValueError("sequence.index(x): x not in sequence") 173 >>> 174 >>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3) 175 2 176 """ 177 import torch._dynamo 178 179 return torch._dynamo.substitute_in_graph( 180 original_fn, 181 can_constant_fold_through=can_constant_fold_through, 182 skip_signature_check=skip_signature_check, 183 ) 184 185 186def list_backends(exclude_tags=("debug", "experimental")) -> List[str]: 187 """ 188 Return valid strings that can be passed to `torch.compile(..., backend="name")`. 189 190 Args: 191 exclude_tags(optional): A tuple of strings representing tags to exclude. 192 """ 193 import torch._dynamo 194 195 return torch._dynamo.list_backends(exclude_tags) 196 197 198def assume_constant_result(fn): 199 """ 200 This function is used to mark a function `fn` as having a constant result. 201 This allows the compiler to optimize away your function 202 Returns The same function `fn` 203 204 Args: 205 fn: The function to be marked as having a constant result. 206 207 .. warning:: 208 `assume_constant_result` can if invalid cause safety and soundness issues, :func:`torch.compile` 209 will not attempt to validate whether the constant assumption is true or not 210 211 """ 212 import torch._dynamo 213 214 return torch._dynamo.assume_constant_result(fn) 215 216 217def disable(fn=None, recursive=True): 218 """ 219 This function provides both a decorator and a context manager to disable compilation on a function 220 It also provides the option of recursively disabling called functions 221 222 Args: 223 fn (optional): The function to disable 224 recursive (optional): A boolean value indicating whether the disabling should be recursive. 225 """ 226 import torch._dynamo 227 228 return torch._dynamo.disable(fn, recursive) 229 230 231def cudagraph_mark_step_begin(): 232 """ 233 Indicates that a new iteration of inference or training is about to begin. 234 235 CUDA Graphs will free tensors of a prior iteration. A new iteration is started on each invocation of 236 torch.compile, so long as there is not a pending backward that has not been called. 237 238 If that heuristic is wrong, such as in the following example, manually mark it with this api. 239 240 .. code-block:: python 241 242 @torch.compile(mode="reduce-overhead") 243 def rand_foo(): 244 return torch.rand([4], device="cuda") 245 246 for _ in range(5): 247 torch.compiler.cudagraph_mark_step_begin() 248 rand_foo() + rand_foo() 249 250 For more details, see `torch.compiler_cudagraph_trees <https://pytorch.org/docs/main/torch.compiler_cudagraph_trees.html>`__ 251 """ 252 from torch._inductor import cudagraph_trees 253 254 cudagraph_trees.mark_step_begin() 255 256 257def wrap_numpy(fn): 258 r"""Decorator that turns a function from ``np.ndarray``s to ``np.ndarray``s into a function 259 from ``torch.Tensor``s to ``torch.Tensor``s. 260 261 It is designed to be used with :func:`torch.compile` with ``fullgraph=True``. It allows to 262 compile a NumPy function as if it were a PyTorch function. This allows you to run NumPy code 263 on CUDA or compute its gradients. 264 265 .. note:: 266 267 This decorator does not work without :func:`torch.compile`. 268 269 Example:: 270 271 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) 272 >>> # Compile a NumPy function as a Tensor -> Tensor function 273 >>> @torch.compile(fullgraph=True) 274 >>> @torch.compiler.wrap_numpy 275 >>> def fn(a: np.ndarray): 276 >>> return np.sum(a * a) 277 >>> # Execute the NumPy function using Tensors on CUDA and compute the gradients 278 >>> x = torch.arange(6, dtype=torch.float32, device="cuda", requires_grad=True) 279 >>> out = fn(x) 280 >>> out.backward() 281 >>> print(x.grad) 282 tensor([ 0., 2., 4., 6., 8., 10.], device='cuda:0') 283 """ 284 from torch._dynamo.external_utils import wrap_numpy as wrap 285 286 return wrap(fn) 287 288 289_is_compiling_flag: bool = False 290 291 292def is_compiling() -> bool: 293 """ 294 Indicates whether a graph is executed/traced as part of torch.compile() or torch.export(). 295 296 Note that there are 2 other related flags that should deprecated eventually: 297 * torch._dynamo.external_utils.is_compiling() 298 * torch._utils.is_compiling() 299 300 Example:: 301 302 >>> def forward(self, x): 303 >>> if not torch.compiler.is_compiling(): 304 >>> pass # ...logic that is not needed in a compiled/traced graph... 305 >>> 306 >>> # ...rest of the function... 307 """ 308 if torch.jit.is_scripting(): 309 return False 310 else: 311 return _is_compiling_flag 312 313 314def is_dynamo_compiling() -> bool: 315 """ 316 Indicates whether a graph is traced via TorchDynamo. 317 318 It's stricter than is_compiling() flag, as it would only be set to True when 319 TorchDynamo is used. 320 321 Example:: 322 323 >>> def forward(self, x): 324 >>> if not torch.compiler.is_dynamo_compiling(): 325 >>> pass # ...logic that is not needed in a TorchDynamo-traced graph... 326 >>> 327 >>> # ...rest of the function... 328 """ 329 return False 330