1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerimport abc 3*da0073e9SAndroid Build Coastguard Workerimport collections 4*da0073e9SAndroid Build Coastguard Workerimport contextlib 5*da0073e9SAndroid Build Coastguard Workerimport functools 6*da0073e9SAndroid Build Coastguard Workerimport logging 7*da0073e9SAndroid Build Coastguard Workerimport threading 8*da0073e9SAndroid Build Coastguard Workerimport weakref 9*da0073e9SAndroid Build Coastguard Workerfrom collections import defaultdict, namedtuple 10*da0073e9SAndroid Build Coastguard Workerfrom typing import ( 11*da0073e9SAndroid Build Coastguard Worker Any, 12*da0073e9SAndroid Build Coastguard Worker Callable, 13*da0073e9SAndroid Build Coastguard Worker cast, 14*da0073e9SAndroid Build Coastguard Worker Deque, 15*da0073e9SAndroid Build Coastguard Worker Dict, 16*da0073e9SAndroid Build Coastguard Worker List, 17*da0073e9SAndroid Build Coastguard Worker Optional, 18*da0073e9SAndroid Build Coastguard Worker Sequence, 19*da0073e9SAndroid Build Coastguard Worker Set, 20*da0073e9SAndroid Build Coastguard Worker Tuple, 21*da0073e9SAndroid Build Coastguard Worker Union, 22*da0073e9SAndroid Build Coastguard Worker) 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Workerimport torch 25*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd.variable import Variable 26*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._python_dispatch import TorchDispatchMode 27*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.hooks import RemovableHandle 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Workerlog = logging.getLogger(__name__) 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker__all__ = [ 33*da0073e9SAndroid Build Coastguard Worker "saved_tensors_hooks", 34*da0073e9SAndroid Build Coastguard Worker "save_on_cpu", 35*da0073e9SAndroid Build Coastguard Worker "disable_saved_tensors_hooks", 36*da0073e9SAndroid Build Coastguard Worker "register_multi_grad_hook", 37*da0073e9SAndroid Build Coastguard Worker "allow_mutation_on_saved_tensors", 38*da0073e9SAndroid Build Coastguard Worker "Node", 39*da0073e9SAndroid Build Coastguard Worker "GradientEdge", 40*da0073e9SAndroid Build Coastguard Worker "get_gradient_edge", 41*da0073e9SAndroid Build Coastguard Worker "increment_version", 42*da0073e9SAndroid Build Coastguard Worker] 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Workerclass Node(abc.ABC): 46*da0073e9SAndroid Build Coastguard Worker @abc.abstractmethod 47*da0073e9SAndroid Build Coastguard Worker def name(self) -> str: 48*da0073e9SAndroid Build Coastguard Worker r"""Return the name. 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker Example:: 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker >>> import torch 53*da0073e9SAndroid Build Coastguard Worker >>> a = torch.tensor([0., 0., 0.], requires_grad=True) 54*da0073e9SAndroid Build Coastguard Worker >>> b = a.clone() 55*da0073e9SAndroid Build Coastguard Worker >>> assert isinstance(b.grad_fn, torch.autograd.graph.Node) 56*da0073e9SAndroid Build Coastguard Worker >>> print(b.grad_fn.name()) 57*da0073e9SAndroid Build Coastguard Worker CloneBackward0 58*da0073e9SAndroid Build Coastguard Worker """ 59*da0073e9SAndroid Build Coastguard Worker ... 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker @property 62*da0073e9SAndroid Build Coastguard Worker @abc.abstractmethod 63*da0073e9SAndroid Build Coastguard Worker def next_functions(self) -> Tuple[Tuple[Optional["Node"], int], ...]: 64*da0073e9SAndroid Build Coastguard Worker ... 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Worker @abc.abstractmethod 67*da0073e9SAndroid Build Coastguard Worker def metadata(self) -> dict: 68*da0073e9SAndroid Build Coastguard Worker r"""Return the metadata.""" 69*da0073e9SAndroid Build Coastguard Worker ... 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker @abc.abstractmethod 72*da0073e9SAndroid Build Coastguard Worker def _register_hook_dict(self, tensor: torch.Tensor) -> None: 73*da0073e9SAndroid Build Coastguard Worker ... 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker @abc.abstractmethod 76*da0073e9SAndroid Build Coastguard Worker def register_hook(self, fn: Callable[..., Any]) -> RemovableHandle: 77*da0073e9SAndroid Build Coastguard Worker r"""Register a backward hook. 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker The hook will be called every time a gradient with respect to the 80*da0073e9SAndroid Build Coastguard Worker Node is computed. The hook should have the following signature:: 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker hook(grad_inputs: Tuple[Tensor], grad_outputs: Tuple[Tensor]) -> Tuple[Tensor] or None 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker The hook should not modify its argument, but it can optionally return 86*da0073e9SAndroid Build Coastguard Worker a new gradient which will be used in place of :attr:`grad_inputs`. 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Worker This function returns a handle with a method ``handle.remove()`` 89*da0073e9SAndroid Build Coastguard Worker that removes the hook from the module. 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker .. note:: 92*da0073e9SAndroid Build Coastguard Worker See :ref:`backward-hooks-execution` for more information on how when this hook 93*da0073e9SAndroid Build Coastguard Worker is executed, and how its execution is ordered relative to other hooks. 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker Example:: 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker >>> import torch 98*da0073e9SAndroid Build Coastguard Worker >>> a = torch.tensor([0., 0., 0.], requires_grad=True) 99*da0073e9SAndroid Build Coastguard Worker >>> b = a.clone() 100*da0073e9SAndroid Build Coastguard Worker >>> assert isinstance(b.grad_fn, torch.autograd.graph.Node) 101*da0073e9SAndroid Build Coastguard Worker >>> handle = b.grad_fn.register_hook(lambda gI, gO: (gO[0] * 2,)) 102*da0073e9SAndroid Build Coastguard Worker >>> b.sum().backward(retain_graph=True) 103*da0073e9SAndroid Build Coastguard Worker >>> print(a.grad) 104*da0073e9SAndroid Build Coastguard Worker tensor([2., 2., 2.]) 105*da0073e9SAndroid Build Coastguard Worker >>> handle.remove() # Removes the hook 106*da0073e9SAndroid Build Coastguard Worker >>> a.grad = None 107*da0073e9SAndroid Build Coastguard Worker >>> b.sum().backward(retain_graph=True) 108*da0073e9SAndroid Build Coastguard Worker >>> print(a.grad) 109*da0073e9SAndroid Build Coastguard Worker tensor([1., 1., 1.]) 110*da0073e9SAndroid Build Coastguard Worker """ 111*da0073e9SAndroid Build Coastguard Worker ... 112*da0073e9SAndroid Build Coastguard Worker 113*da0073e9SAndroid Build Coastguard Worker @abc.abstractmethod 114*da0073e9SAndroid Build Coastguard Worker def register_prehook(self, fn: Callable[..., Any]) -> RemovableHandle: 115*da0073e9SAndroid Build Coastguard Worker r"""Register a backward pre-hook. 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker The hook will be called every time a gradient with respect to the 118*da0073e9SAndroid Build Coastguard Worker Node is computed. The hook should have the following signature:: 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Worker hook(grad_outputs: Tuple[Tensor]) -> Tuple[Tensor] or None 121*da0073e9SAndroid Build Coastguard Worker 122*da0073e9SAndroid Build Coastguard Worker The hook should not modify its argument, but it can optionally return 123*da0073e9SAndroid Build Coastguard Worker a new gradient which will be used in place of :attr:`grad_outputs`. 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Worker This function returns a handle with a method ``handle.remove()`` 126*da0073e9SAndroid Build Coastguard Worker that removes the hook from the module. 127*da0073e9SAndroid Build Coastguard Worker 128*da0073e9SAndroid Build Coastguard Worker .. note:: 129*da0073e9SAndroid Build Coastguard Worker See :ref:`backward-hooks-execution` for more information on how when this hook 130*da0073e9SAndroid Build Coastguard Worker is executed, and how its execution is ordered relative to other hooks. 131*da0073e9SAndroid Build Coastguard Worker 132*da0073e9SAndroid Build Coastguard Worker Example:: 133*da0073e9SAndroid Build Coastguard Worker 134*da0073e9SAndroid Build Coastguard Worker >>> a = torch.tensor([0., 0., 0.], requires_grad=True) 135*da0073e9SAndroid Build Coastguard Worker >>> b = a.clone() 136*da0073e9SAndroid Build Coastguard Worker >>> assert isinstance(b.grad_fn, torch.autograd.graph.Node) 137*da0073e9SAndroid Build Coastguard Worker >>> handle = b.grad_fn.register_prehook(lambda gI: (gI[0] * 2,)) 138*da0073e9SAndroid Build Coastguard Worker >>> b.sum().backward(retain_graph=True) 139*da0073e9SAndroid Build Coastguard Worker >>> print(a.grad) 140*da0073e9SAndroid Build Coastguard Worker tensor([2., 2., 2.]) 141*da0073e9SAndroid Build Coastguard Worker >>> handle.remove() 142*da0073e9SAndroid Build Coastguard Worker >>> a.grad = None 143*da0073e9SAndroid Build Coastguard Worker >>> b.sum().backward(retain_graph=True) 144*da0073e9SAndroid Build Coastguard Worker >>> print(a.grad) 145*da0073e9SAndroid Build Coastguard Worker tensor([1., 1., 1.]) 146*da0073e9SAndroid Build Coastguard Worker """ 147*da0073e9SAndroid Build Coastguard Worker ... 148*da0073e9SAndroid Build Coastguard Worker 149*da0073e9SAndroid Build Coastguard Worker @classmethod 150*da0073e9SAndroid Build Coastguard Worker def __subclasshook__(cls, C): 151*da0073e9SAndroid Build Coastguard Worker if cls is Node: 152*da0073e9SAndroid Build Coastguard Worker if ( 153*da0073e9SAndroid Build Coastguard Worker C is not None and C is getattr(torch._C._functions, C.__name__, None) 154*da0073e9SAndroid Build Coastguard Worker ) or issubclass(C, torch.autograd.function.BackwardCFunction): 155*da0073e9SAndroid Build Coastguard Worker return True 156*da0073e9SAndroid Build Coastguard Worker return NotImplemented 157*da0073e9SAndroid Build Coastguard Worker 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Workerdef _get_grad_fn_or_grad_acc(t): 160*da0073e9SAndroid Build Coastguard Worker if t.requires_grad and t.grad_fn is None: 161*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 162*da0073e9SAndroid Build Coastguard Worker return t.view_as(t).grad_fn.next_functions[0][0] 163*da0073e9SAndroid Build Coastguard Worker else: 164*da0073e9SAndroid Build Coastguard Worker return t.grad_fn 165*da0073e9SAndroid Build Coastguard Worker 166*da0073e9SAndroid Build Coastguard Worker 167*da0073e9SAndroid Build Coastguard WorkerGradientEdge = namedtuple("GradientEdge", ("node output_nr")) 168*da0073e9SAndroid Build Coastguard WorkerGradientEdge.__doc__ = """\ 169*da0073e9SAndroid Build Coastguard WorkerObject representing a given gradient edge within the autograd graph. 170*da0073e9SAndroid Build Coastguard WorkerTo get the gradient edge where a given Tensor gradient will be computed, 171*da0073e9SAndroid Build Coastguard Workeryou can do ``edge = autograd.graph.get_gradient_edge(tensor)``. 172*da0073e9SAndroid Build Coastguard Worker""" 173*da0073e9SAndroid Build Coastguard Worker 174*da0073e9SAndroid Build Coastguard Worker 175*da0073e9SAndroid Build Coastguard Workerdef get_gradient_edge(tensor): 176*da0073e9SAndroid Build Coastguard Worker """Get the gradient edge for computing the gradient of the given Tensor. 177*da0073e9SAndroid Build Coastguard Worker 178*da0073e9SAndroid Build Coastguard Worker In particular, it is equivalent to call 179*da0073e9SAndroid Build Coastguard Worker ``g = autograd.grad(loss, input)`` and ``g = autograd.grad(loss, get_gradient_edge(input))``. 180*da0073e9SAndroid Build Coastguard Worker """ 181*da0073e9SAndroid Build Coastguard Worker if not tensor.requires_grad: 182*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 183*da0073e9SAndroid Build Coastguard Worker "It is not possible to get the gradient edge for a Tensor that does not require gradients" 184*da0073e9SAndroid Build Coastguard Worker ) 185*da0073e9SAndroid Build Coastguard Worker grad_fn = _get_grad_fn_or_grad_acc(tensor) 186*da0073e9SAndroid Build Coastguard Worker 187*da0073e9SAndroid Build Coastguard Worker # Note that output_nr default to 0 which is the right value 188*da0073e9SAndroid Build Coastguard Worker # for the AccumulateGrad node. 189*da0073e9SAndroid Build Coastguard Worker return GradientEdge(grad_fn, tensor.output_nr) 190*da0073e9SAndroid Build Coastguard Worker 191*da0073e9SAndroid Build Coastguard Worker 192*da0073e9SAndroid Build Coastguard Workerdef increment_version(tensor): 193*da0073e9SAndroid Build Coastguard Worker """Update autograd metadata tracking whether the given Tensor was modified in place. 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Worker This is to enable more accurate error checking within the autograd engine. 196*da0073e9SAndroid Build Coastguard Worker It is already done automatically by PyTorch functions and within custom Function 197*da0073e9SAndroid Build Coastguard Worker when mark_dirty() is called appropriately so you only need to call this explicitly 198*da0073e9SAndroid Build Coastguard Worker if you are doing inplace operation on the Tensor data in a way that Pytorch doesn't 199*da0073e9SAndroid Build Coastguard Worker know about. For example a custom kernel that reads the Tensor data_ptr and modifies 200*da0073e9SAndroid Build Coastguard Worker the memory inplace based on this pointer. 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard Worker Note that incrementing the version counter multiple times for a single inplace operation 203*da0073e9SAndroid Build Coastguard Worker is not problematic. 204*da0073e9SAndroid Build Coastguard Worker """ 205*da0073e9SAndroid Build Coastguard Worker torch._C._increment_version(tensor) 206*da0073e9SAndroid Build Coastguard Worker 207*da0073e9SAndroid Build Coastguard Worker 208*da0073e9SAndroid Build Coastguard Workerclass saved_tensors_hooks: 209*da0073e9SAndroid Build Coastguard Worker """Context-manager that sets a pair of pack / unpack hooks for saved tensors. 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Worker Use this context-manager to define how intermediary results of an operation 212*da0073e9SAndroid Build Coastguard Worker should be packed before saving, and unpacked on retrieval. 213*da0073e9SAndroid Build Coastguard Worker 214*da0073e9SAndroid Build Coastguard Worker In that context, the ``pack_hook`` function will be called everytime an 215*da0073e9SAndroid Build Coastguard Worker operation saves a tensor for backward (this includes intermediary results 216*da0073e9SAndroid Build Coastguard Worker saved using 217*da0073e9SAndroid Build Coastguard Worker :func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` but 218*da0073e9SAndroid Build Coastguard Worker also those recorded by a PyTorch-defined operation). The output of 219*da0073e9SAndroid Build Coastguard Worker ``pack_hook`` is then stored in the computation graph instead of the 220*da0073e9SAndroid Build Coastguard Worker original tensor. 221*da0073e9SAndroid Build Coastguard Worker 222*da0073e9SAndroid Build Coastguard Worker The ``unpack_hook`` is called when the saved tensor needs to be accessed, 223*da0073e9SAndroid Build Coastguard Worker namely when executing :func:`torch.Tensor.backward()` or 224*da0073e9SAndroid Build Coastguard Worker :func:`torch.autograd.grad()`. It takes as argument the *packed* object 225*da0073e9SAndroid Build Coastguard Worker returned by ``pack_hook`` and should return a tensor which has the same 226*da0073e9SAndroid Build Coastguard Worker content as the original tensor (passed as input to the corresponding 227*da0073e9SAndroid Build Coastguard Worker ``pack_hook``). 228*da0073e9SAndroid Build Coastguard Worker 229*da0073e9SAndroid Build Coastguard Worker The hooks should have the following signatures: 230*da0073e9SAndroid Build Coastguard Worker 231*da0073e9SAndroid Build Coastguard Worker pack_hook(tensor: Tensor) -> Any 232*da0073e9SAndroid Build Coastguard Worker 233*da0073e9SAndroid Build Coastguard Worker unpack_hook(Any) -> Tensor 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard Worker where the return value of ``pack_hook`` is a valid input to ``unpack_hook``. 236*da0073e9SAndroid Build Coastguard Worker 237*da0073e9SAndroid Build Coastguard Worker In general, you want ``unpack_hook(pack_hook(t))`` to be equal to ``t`` in terms 238*da0073e9SAndroid Build Coastguard Worker of value, size, dtype and device. 239*da0073e9SAndroid Build Coastguard Worker 240*da0073e9SAndroid Build Coastguard Worker Example:: 241*da0073e9SAndroid Build Coastguard Worker 242*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) 243*da0073e9SAndroid Build Coastguard Worker >>> def pack_hook(x): 244*da0073e9SAndroid Build Coastguard Worker ... print("Packing", x) 245*da0073e9SAndroid Build Coastguard Worker ... return x 246*da0073e9SAndroid Build Coastguard Worker >>> 247*da0073e9SAndroid Build Coastguard Worker >>> def unpack_hook(x): 248*da0073e9SAndroid Build Coastguard Worker ... print("Unpacking", x) 249*da0073e9SAndroid Build Coastguard Worker ... return x 250*da0073e9SAndroid Build Coastguard Worker >>> 251*da0073e9SAndroid Build Coastguard Worker >>> a = torch.ones(5, requires_grad=True) 252*da0073e9SAndroid Build Coastguard Worker >>> b = torch.ones(5, requires_grad=True) * 2 253*da0073e9SAndroid Build Coastguard Worker >>> with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook): 254*da0073e9SAndroid Build Coastguard Worker ... y = a * b 255*da0073e9SAndroid Build Coastguard Worker Packing tensor([1., 1., 1., 1., 1.], requires_grad=True) 256*da0073e9SAndroid Build Coastguard Worker Packing tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>) 257*da0073e9SAndroid Build Coastguard Worker >>> y.sum().backward() 258*da0073e9SAndroid Build Coastguard Worker Unpacking tensor([1., 1., 1., 1., 1.], requires_grad=True) 259*da0073e9SAndroid Build Coastguard Worker Unpacking tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>) 260*da0073e9SAndroid Build Coastguard Worker 261*da0073e9SAndroid Build Coastguard Worker .. warning :: 262*da0073e9SAndroid Build Coastguard Worker Performing an inplace operation on the input to either hooks may lead 263*da0073e9SAndroid Build Coastguard Worker to undefined behavior. 264*da0073e9SAndroid Build Coastguard Worker 265*da0073e9SAndroid Build Coastguard Worker .. warning :: 266*da0073e9SAndroid Build Coastguard Worker Only one pair of hooks is allowed at a time. When recursively nesting this 267*da0073e9SAndroid Build Coastguard Worker context-manager, only the inner-most pair of hooks will be applied. 268*da0073e9SAndroid Build Coastguard Worker """ 269*da0073e9SAndroid Build Coastguard Worker 270*da0073e9SAndroid Build Coastguard Worker def __init__( 271*da0073e9SAndroid Build Coastguard Worker self, 272*da0073e9SAndroid Build Coastguard Worker pack_hook: Callable[[torch.Tensor], Any], 273*da0073e9SAndroid Build Coastguard Worker unpack_hook: Callable[[Any], torch.Tensor], 274*da0073e9SAndroid Build Coastguard Worker ): 275*da0073e9SAndroid Build Coastguard Worker self.pack_hook = pack_hook 276*da0073e9SAndroid Build Coastguard Worker self.unpack_hook = unpack_hook 277*da0073e9SAndroid Build Coastguard Worker 278*da0073e9SAndroid Build Coastguard Worker def __enter__(self): 279*da0073e9SAndroid Build Coastguard Worker torch._C._autograd._push_saved_tensors_default_hooks( 280*da0073e9SAndroid Build Coastguard Worker self.pack_hook, self.unpack_hook 281*da0073e9SAndroid Build Coastguard Worker ) 282*da0073e9SAndroid Build Coastguard Worker 283*da0073e9SAndroid Build Coastguard Worker def __exit__(self, *args: object): 284*da0073e9SAndroid Build Coastguard Worker torch._C._autograd._pop_saved_tensors_default_hooks() 285*da0073e9SAndroid Build Coastguard Worker 286*da0073e9SAndroid Build Coastguard Worker 287*da0073e9SAndroid Build Coastguard Workerclass save_on_cpu(saved_tensors_hooks): 288*da0073e9SAndroid Build Coastguard Worker """Context manager under which tensors saved by the forward pass will be stored on cpu, then retrieved for backward. 289*da0073e9SAndroid Build Coastguard Worker 290*da0073e9SAndroid Build Coastguard Worker When performing operations within this context manager, intermediary 291*da0073e9SAndroid Build Coastguard Worker results saved in the graph during the forward pass will be moved to CPU, 292*da0073e9SAndroid Build Coastguard Worker then copied back to the original device when needed for the backward pass. 293*da0073e9SAndroid Build Coastguard Worker If the graph was already on CPU, no tensor copy is performed. 294*da0073e9SAndroid Build Coastguard Worker 295*da0073e9SAndroid Build Coastguard Worker Use this context-manager to trade compute for GPU memory usage (e.g. 296*da0073e9SAndroid Build Coastguard Worker when your model doesn't fit in GPU memory during training). 297*da0073e9SAndroid Build Coastguard Worker 298*da0073e9SAndroid Build Coastguard Worker Args: 299*da0073e9SAndroid Build Coastguard Worker pin_memory (bool): If ``True`` tensors will be saved to CPU pinned memory 300*da0073e9SAndroid Build Coastguard Worker during packing and copied to GPU asynchronously during unpacking. 301*da0073e9SAndroid Build Coastguard Worker Defaults to ``False``. 302*da0073e9SAndroid Build Coastguard Worker Also see :ref:`cuda-memory-pinning`. 303*da0073e9SAndroid Build Coastguard Worker 304*da0073e9SAndroid Build Coastguard Worker 305*da0073e9SAndroid Build Coastguard Worker Example:: 306*da0073e9SAndroid Build Coastguard Worker 307*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) 308*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) 309*da0073e9SAndroid Build Coastguard Worker >>> a = torch.randn(5, requires_grad=True, device="cuda") 310*da0073e9SAndroid Build Coastguard Worker >>> b = torch.randn(5, requires_grad=True, device="cuda") 311*da0073e9SAndroid Build Coastguard Worker >>> c = torch.randn(5, requires_grad=True, device="cuda") 312*da0073e9SAndroid Build Coastguard Worker >>> 313*da0073e9SAndroid Build Coastguard Worker >>> def f(a, b, c): 314*da0073e9SAndroid Build Coastguard Worker ... prod_1 = a * b # a and b are saved on GPU 315*da0073e9SAndroid Build Coastguard Worker ... with torch.autograd.graph.save_on_cpu(): 316*da0073e9SAndroid Build Coastguard Worker ... prod_2 = prod_1 * c # prod_1 and c are saved on CPU 317*da0073e9SAndroid Build Coastguard Worker ... y = prod_2 * a # prod_2 and a are saved on GPU 318*da0073e9SAndroid Build Coastguard Worker ... return y 319*da0073e9SAndroid Build Coastguard Worker >>> 320*da0073e9SAndroid Build Coastguard Worker >>> y = f(a, b, c) 321*da0073e9SAndroid Build Coastguard Worker >>> del a, b, c # for illustration only 322*da0073e9SAndroid Build Coastguard Worker >>> # the content of a, b, and prod_2 are still alive on GPU 323*da0073e9SAndroid Build Coastguard Worker >>> # the content of prod_1 and c only live on CPU 324*da0073e9SAndroid Build Coastguard Worker >>> y.sum().backward() # all CPU tensors are moved back to GPU, for backward 325*da0073e9SAndroid Build Coastguard Worker >>> # all intermediary tensors are released (deleted) after the call to backward 326*da0073e9SAndroid Build Coastguard Worker 327*da0073e9SAndroid Build Coastguard Worker """ 328*da0073e9SAndroid Build Coastguard Worker 329*da0073e9SAndroid Build Coastguard Worker def __init__(self, pin_memory=False, device_type="cuda"): 330*da0073e9SAndroid Build Coastguard Worker device_module = getattr(torch, device_type, torch.cuda) 331*da0073e9SAndroid Build Coastguard Worker 332*da0073e9SAndroid Build Coastguard Worker def pack_to_cpu(tensor): 333*da0073e9SAndroid Build Coastguard Worker if not pin_memory: 334*da0073e9SAndroid Build Coastguard Worker return (tensor.device, tensor.cpu()) 335*da0073e9SAndroid Build Coastguard Worker packed = torch.empty( 336*da0073e9SAndroid Build Coastguard Worker tensor.size(), 337*da0073e9SAndroid Build Coastguard Worker dtype=tensor.dtype, 338*da0073e9SAndroid Build Coastguard Worker layout=tensor.layout, 339*da0073e9SAndroid Build Coastguard Worker pin_memory=(device_module.is_available() and not tensor.is_sparse), 340*da0073e9SAndroid Build Coastguard Worker ) 341*da0073e9SAndroid Build Coastguard Worker packed.copy_(tensor) 342*da0073e9SAndroid Build Coastguard Worker return (tensor.device, packed) 343*da0073e9SAndroid Build Coastguard Worker 344*da0073e9SAndroid Build Coastguard Worker def unpack_from_cpu(packed): 345*da0073e9SAndroid Build Coastguard Worker device, tensor = packed 346*da0073e9SAndroid Build Coastguard Worker return tensor.to(device, non_blocking=pin_memory) 347*da0073e9SAndroid Build Coastguard Worker 348*da0073e9SAndroid Build Coastguard Worker super().__init__(pack_to_cpu, unpack_from_cpu) 349*da0073e9SAndroid Build Coastguard Worker 350*da0073e9SAndroid Build Coastguard Worker 351*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager 352*da0073e9SAndroid Build Coastguard Workerdef disable_saved_tensors_hooks(error_message): 353*da0073e9SAndroid Build Coastguard Worker """Context-manager that disables the saved tensors default hooks feature. 354*da0073e9SAndroid Build Coastguard Worker 355*da0073e9SAndroid Build Coastguard Worker Useful for if you are creating a feature that does not work with saved 356*da0073e9SAndroid Build Coastguard Worker tensors default hooks. 357*da0073e9SAndroid Build Coastguard Worker 358*da0073e9SAndroid Build Coastguard Worker Args: 359*da0073e9SAndroid Build Coastguard Worker error_message (str): When saved tensors default hooks are used when they 360*da0073e9SAndroid Build Coastguard Worker have been are disabled, a RuntimeError with this 361*da0073e9SAndroid Build Coastguard Worker error message gets raised. 362*da0073e9SAndroid Build Coastguard Worker 363*da0073e9SAndroid Build Coastguard Worker Example:: 364*da0073e9SAndroid Build Coastguard Worker 365*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +SKIP(failing) 366*da0073e9SAndroid Build Coastguard Worker >>> message = "saved tensors default hooks are disabled" 367*da0073e9SAndroid Build Coastguard Worker >>> with torch.autograd.graph.disable_saved_tensors_hooks(message): 368*da0073e9SAndroid Build Coastguard Worker ... # Raises RuntimeError: saved tensors default hooks are disabled 369*da0073e9SAndroid Build Coastguard Worker ... with torch.autograd.graph.save_on_cpu(): 370*da0073e9SAndroid Build Coastguard Worker ... pass 371*da0073e9SAndroid Build Coastguard Worker 372*da0073e9SAndroid Build Coastguard Worker """ 373*da0073e9SAndroid Build Coastguard Worker try: 374*da0073e9SAndroid Build Coastguard Worker maybe_prev_message = ( 375*da0073e9SAndroid Build Coastguard Worker torch._C._autograd._saved_tensors_hooks_get_disabled_error_message() 376*da0073e9SAndroid Build Coastguard Worker ) 377*da0073e9SAndroid Build Coastguard Worker torch._C._autograd._saved_tensors_hooks_disable(error_message) 378*da0073e9SAndroid Build Coastguard Worker yield 379*da0073e9SAndroid Build Coastguard Worker finally: 380*da0073e9SAndroid Build Coastguard Worker # See NOTE: [disabled_error_message invariant] 381*da0073e9SAndroid Build Coastguard Worker if maybe_prev_message is None: 382*da0073e9SAndroid Build Coastguard Worker torch._C._autograd._saved_tensors_hooks_enable() 383*da0073e9SAndroid Build Coastguard Worker else: 384*da0073e9SAndroid Build Coastguard Worker torch._C._autograd._saved_tensors_hooks_disable(maybe_prev_message) 385*da0073e9SAndroid Build Coastguard Worker 386*da0073e9SAndroid Build Coastguard Worker 387*da0073e9SAndroid Build Coastguard Workerclass _MultiHandle(RemovableHandle): 388*da0073e9SAndroid Build Coastguard Worker handles: Tuple[RemovableHandle, ...] 389*da0073e9SAndroid Build Coastguard Worker 390*da0073e9SAndroid Build Coastguard Worker def __init__(self, handles: Tuple[RemovableHandle, ...]): 391*da0073e9SAndroid Build Coastguard Worker self.handles = handles 392*da0073e9SAndroid Build Coastguard Worker 393*da0073e9SAndroid Build Coastguard Worker def remove(self): 394*da0073e9SAndroid Build Coastguard Worker for handle in self.handles: 395*da0073e9SAndroid Build Coastguard Worker handle.remove() 396*da0073e9SAndroid Build Coastguard Worker 397*da0073e9SAndroid Build Coastguard Worker def __getstate__(self): 398*da0073e9SAndroid Build Coastguard Worker return self.handles 399*da0073e9SAndroid Build Coastguard Worker 400*da0073e9SAndroid Build Coastguard Worker def __setstate__(self, state): 401*da0073e9SAndroid Build Coastguard Worker self.handles = state 402*da0073e9SAndroid Build Coastguard Worker 403*da0073e9SAndroid Build Coastguard Worker 404*da0073e9SAndroid Build Coastguard Workerdef register_multi_grad_hook( 405*da0073e9SAndroid Build Coastguard Worker tensors: Sequence[torch.Tensor], 406*da0073e9SAndroid Build Coastguard Worker fn: Union[ 407*da0073e9SAndroid Build Coastguard Worker Callable[[Sequence[Optional[torch.Tensor]]], None], 408*da0073e9SAndroid Build Coastguard Worker Callable[[torch.Tensor], None], 409*da0073e9SAndroid Build Coastguard Worker ], 410*da0073e9SAndroid Build Coastguard Worker *, 411*da0073e9SAndroid Build Coastguard Worker mode: str = "all", 412*da0073e9SAndroid Build Coastguard Worker): 413*da0073e9SAndroid Build Coastguard Worker r"""Register a multi-grad backward hook. 414*da0073e9SAndroid Build Coastguard Worker 415*da0073e9SAndroid Build Coastguard Worker There are two supported modes: ``"all"`` and ``"any"``. 416*da0073e9SAndroid Build Coastguard Worker 417*da0073e9SAndroid Build Coastguard Worker Under the ``"all"`` mode, the hook will be called after gradients with respect to every tensor in 418*da0073e9SAndroid Build Coastguard Worker :attr:`tensors` have been computed. If a tensor is in :attr:`tensors` but 419*da0073e9SAndroid Build Coastguard Worker is not part of the graph, or if a tensor is not needed to compute the gradients 420*da0073e9SAndroid Build Coastguard Worker for any ``inputs`` specified for the current ``.backward()`` or ``.grad()`` call, 421*da0073e9SAndroid Build Coastguard Worker this tensor will be ignored and the hook will not wait for its gradient to be 422*da0073e9SAndroid Build Coastguard Worker computed. 423*da0073e9SAndroid Build Coastguard Worker 424*da0073e9SAndroid Build Coastguard Worker After every non-ignored tensor's gradient has been computed, :attr:`fn` will be 425*da0073e9SAndroid Build Coastguard Worker called with those gradients. ``None`` will be passed for tensors that did not 426*da0073e9SAndroid Build Coastguard Worker have their gradients computed. 427*da0073e9SAndroid Build Coastguard Worker 428*da0073e9SAndroid Build Coastguard Worker Under the ``"any"`` mode, the hook will be called after the first gradient 429*da0073e9SAndroid Build Coastguard Worker with respect to a tensor in :attr:`tensors` has been computed. The hook 430*da0073e9SAndroid Build Coastguard Worker will be called with that gradient as its argument. 431*da0073e9SAndroid Build Coastguard Worker 432*da0073e9SAndroid Build Coastguard Worker The hook should not modify its arguments. 433*da0073e9SAndroid Build Coastguard Worker 434*da0073e9SAndroid Build Coastguard Worker This function returns a handle with a method ``handle.remove()`` that removes the hook. 435*da0073e9SAndroid Build Coastguard Worker 436*da0073e9SAndroid Build Coastguard Worker .. note:: 437*da0073e9SAndroid Build Coastguard Worker See :ref:`backward-hooks-execution` for more information on how when this hook 438*da0073e9SAndroid Build Coastguard Worker is executed, and how its execution is ordered relative to other hooks. 439*da0073e9SAndroid Build Coastguard Worker 440*da0073e9SAndroid Build Coastguard Worker Example:: 441*da0073e9SAndroid Build Coastguard Worker 442*da0073e9SAndroid Build Coastguard Worker >>> import torch 443*da0073e9SAndroid Build Coastguard Worker >>> 444*da0073e9SAndroid Build Coastguard Worker >>> a = torch.rand(2, 3, requires_grad=True) 445*da0073e9SAndroid Build Coastguard Worker >>> b = torch.rand(2, 3, requires_grad=True) 446*da0073e9SAndroid Build Coastguard Worker >>> c = a * b 447*da0073e9SAndroid Build Coastguard Worker >>> d = a * b 448*da0073e9SAndroid Build Coastguard Worker >>> 449*da0073e9SAndroid Build Coastguard Worker >>> def fn(grads): 450*da0073e9SAndroid Build Coastguard Worker ... print([g is not None for g in grads]) 451*da0073e9SAndroid Build Coastguard Worker ... 452*da0073e9SAndroid Build Coastguard Worker >>> torch.autograd.graph.register_multi_grad_hook((a, b, c, d), fn) 453*da0073e9SAndroid Build Coastguard Worker >>> 454*da0073e9SAndroid Build Coastguard Worker >>> c.sum().backward(retain_graph=True) 455*da0073e9SAndroid Build Coastguard Worker [True, True, True, False] 456*da0073e9SAndroid Build Coastguard Worker >>> c.sum().backward(inputs=(a,), retain_graph=True) 457*da0073e9SAndroid Build Coastguard Worker [True, False, True, False] 458*da0073e9SAndroid Build Coastguard Worker >>> 459*da0073e9SAndroid Build Coastguard Worker """ 460*da0073e9SAndroid Build Coastguard Worker supported_modes = ("all", "any") 461*da0073e9SAndroid Build Coastguard Worker if mode not in supported_modes: 462*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Expects mode to be one of {supported_modes} but got {mode}") 463*da0073e9SAndroid Build Coastguard Worker 464*da0073e9SAndroid Build Coastguard Worker if mode == "all": 465*da0073e9SAndroid Build Coastguard Worker count: Dict[int, int] = dict() 466*da0073e9SAndroid Build Coastguard Worker nb_calls = None 467*da0073e9SAndroid Build Coastguard Worker buffer: Dict[int, List[Optional[torch.Tensor]]] = dict() 468*da0073e9SAndroid Build Coastguard Worker 469*da0073e9SAndroid Build Coastguard Worker grad_fns = list(map(_get_grad_fn_or_grad_acc, tensors)) 470*da0073e9SAndroid Build Coastguard Worker len_tensors = len(tensors) 471*da0073e9SAndroid Build Coastguard Worker 472*da0073e9SAndroid Build Coastguard Worker def get_inner_hook(idx): 473*da0073e9SAndroid Build Coastguard Worker def inner_hook(grad: torch.Tensor): 474*da0073e9SAndroid Build Coastguard Worker nonlocal count, nb_calls, buffer, fn 475*da0073e9SAndroid Build Coastguard Worker id = torch._C._current_graph_task_id() 476*da0073e9SAndroid Build Coastguard Worker assert ( 477*da0073e9SAndroid Build Coastguard Worker id != -1 478*da0073e9SAndroid Build Coastguard Worker ), "expected this hook to be called inside a backward call" 479*da0073e9SAndroid Build Coastguard Worker count[id] = count.get(id, 0) 480*da0073e9SAndroid Build Coastguard Worker buffer[id] = buffer.get(id, [None] * len_tensors) 481*da0073e9SAndroid Build Coastguard Worker 482*da0073e9SAndroid Build Coastguard Worker if count[id] == 0: 483*da0073e9SAndroid Build Coastguard Worker # On the first call, compute the actual nb_calls and buffer 484*da0073e9SAndroid Build Coastguard Worker nb_calls = sum(torch._C._will_engine_execute_node(g) for g in grad_fns) # type: ignore[attr-defined] 485*da0073e9SAndroid Build Coastguard Worker 486*da0073e9SAndroid Build Coastguard Worker buffer[id][idx] = grad 487*da0073e9SAndroid Build Coastguard Worker count[id] += 1 488*da0073e9SAndroid Build Coastguard Worker 489*da0073e9SAndroid Build Coastguard Worker if count[id] == nb_calls: 490*da0073e9SAndroid Build Coastguard Worker fn = cast(Callable[[Sequence[Optional[torch.Tensor]]], None], fn) 491*da0073e9SAndroid Build Coastguard Worker fn(buffer[id]) 492*da0073e9SAndroid Build Coastguard Worker del count[id] 493*da0073e9SAndroid Build Coastguard Worker del buffer[id] 494*da0073e9SAndroid Build Coastguard Worker 495*da0073e9SAndroid Build Coastguard Worker return inner_hook 496*da0073e9SAndroid Build Coastguard Worker 497*da0073e9SAndroid Build Coastguard Worker handles: Tuple[RemovableHandle] = tuple( 498*da0073e9SAndroid Build Coastguard Worker t.register_hook(get_inner_hook(i)) for i, t in enumerate(tensors) 499*da0073e9SAndroid Build Coastguard Worker ) 500*da0073e9SAndroid Build Coastguard Worker elif mode == "any": 501*da0073e9SAndroid Build Coastguard Worker fn = cast(Callable[[torch.Tensor], None], fn) 502*da0073e9SAndroid Build Coastguard Worker lock = threading.Lock() 503*da0073e9SAndroid Build Coastguard Worker ran_hook: Dict[int, bool] = defaultdict(bool) 504*da0073e9SAndroid Build Coastguard Worker 505*da0073e9SAndroid Build Coastguard Worker @functools.wraps(fn) 506*da0073e9SAndroid Build Coastguard Worker def wrapped_fn(grad: torch.Tensor): 507*da0073e9SAndroid Build Coastguard Worker nonlocal ran_hook 508*da0073e9SAndroid Build Coastguard Worker id = torch._C._current_graph_task_id() 509*da0073e9SAndroid Build Coastguard Worker assert id != -1, "expected this hook to be called inside a backward call" 510*da0073e9SAndroid Build Coastguard Worker with lock: 511*da0073e9SAndroid Build Coastguard Worker prev, ran_hook[id] = ran_hook[id], True 512*da0073e9SAndroid Build Coastguard Worker if prev: 513*da0073e9SAndroid Build Coastguard Worker return 514*da0073e9SAndroid Build Coastguard Worker fn(grad) 515*da0073e9SAndroid Build Coastguard Worker 516*da0073e9SAndroid Build Coastguard Worker handles = tuple( 517*da0073e9SAndroid Build Coastguard Worker tensor.register_hook(wrapped_fn) 518*da0073e9SAndroid Build Coastguard Worker for tensor in tensors 519*da0073e9SAndroid Build Coastguard Worker if tensor.requires_grad 520*da0073e9SAndroid Build Coastguard Worker ) 521*da0073e9SAndroid Build Coastguard Worker 522*da0073e9SAndroid Build Coastguard Worker return _MultiHandle(handles) # type: ignore[possibly-undefined] 523*da0073e9SAndroid Build Coastguard Worker 524*da0073e9SAndroid Build Coastguard Worker 525*da0073e9SAndroid Build Coastguard Worker# NOTE [Allow mutation on tensors saved for backward] 526*da0073e9SAndroid Build Coastguard Worker# 527*da0073e9SAndroid Build Coastguard Worker# 1. Tensor gets saved for backward 528*da0073e9SAndroid Build Coastguard Worker# - remember the python object id and the version of the tensor 529*da0073e9SAndroid Build Coastguard Worker# - remember aliasing information (data_ptr of base + version) 530*da0073e9SAndroid Build Coastguard Worker# - save the original so we control its lifetime 531*da0073e9SAndroid Build Coastguard Worker# 2. Any time a tensor gets in-placed 532*da0073e9SAndroid Build Coastguard Worker# - for each tensor aliased to it: 533*da0073e9SAndroid Build Coastguard Worker# - check using its object id and version to see if it has been saved 534*da0073e9SAndroid Build Coastguard Worker# - if it has been saved, clone it 535*da0073e9SAndroid Build Coastguard Worker# - delete the reference to the original 536*da0073e9SAndroid Build Coastguard Worker# 3. during backward 537*da0073e9SAndroid Build Coastguard Worker# - if the clone exists, the tensor must've been modified in-place 538*da0073e9SAndroid Build Coastguard Worker_allow_mutation_on_saved_tensors_enabled = False 539*da0073e9SAndroid Build Coastguard Worker 540*da0073e9SAndroid Build Coastguard Worker 541*da0073e9SAndroid Build Coastguard Workerdef _get_tid(t) -> Tuple[int, int, int]: 542*da0073e9SAndroid Build Coastguard Worker # FIXME: This is almost definitely a bug. 543*da0073e9SAndroid Build Coastguard Worker if isinstance( 544*da0073e9SAndroid Build Coastguard Worker t, 545*da0073e9SAndroid Build Coastguard Worker ( 546*da0073e9SAndroid Build Coastguard Worker torch._subclasses.fake_tensor.FakeTensor, 547*da0073e9SAndroid Build Coastguard Worker torch._subclasses.functional_tensor.FunctionalTensor, 548*da0073e9SAndroid Build Coastguard Worker ), 549*da0073e9SAndroid Build Coastguard Worker ): 550*da0073e9SAndroid Build Coastguard Worker data_ptr = 0 551*da0073e9SAndroid Build Coastguard Worker else: 552*da0073e9SAndroid Build Coastguard Worker data_ptr = t.data_ptr() 553*da0073e9SAndroid Build Coastguard Worker return (id(t), data_ptr, t._version) 554*da0073e9SAndroid Build Coastguard Worker 555*da0073e9SAndroid Build Coastguard Worker 556*da0073e9SAndroid Build Coastguard Workerdef _get_sid(t) -> Tuple[int, int]: 557*da0073e9SAndroid Build Coastguard Worker # FIXME: This is almost definitely a bug. 558*da0073e9SAndroid Build Coastguard Worker if isinstance( 559*da0073e9SAndroid Build Coastguard Worker t, 560*da0073e9SAndroid Build Coastguard Worker ( 561*da0073e9SAndroid Build Coastguard Worker torch._subclasses.fake_tensor.FakeTensor, 562*da0073e9SAndroid Build Coastguard Worker torch._subclasses.functional_tensor.FunctionalTensor, 563*da0073e9SAndroid Build Coastguard Worker ), 564*da0073e9SAndroid Build Coastguard Worker ): 565*da0073e9SAndroid Build Coastguard Worker data_ptr = 0 566*da0073e9SAndroid Build Coastguard Worker else: 567*da0073e9SAndroid Build Coastguard Worker data_ptr = t.data_ptr() 568*da0073e9SAndroid Build Coastguard Worker return (data_ptr, t._version) 569*da0073e9SAndroid Build Coastguard Worker 570*da0073e9SAndroid Build Coastguard Worker 571*da0073e9SAndroid Build Coastguard Workerclass _Handle: 572*da0073e9SAndroid Build Coastguard Worker pass 573*da0073e9SAndroid Build Coastguard Worker 574*da0073e9SAndroid Build Coastguard Worker 575*da0073e9SAndroid Build Coastguard Workerclass _swap_with_cloned(saved_tensors_hooks): 576*da0073e9SAndroid Build Coastguard Worker def __init__(self, ctx): 577*da0073e9SAndroid Build Coastguard Worker def pack_hook(t): 578*da0073e9SAndroid Build Coastguard Worker tid = _get_tid(t) 579*da0073e9SAndroid Build Coastguard Worker sid = _get_sid(t) 580*da0073e9SAndroid Build Coastguard Worker # Tensors saved for backward have an entry in _tid_to_weakhandle 581*da0073e9SAndroid Build Coastguard Worker handle: Optional[_Handle] = None 582*da0073e9SAndroid Build Coastguard Worker 583*da0073e9SAndroid Build Coastguard Worker # Save aliasing information 584*da0073e9SAndroid Build Coastguard Worker ctx.sid_to_tid[sid].add(tid) 585*da0073e9SAndroid Build Coastguard Worker 586*da0073e9SAndroid Build Coastguard Worker # NB: The same tensor (of the same version) can be saved multiple times 587*da0073e9SAndroid Build Coastguard Worker if tid not in ctx.tid_to_weakhandle: 588*da0073e9SAndroid Build Coastguard Worker handle = _Handle() 589*da0073e9SAndroid Build Coastguard Worker ctx.tid_to_weakhandle[tid] = handle 590*da0073e9SAndroid Build Coastguard Worker ctx.original[handle] = t 591*da0073e9SAndroid Build Coastguard Worker else: 592*da0073e9SAndroid Build Coastguard Worker # Store an additional strong reference to the handle 593*da0073e9SAndroid Build Coastguard Worker handle = ctx.tid_to_weakhandle[tid] 594*da0073e9SAndroid Build Coastguard Worker return handle 595*da0073e9SAndroid Build Coastguard Worker 596*da0073e9SAndroid Build Coastguard Worker def unpack_hook(tup): 597*da0073e9SAndroid Build Coastguard Worker handle = tup 598*da0073e9SAndroid Build Coastguard Worker error_msg = ( 599*da0073e9SAndroid Build Coastguard Worker "Trying to backward outside of the 'allow_mutation_on_saved_tensors' context" 600*da0073e9SAndroid Build Coastguard Worker "in which the graph was originally recorded." 601*da0073e9SAndroid Build Coastguard Worker ) 602*da0073e9SAndroid Build Coastguard Worker assert _allow_mutation_on_saved_tensors_enabled, error_msg 603*da0073e9SAndroid Build Coastguard Worker if handle in ctx.cloned: 604*da0073e9SAndroid Build Coastguard Worker res = ctx.cloned[handle] 605*da0073e9SAndroid Build Coastguard Worker else: 606*da0073e9SAndroid Build Coastguard Worker assert handle in ctx.original, error_msg 607*da0073e9SAndroid Build Coastguard Worker res = ctx.original[handle] 608*da0073e9SAndroid Build Coastguard Worker return res 609*da0073e9SAndroid Build Coastguard Worker 610*da0073e9SAndroid Build Coastguard Worker super().__init__(pack_hook, unpack_hook) 611*da0073e9SAndroid Build Coastguard Worker 612*da0073e9SAndroid Build Coastguard Worker 613*da0073e9SAndroid Build Coastguard Workerclass _CloneArgBeforeMutateMode(TorchDispatchMode): 614*da0073e9SAndroid Build Coastguard Worker def __init__(self, ctx): 615*da0073e9SAndroid Build Coastguard Worker self.ctx = ctx 616*da0073e9SAndroid Build Coastguard Worker 617*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(self, func, types, args=(), kwargs=None): 618*da0073e9SAndroid Build Coastguard Worker kwargs = kwargs or {} 619*da0073e9SAndroid Build Coastguard Worker 620*da0073e9SAndroid Build Coastguard Worker for idx, arg in enumerate(func._schema.arguments): 621*da0073e9SAndroid Build Coastguard Worker if arg.alias_info is not None and arg.alias_info.is_write: 622*da0073e9SAndroid Build Coastguard Worker t = kwargs["out"] if arg.is_out else args[idx] 623*da0073e9SAndroid Build Coastguard Worker tid = _get_tid(t) 624*da0073e9SAndroid Build Coastguard Worker sid = _get_sid(t) 625*da0073e9SAndroid Build Coastguard Worker ctx = self.ctx 626*da0073e9SAndroid Build Coastguard Worker if sid in ctx.sid_to_tid: 627*da0073e9SAndroid Build Coastguard Worker for tid in ctx.sid_to_tid[sid]: 628*da0073e9SAndroid Build Coastguard Worker if tid not in ctx.tid_to_weakhandle: 629*da0073e9SAndroid Build Coastguard Worker # We know that if tid is in sid_to_tid, then it must also be in 630*da0073e9SAndroid Build Coastguard Worker # tid_to_weakhandle. However, it is possible for the tensor to be 631*da0073e9SAndroid Build Coastguard Worker # saved at one point, but cleared by backward before it is modified 632*da0073e9SAndroid Build Coastguard Worker # in-place. Consider the following example: 633*da0073e9SAndroid Build Coastguard Worker # 634*da0073e9SAndroid Build Coastguard Worker # >>> a = torch.randn(2, 3, requires_grad=True).clone() 635*da0073e9SAndroid Build Coastguard Worker # >>> out = (a**2).sum() 636*da0073e9SAndroid Build Coastguard Worker # >>> out.backward() 637*da0073e9SAndroid Build Coastguard Worker # >>> a.sin_() 638*da0073e9SAndroid Build Coastguard Worker continue 639*da0073e9SAndroid Build Coastguard Worker handle = ctx.tid_to_weakhandle[tid] 640*da0073e9SAndroid Build Coastguard Worker if handle in ctx.cloned: 641*da0073e9SAndroid Build Coastguard Worker # The same exact tensor has been cloned already 642*da0073e9SAndroid Build Coastguard Worker continue 643*da0073e9SAndroid Build Coastguard Worker ctx.cloned[handle] = ctx.original[handle].clone() 644*da0073e9SAndroid Build Coastguard Worker del ctx.original[handle] 645*da0073e9SAndroid Build Coastguard Worker 646*da0073e9SAndroid Build Coastguard Worker rs = func(*args, **kwargs) 647*da0073e9SAndroid Build Coastguard Worker return rs 648*da0073e9SAndroid Build Coastguard Worker 649*da0073e9SAndroid Build Coastguard Worker 650*da0073e9SAndroid Build Coastguard Workerclass _AllowMutationOnSavedContext: 651*da0073e9SAndroid Build Coastguard Worker def __init__(self): 652*da0073e9SAndroid Build Coastguard Worker self.cloned: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() 653*da0073e9SAndroid Build Coastguard Worker self.original: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() 654*da0073e9SAndroid Build Coastguard Worker self.tid_to_weakhandle: weakref.WeakValueDictionary = ( 655*da0073e9SAndroid Build Coastguard Worker weakref.WeakValueDictionary() 656*da0073e9SAndroid Build Coastguard Worker ) 657*da0073e9SAndroid Build Coastguard Worker self.sid_to_tid: Dict[Tuple[int, int], Set[Tuple[int, int, int]]] = defaultdict( 658*da0073e9SAndroid Build Coastguard Worker set 659*da0073e9SAndroid Build Coastguard Worker ) 660*da0073e9SAndroid Build Coastguard Worker 661*da0073e9SAndroid Build Coastguard Worker def clear(self): 662*da0073e9SAndroid Build Coastguard Worker self.cloned.clear() 663*da0073e9SAndroid Build Coastguard Worker self.original.clear() 664*da0073e9SAndroid Build Coastguard Worker self.tid_to_weakhandle.clear() 665*da0073e9SAndroid Build Coastguard Worker self.sid_to_tid.clear() 666*da0073e9SAndroid Build Coastguard Worker 667*da0073e9SAndroid Build Coastguard Worker 668*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager 669*da0073e9SAndroid Build Coastguard Workerdef allow_mutation_on_saved_tensors(): 670*da0073e9SAndroid Build Coastguard Worker """Context manager under which mutating tensors saved for backward is allowed. 671*da0073e9SAndroid Build Coastguard Worker 672*da0073e9SAndroid Build Coastguard Worker Under this context manager, tensors saved for backward are cloned on mutation, 673*da0073e9SAndroid Build Coastguard Worker so the original version can still be used during backward. Normally, mutating a tensor 674*da0073e9SAndroid Build Coastguard Worker saved for backward will result in an error raised when it's used during backward. 675*da0073e9SAndroid Build Coastguard Worker 676*da0073e9SAndroid Build Coastguard Worker To ensure the correct behavior, both the forward and backward should be run under 677*da0073e9SAndroid Build Coastguard Worker the same context manager. 678*da0073e9SAndroid Build Coastguard Worker 679*da0073e9SAndroid Build Coastguard Worker returns: 680*da0073e9SAndroid Build Coastguard Worker An _AllowMutationOnSavedContext object storing the state managed by this 681*da0073e9SAndroid Build Coastguard Worker context manager. This object can be useful for debugging purposes. The state 682*da0073e9SAndroid Build Coastguard Worker managed by the context manager is automatically cleared upon exiting. 683*da0073e9SAndroid Build Coastguard Worker 684*da0073e9SAndroid Build Coastguard Worker Example:: 685*da0073e9SAndroid Build Coastguard Worker 686*da0073e9SAndroid Build Coastguard Worker >>> import torch 687*da0073e9SAndroid Build Coastguard Worker >>> with torch.autograd.graph.allow_mutation_on_saved_tensors(): 688*da0073e9SAndroid Build Coastguard Worker ... # forward 689*da0073e9SAndroid Build Coastguard Worker ... a = torch.ones(2, 3, requires_grad=True) 690*da0073e9SAndroid Build Coastguard Worker ... b = a.clone() 691*da0073e9SAndroid Build Coastguard Worker ... out = (b**2).sum() 692*da0073e9SAndroid Build Coastguard Worker ... b.sin_() 693*da0073e9SAndroid Build Coastguard Worker ... # backward 694*da0073e9SAndroid Build Coastguard Worker ... out.sum().backward() 695*da0073e9SAndroid Build Coastguard Worker ... 696*da0073e9SAndroid Build Coastguard Worker tensor([[0.8415, 0.8415, 0.8415], 697*da0073e9SAndroid Build Coastguard Worker [0.8415, 0.8415, 0.8415]], grad_fn=<SinBackward0>) 698*da0073e9SAndroid Build Coastguard Worker """ 699*da0073e9SAndroid Build Coastguard Worker global _allow_mutation_on_saved_tensors_enabled 700*da0073e9SAndroid Build Coastguard Worker 701*da0073e9SAndroid Build Coastguard Worker ctx = _AllowMutationOnSavedContext() 702*da0073e9SAndroid Build Coastguard Worker 703*da0073e9SAndroid Build Coastguard Worker with _swap_with_cloned(ctx), _CloneArgBeforeMutateMode(ctx): 704*da0073e9SAndroid Build Coastguard Worker try: 705*da0073e9SAndroid Build Coastguard Worker if _allow_mutation_on_saved_tensors_enabled: 706*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 707*da0073e9SAndroid Build Coastguard Worker "allow_mutation_on_saved_tensors contexts cannot be nested" 708*da0073e9SAndroid Build Coastguard Worker ) 709*da0073e9SAndroid Build Coastguard Worker _allow_mutation_on_saved_tensors_enabled = True 710*da0073e9SAndroid Build Coastguard Worker yield ctx 711*da0073e9SAndroid Build Coastguard Worker finally: 712*da0073e9SAndroid Build Coastguard Worker ctx.clear() 713*da0073e9SAndroid Build Coastguard Worker _allow_mutation_on_saved_tensors_enabled = False 714*da0073e9SAndroid Build Coastguard Worker 715*da0073e9SAndroid Build Coastguard Worker 716*da0073e9SAndroid Build Coastguard Workerdef _register_logging_hooks_on_whole_graph(t_outputs: List[torch.Tensor]): 717*da0073e9SAndroid Build Coastguard Worker grad_fns = list(map(_get_grad_fn_or_grad_acc, t_outputs)) 718*da0073e9SAndroid Build Coastguard Worker 719*da0073e9SAndroid Build Coastguard Worker def iter_graph(roots): 720*da0073e9SAndroid Build Coastguard Worker if not roots: 721*da0073e9SAndroid Build Coastguard Worker return 722*da0073e9SAndroid Build Coastguard Worker seen = set() 723*da0073e9SAndroid Build Coastguard Worker q: Deque = collections.deque() 724*da0073e9SAndroid Build Coastguard Worker for node in roots: 725*da0073e9SAndroid Build Coastguard Worker if node is not None: 726*da0073e9SAndroid Build Coastguard Worker seen.add(node) 727*da0073e9SAndroid Build Coastguard Worker q.append(node) 728*da0073e9SAndroid Build Coastguard Worker 729*da0073e9SAndroid Build Coastguard Worker while q: 730*da0073e9SAndroid Build Coastguard Worker node = q.popleft() 731*da0073e9SAndroid Build Coastguard Worker for fn, _idx in node.next_functions: 732*da0073e9SAndroid Build Coastguard Worker if fn in seen or fn is None: 733*da0073e9SAndroid Build Coastguard Worker continue 734*da0073e9SAndroid Build Coastguard Worker seen.add(fn) 735*da0073e9SAndroid Build Coastguard Worker q.append(fn) 736*da0073e9SAndroid Build Coastguard Worker 737*da0073e9SAndroid Build Coastguard Worker yield node 738*da0073e9SAndroid Build Coastguard Worker 739*da0073e9SAndroid Build Coastguard Worker def fmt(t): 740*da0073e9SAndroid Build Coastguard Worker # Avoid circular import 741*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.common_utils import dtype_abbrs 742*da0073e9SAndroid Build Coastguard Worker 743*da0073e9SAndroid Build Coastguard Worker if t is None: 744*da0073e9SAndroid Build Coastguard Worker return "None" 745*da0073e9SAndroid Build Coastguard Worker return f"{dtype_abbrs[t.dtype]}[{', '.join(map(str, t.shape))}]" 746*da0073e9SAndroid Build Coastguard Worker 747*da0073e9SAndroid Build Coastguard Worker def prehook(grad_outputs): 748*da0073e9SAndroid Build Coastguard Worker node = torch._C._current_autograd_node() 749*da0073e9SAndroid Build Coastguard Worker grad_outputs_str = f"[{','.join(fmt(t) for t in grad_outputs)}]" 750*da0073e9SAndroid Build Coastguard Worker log_str = f"Executing: {node} with grad_outputs: {grad_outputs_str}" 751*da0073e9SAndroid Build Coastguard Worker log.debug(log_str) 752*da0073e9SAndroid Build Coastguard Worker 753*da0073e9SAndroid Build Coastguard Worker handles = [] 754*da0073e9SAndroid Build Coastguard Worker for node in iter_graph(grad_fns): 755*da0073e9SAndroid Build Coastguard Worker handles.append(node.register_prehook(prehook)) 756*da0073e9SAndroid Build Coastguard Worker 757*da0073e9SAndroid Build Coastguard Worker def unregister_hooks(): 758*da0073e9SAndroid Build Coastguard Worker for handle in handles: 759*da0073e9SAndroid Build Coastguard Worker handle.remove() 760*da0073e9SAndroid Build Coastguard Worker 761*da0073e9SAndroid Build Coastguard Worker return unregister_hooks 762*da0073e9SAndroid Build Coastguard Worker 763*da0073e9SAndroid Build Coastguard Worker 764*da0073e9SAndroid Build Coastguard Workerdef _engine_run_backward(t_outputs, *args, **kwargs): 765*da0073e9SAndroid Build Coastguard Worker attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG 766*da0073e9SAndroid Build Coastguard Worker if attach_logging_hooks: 767*da0073e9SAndroid Build Coastguard Worker unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) 768*da0073e9SAndroid Build Coastguard Worker try: 769*da0073e9SAndroid Build Coastguard Worker return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass 770*da0073e9SAndroid Build Coastguard Worker t_outputs, *args, **kwargs 771*da0073e9SAndroid Build Coastguard Worker ) # Calls into the C++ engine to run the backward pass 772*da0073e9SAndroid Build Coastguard Worker finally: 773*da0073e9SAndroid Build Coastguard Worker if attach_logging_hooks: 774*da0073e9SAndroid Build Coastguard Worker unregister_hooks() # type: ignore[possibly-undefined] 775