1# mypy: allow-untyped-defs 2# This module contains functions that *will be allowed* by dynamo 3 4import functools 5import warnings 6from typing import List 7 8import torch 9import torch.utils._pytree as pytree 10 11 12try: 13 import numpy as np 14except ModuleNotFoundError: 15 np = None # type: ignore[assignment] 16 17 18def is_compiling() -> bool: 19 """ 20 Indicates whether we are tracing/compiling with torch.compile() or torch.export(). 21 22 If need to check specifically that TorchDynamo is used, then use 23 torch.compiler.is_dynamo_compiling(). 24 25 TODO(khabinov): we should deprecate this function and use one of these two: 26 * torch.compiler.is_compiling(), 27 * torch.compiler.is_dynamo_compiling(). 28 It will depend on the context where to use what. 29 """ 30 return torch.compiler.is_compiling() 31 32 33def wrap_inline(fn): 34 """ 35 Create an extra frame around fn that is not in skipfiles 36 """ 37 38 @functools.wraps(fn) 39 def inner(*args, **kwargs): 40 return fn(*args, **kwargs) 41 42 return inner 43 44 45def call_hook(hook, *args, **kwargs): 46 """ 47 Used by compiled autograd to handle hook returning None 48 """ 49 result = hook(*args) 50 if result is None: 51 return args[0] 52 elif kwargs["hook_type"] == "post_acc_grad_hook": 53 raise RuntimeError("Tensor post accumulate grad hooks should return None.") 54 return result 55 56 57def wrap_numpy(f): 58 r"""Decorator that turns a function from ``np.ndarray``s to ``np.ndarray``s into a function 59 from ``torch.Tensor``s to ``torch.Tensor``s. 60 """ 61 if not np: 62 return f 63 64 @functools.wraps(f) 65 def wrap(*args, **kwargs): 66 args, kwargs = pytree.tree_map_only( 67 torch.Tensor, lambda x: x.numpy(), (args, kwargs) 68 ) 69 out = f(*args, **kwargs) 70 return pytree.tree_map_only(np.ndarray, lambda x: torch.as_tensor(x), out) 71 72 return wrap 73 74 75class FakeBackwardCFunction: 76 def __init__( 77 self, 78 real: torch.autograd.function.BackwardCFunction, 79 saved_tensors: List[torch.Tensor], 80 ) -> None: 81 self.real = real 82 self.saved_tensors = saved_tensors 83 84 def __getattr__(self, name): 85 if name == "saved_variables": 86 warnings.warn( 87 "'saved_variables' is deprecated; use 'saved_tensors'", 88 DeprecationWarning, 89 ) 90 return self.saved_tensors 91 92 # route any attribute that isn't defined on this obj 93 return getattr(self.real, name) 94 95 96# This function corresponds to the "eager" implementation of a lifted autograd.Function.backward 97def call_backward(backward_c_function, saved_tensors, *args): 98 fake = FakeBackwardCFunction(backward_c_function, saved_tensors) 99 grads = fake._forward_cls.backward(fake, *args) # type: ignore[attr-defined] 100 101 # in eager, we wrap in a tuple when there's only one grad output 102 if type(grads) is not tuple: 103 grads = (grads,) 104 105 return grads 106 107 108def untyped_storage_size(x: torch.Tensor): 109 return x.untyped_storage().size() 110 111 112class FakeCompiledAutogradEngine: 113 @staticmethod 114 def queue_callback(final_callbacks, cb): 115 final_callbacks.append(cb) 116 117 @staticmethod 118 def exec_final_callbacks(final_callbacks): 119 i = 0 120 while i < len(final_callbacks): 121 cb = final_callbacks[i] 122 cb() 123 i += 1 124 final_callbacks.clear() 125 126 @staticmethod 127 def _exec_final_callbacks_stub(): 128 pass 129 130 131def call_hook_from_backward_state(*args, bw_state, hook_name: str, **kwargs): 132 return getattr(bw_state, hook_name)(*args, **kwargs) 133 134 135def call_module_hooks_from_backward_state( 136 _, result, *args, bw_state, hooks_name: str, module_name: str 137): 138 module = getattr(bw_state, module_name) 139 hooks = getattr(bw_state, hooks_name) 140 for hook in hooks: 141 new_result = hook(module, result, *args) 142 if new_result is not None: 143 result = new_result 144 return result 145