1# mypy: allow-untyped-defs 2import os 3from collections import namedtuple 4from typing import Any 5 6import torch 7 8from .grad_mode import _DecoratorContextManager 9 10 11__all__ = [ 12 "UnpackedDualTensor", 13 "enter_dual_level", 14 "exit_dual_level", 15 "make_dual", 16 "unpack_dual", 17 "dual_level", 18] 19 20# Global variable used to make the python API simpler to use 21_current_level = -1 22 23 24def enter_dual_level(): 25 r"""Enter a new forward grad level. 26 27 This level can be used to make and unpack dual Tensors to compute 28 forward gradients. 29 30 This function also updates the current level that is used by default 31 by the other functions in this API. 32 """ 33 global _current_level 34 new_level = torch._C._enter_dual_level() 35 if new_level != _current_level + 1: 36 raise RuntimeError( 37 "Entering a new forward AD level but the current level " 38 "is not valid. Make sure you did not modified it directly." 39 ) 40 _current_level = new_level 41 return new_level 42 43 44def exit_dual_level(*, level=None): 45 r"""Exit a forward grad level. 46 47 This function deletes all the gradients associated with this 48 level. Only deleting the latest entered level is allowed. 49 50 This function also updates the current level that is used by default 51 by the other functions in this API. 52 """ 53 global _current_level 54 if level is None: 55 level = _current_level 56 if level != _current_level: 57 raise RuntimeError( 58 "Trying to exit a forward AD level that was not the last one " 59 "that was created. This is not supported." 60 ) 61 torch._C._exit_dual_level(level=level) 62 _current_level = level - 1 63 64 65def _maybe_load_decompositions(): 66 if os.environ.get("PYTORCH_JIT", "1") == "1" and __debug__: 67 from torch._decomp import decompositions_for_jvp # noqa: F401 68 69 70def make_dual(tensor, tangent, *, level=None): 71 r"""Associate a tensor value with its tangent to create a "dual tensor" for forward AD gradient computation. 72 73 The result is a new tensor aliased to :attr:`tensor` with :attr:`tangent` embedded 74 as an attribute as-is if it has the same storage layout or copied otherwise. 75 The tangent attribute can be recovered with :func:`unpack_dual`. 76 77 This function is backward differentiable. 78 79 Given a function `f` whose jacobian is `J`, it allows one to compute the Jacobian-vector product (`jvp`) 80 between `J` and a given vector `v` as follows. 81 82 Example:: 83 84 >>> # xdoctest: +SKIP("Undefined variables") 85 >>> with dual_level(): 86 ... inp = make_dual(x, v) 87 ... out = f(inp) 88 ... y, jvp = unpack_dual(out) 89 90 Please see the `forward-mode AD tutorial <https://pytorch.org/tutorials/intermediate/forward_ad_usage.html>`__ 91 for detailed steps on how to use this API. 92 93 """ 94 # See NOTE: [forward-mode AD decompositions mechanism] 95 # 96 # Import from torch._decomp import decompositions_for_jvp to register 97 # decompositions for jvp to the jit registry 98 # 99 # FIXME: We specify that __debug__ must be True because 100 # if python is run with -OO or -O flags (i.e., __debug__ is False), we encounter the 101 # following error: 102 # 103 # Return value was annotated as having type Tuple[NoneType, NoneType] but is actually of 104 # type Tuple[Tensor, Tensor]: 105 # File ".../torch/_decomp/__init__.py", line 1585 106 # else: 107 # buffer = z 108 # return min - torch.log1p(z), buffer 109 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE 110 _maybe_load_decompositions() 111 112 if level is None: 113 level = _current_level 114 115 if level < 0: 116 raise RuntimeError( 117 "Trying to create a dual Tensor for forward AD but no level " 118 "exists, make sure to enter_dual_level() first." 119 ) 120 if not (tensor.is_floating_point() or tensor.is_complex()): 121 raise ValueError( 122 f"Expected primal to be floating point or complex, but got: {tensor.dtype}" 123 ) 124 if not (tangent.is_floating_point() or tangent.is_complex()): 125 raise ValueError( 126 f"Expected tangent to be floating point or complex, but got: {tangent.dtype}" 127 ) 128 129 return torch._VF._make_dual(tensor, tangent, level=level) 130 131 132_UnpackedDualTensor = namedtuple("_UnpackedDualTensor", ["primal", "tangent"]) 133 134 135class UnpackedDualTensor(_UnpackedDualTensor): 136 r"""Namedtuple returned by :func:`unpack_dual` containing the primal and tangent components of the dual tensor. 137 138 See :func:`unpack_dual` for more details. 139 140 """ 141 142 143def unpack_dual(tensor, *, level=None): 144 r"""Unpack a "dual tensor" to get both its Tensor value and its forward AD gradient. 145 146 The result is a namedtuple ``(primal, tangent)`` where ``primal`` is a view of 147 :attr:`tensor`'s primal and ``tangent`` is :attr:`tensor`'s tangent as-is. 148 Neither of these tensors can be dual tensor of level :attr:`level`. 149 150 This function is backward differentiable. 151 152 Example:: 153 154 >>> # xdoctest: +SKIP("Undefined variables") 155 >>> with dual_level(): 156 ... inp = make_dual(x, x_t) 157 ... out = f(inp) 158 ... y, jvp = unpack_dual(out) 159 ... jvp = unpack_dual(out).tangent 160 161 Please see the `forward-mode AD tutorial <https://pytorch.org/tutorials/intermediate/forward_ad_usage.html>`__ 162 for detailed steps on how to use this API. 163 """ 164 if level is None: 165 level = _current_level 166 167 if level < 0: 168 return UnpackedDualTensor(tensor, None) 169 170 primal, dual = torch._VF._unpack_dual(tensor, level=level) 171 172 return UnpackedDualTensor(primal, dual) 173 174 175class dual_level(_DecoratorContextManager): 176 r"""Context-manager for forward AD, where all forward AD computation must occur within the ``dual_level`` context. 177 178 .. Note:: 179 180 The ``dual_level`` context appropriately enters and exit the dual level to 181 controls the current forward AD level, which is used by default by the other 182 functions in this API. 183 184 We currently don't plan to support nested ``dual_level`` contexts, however, so 185 only a single forward AD level is supported. To compute higher-order 186 forward grads, one can use :func:`torch.func.jvp`. 187 188 Example:: 189 190 >>> # xdoctest: +SKIP("Undefined variables") 191 >>> x = torch.tensor([1]) 192 >>> x_t = torch.tensor([1]) 193 >>> with dual_level(): 194 ... inp = make_dual(x, x_t) 195 ... # Do computations with inp 196 ... out = your_fn(inp) 197 ... _, grad = unpack_dual(out) 198 >>> grad is None 199 False 200 >>> # After exiting the level, the grad is deleted 201 >>> _, grad_after = unpack_dual(out) 202 >>> grad is None 203 True 204 205 Please see the `forward-mode AD tutorial <https://pytorch.org/tutorials/intermediate/forward_ad_usage.html>`__ 206 for detailed steps on how to use this API. 207 """ 208 209 def __enter__(self): 210 return enter_dual_level() 211 212 def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 213 exit_dual_level() 214 215 216# Private helper functions 217_is_fwd_grad_enabled = torch._C._is_fwd_grad_enabled 218 219 220# Private helper function to enable or disable fwd grad. 221# If you're a user and want to use this, please file an issue to discuss the use case. 222class _set_fwd_grad_enabled(_DecoratorContextManager): 223 def __init__(self, mode: bool) -> None: 224 self.prev = _is_fwd_grad_enabled() 225 torch._C._set_fwd_grad_enabled(mode) 226 227 def __enter__(self) -> None: 228 pass 229 230 def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 231 torch._C._set_fwd_grad_enabled(self.prev) 232