xref: /aosp_15_r20/external/pytorch/test/test_autograd.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: autograd"]
2
3import collections
4import contextlib
5import functools
6import gc
7import io
8import math
9import operator
10import os
11import pickle
12import random
13import subprocess
14import sys
15import tempfile
16import threading
17import time
18import unittest
19import uuid
20import warnings
21import weakref
22from collections import OrderedDict
23from copy import deepcopy
24from functools import partial, reduce
25from itertools import product
26from operator import mul
27from typing import List, Tuple, TYPE_CHECKING
28
29import torch
30import torch.autograd._functions
31import torch.autograd.forward_ad as fwAD
32from torch import inf, nan, nn
33from torch.autograd import (
34    _calculate_shape,
35    detect_anomaly,
36    Function,
37    kineto_available,
38    Variable,
39)
40from torch.autograd.function import InplaceFunction, once_differentiable
41from torch.autograd.graph import GradientEdge
42from torch.autograd.profiler import emit_itt, emit_nvtx, profile, record_function
43from torch.autograd.profiler_util import (
44    _format_time,
45    EventList,
46    FunctionEvent,
47    FunctionEventAvg,
48)
49from torch.testing import make_tensor
50from torch.testing._internal.common_cuda import TEST_CUDA
51from torch.testing._internal.common_device_type import (
52    deviceCountAtLeast,
53    dtypes,
54    dtypesIfCUDA,
55    dtypesIfMPS,
56    instantiate_device_type_tests,
57    onlyCPU,
58    onlyCUDA,
59    skipMeta,
60)
61from torch.testing._internal.common_dtype import floating_types_and
62from torch.testing._internal.common_methods_invocations import mask_not_all_zeros
63from torch.testing._internal.common_utils import (
64    disable_gc,
65    gradcheck,
66    gradgradcheck,
67    instantiate_parametrized_tests,
68    IS_MACOS,
69    IS_WINDOWS,
70    parametrize,
71    run_tests,
72    set_warn_always_context,
73    skipIfMps,
74    skipIfNoLapack,
75    skipIfTorchDynamo,
76    slowTest,
77    TestCase,
78    xfailIfTorchDynamo,
79)
80from torch.utils._mode_utils import no_dispatch
81from torch.utils._python_dispatch import TorchDispatchMode
82from torch.utils.checkpoint import (
83    checkpoint,
84    checkpoint_sequential,
85    CheckpointPolicy,
86    create_selective_checkpoint_contexts,
87)
88from torch.utils.cpp_extension import load_inline
89from torch.utils.flop_counter import FlopCounterMode
90
91
92if TYPE_CHECKING:
93    from torch.utils.hooks import RemovableHandle
94
95
96def graph_desc(fn):
97    if fn is None:
98        return "None"
99    result = type(fn).__name__ + "("
100    next_functions = fn.next_functions
101    for next_fn, _ in next_functions:
102        result += graph_desc(next_fn)
103        result += ", "
104    if next_functions:
105        result = result[:-2]
106    return result + ")"
107
108
109class TestAutograd(TestCase):
110    def test_copy_slices_graph_task_updates(self):
111        def f1(x, y):
112            out = x.clone().view(-1)
113            out += y
114            return out
115
116        def f2(x, y):
117            out = x.clone().view(-1)
118            b = out * 2
119            out += y
120            return out + b
121
122        x = torch.rand(2, requires_grad=True)
123        y = torch.rand(2, requires_grad=True)
124
125        y_safe = torch._C._functions.DelayedError("Boom!", 1)(y)
126
127        for f in [f1, f2]:
128            # Ensure that the error Node works
129            out = f(x, y_safe)
130            with self.assertRaisesRegex(RuntimeError, "Boom!"):
131                out.sum().backward()
132
133            out = f(x, y_safe)
134            with self.assertRaisesRegex(RuntimeError, "Boom!"):
135                torch.autograd.grad(out.sum(), y)
136
137            # Ensure that if we don't ask for y, it doesn't crash
138            out = f(x, y_safe)
139            torch.autograd.grad(out.sum(), x)
140
141            out = f(x, y_safe)
142            torch.autograd.grad(out.sum(), y_safe)
143
144            out = f(x, y_safe)
145            torch.autograd.grad(out.sum(), (x, y_safe))
146
147        # Ensure that we don't run extra view Node
148        def f3(x, y):
149            out = x.clone().view(-1)
150
151            def hook(*args):
152                # This should never be called!
153                self.assertTrue(False)
154
155            out.register_hook(hook)
156
157            b = out + y
158            out += y
159            return out + b, b
160
161        out, b = f3(x, y_safe)
162        torch.autograd.grad(out.sum(), (b, y_safe))
163
164    def test_grad_mode_class_decoration(self):
165        # Decorating class is deprecated and should not be used
166        with self.assertWarnsRegex(FutureWarning, "Decorating classes is deprecated"):
167
168            @torch.no_grad()
169            class Foo:
170                def __init__(self) -> None:
171                    assert not torch.is_grad_enabled()
172
173                def foo(self):
174                    # Not applied to methods
175                    assert torch.is_grad_enabled()
176
177            # Show that we can actually construct the class
178            foo = Foo()
179            foo.foo()
180
181        # Decorating functions or methods is fine though
182        with warnings.catch_warnings(record=True) as w:
183
184            @torch.no_grad()
185            def foo():
186                assert not torch.is_grad_enabled()
187
188            foo()
189
190            class Foo2:
191                @torch.no_grad()
192                def __init__(self) -> None:
193                    assert not torch.is_grad_enabled()
194
195                @torch.no_grad()
196                def foo(self):
197                    assert not torch.is_grad_enabled()
198
199            foo2 = Foo2()
200            foo2.foo()
201
202        self.assertEqual(len(w), 0)
203
204    def test_tensor_grad_warnings(self):
205        dummy = torch.empty(1)
206
207        with warnings.catch_warnings(record=True) as w:
208            # Accessing .grad on leaf
209            dummy.requires_grad_()
210            foo = dummy.grad
211            self.assertEqual(len(w), 0)
212
213            # Accessing .grad on non-leaf
214            dummy = dummy.clone()
215            foo = dummy.grad
216            self.assertEqual(len(w), 1)
217
218            # Accessing .grad on non-leaf that retains gradients
219            dummy.retain_grad()
220            foo = dummy.grad
221            self.assertEqual(len(w), 1)
222
223    def _function_test(self, cls):
224        x = torch.randn(5, 5, requires_grad=True)
225        y = torch.randn(5, 5, requires_grad=True)
226        result = cls.apply(x, 2, y)
227        go = torch.ones((), requires_grad=True)
228        result.sum().backward(go, create_graph=True)
229
230        self.assertEqual(x.grad, y + torch.ones(5, 5))
231        self.assertEqual(y.grad, x + torch.ones(5, 5) * 2)
232        self.assertIsNotNone(x.grad.grad_fn)
233        self.assertIsNotNone(y.grad.grad_fn)
234
235        return x, y
236
237    def test_function(self):
238        class MyFunction(Function):
239            @staticmethod
240            def forward(ctx, tensor1, pyscalar, tensor2):
241                ctx.pyscalar = pyscalar
242                ctx.save_for_backward(tensor1, tensor2)
243                return tensor1 + pyscalar * tensor2 + tensor1 * tensor2
244
245            @staticmethod
246            def backward(ctx, grad_output):
247                var1, var2 = ctx.saved_tensors
248                # NOTE: self is the test case here
249                self.assertIsInstance(var1, torch.Tensor)
250                self.assertIsInstance(var2, torch.Tensor)
251                self.assertIsInstance(grad_output, torch.Tensor)
252                return (
253                    grad_output + grad_output * var2,
254                    None,
255                    grad_output * ctx.pyscalar + grad_output * var1,
256                )
257
258        x, y = self._function_test(MyFunction)
259
260        x_grad_desc = graph_desc(x.grad.grad_fn)
261        y_grad_desc = graph_desc(y.grad.grad_fn)
262        self.assertExpected(x_grad_desc, "x_grad_desc")
263        self.assertExpected(y_grad_desc, "y_grad_desc")
264
265    def test_once_differentiable(self):
266        class MyFunction(Function):
267            @staticmethod
268            def forward(ctx, tensor1, pyscalar, tensor2):
269                ctx.pyscalar = pyscalar
270                ctx.save_for_backward(tensor1, tensor2)
271                return tensor1 + pyscalar * tensor2 + tensor1 * tensor2
272
273            @staticmethod
274            @once_differentiable
275            def backward(ctx, grad_output):
276                self.assertFalse(torch.is_grad_enabled())
277                t1, t2 = ctx.saved_tensors
278                return (
279                    grad_output + grad_output * t2,
280                    None,
281                    grad_output * ctx.pyscalar + grad_output * t1,
282                )
283
284        x, y = self._function_test(MyFunction)
285        self.assertEqual(
286            graph_desc(x.grad.grad_fn),
287            "CopyBackwards(None, Error(AccumulateGrad(), None, AccumulateGrad()))",
288        )
289        self.assertEqual(
290            graph_desc(y.grad.grad_fn),
291            "CopyBackwards(None, Error(AccumulateGrad(), None, AccumulateGrad()))",
292        )
293
294    def test_function_returns_input(self):
295        class MyFunction(Function):
296            @staticmethod
297            def forward(ctx, x):
298                return x
299
300            @staticmethod
301            def backward(ctx, grad):
302                return grad * 2
303
304        for shape in [(1,), ()]:
305            v = torch.ones(shape, requires_grad=True)
306            MyFunction.apply(v).backward()
307            self.assertEqual(v.grad, torch.full(shape, 2.0))
308
309            with torch.no_grad():
310                v.grad.zero_()
311            MyFunction.apply(v.clone()).backward()
312            self.assertEqual(v.grad, torch.full(shape, 2.0))
313
314    def test_function_returns_undefined_tensor(self):
315        class MyFunction(Function):
316            @staticmethod
317            def forward(ctx, x):
318                return x * 2
319
320            @staticmethod
321            def backward(ctx, grad):
322                return None
323
324        # Test that undefined tensors returned from custom backward function
325        # are propagated as undefined and not tensor full of zeroes
326        x = torch.ones(1, requires_grad=True)
327
328        MyFunction.apply(x).backward()
329        self.assertIsNone(x.grad)
330
331        MyFunction.apply(x**2).backward()
332        self.assertIsNone(x.grad)
333
334        MyFunction.apply(x).sum().backward()
335        self.assertIsNone(x.grad)
336
337        self.assertIsNone(
338            torch.autograd.grad(MyFunction.apply(x), x, allow_unused=True)[0]
339        )
340
341    def test_materialize_grads(self):
342        class MyFunction(Function):
343            @staticmethod
344            def forward(ctx, x):
345                return x
346
347            @staticmethod
348            def backward(ctx, grad):
349                self.assertEqual(grad, torch.zeros(1))
350                return grad
351
352        x = torch.ones(1, requires_grad=True)
353        torch._C._functions.UndefinedGrad()(MyFunction.apply(x)).backward()
354
355    def test_dont_materialize_grads(self):
356        class MyFunction(Function):
357            @staticmethod
358            def forward(ctx, x):
359                ctx.set_materialize_grads(False)
360                return x
361
362            @staticmethod
363            def backward(ctx, grad):
364                self.assertIsNone(grad)
365                return grad
366
367        x = torch.ones(1, requires_grad=True)
368        torch._C._functions.UndefinedGrad()(MyFunction.apply(x)).backward()
369
370    @skipIfTorchDynamo("compile tested in test/dynamo/test_autograd_function.py")
371    def test_set_materialize_non_diff_grads(self):
372        class Func(torch.autograd.Function):
373            @staticmethod
374            def forward(ctx, x):
375                out0 = x.clone()
376                out1 = x.clone()
377                ctx.mark_non_differentiable(out1)
378                ctx._materialize_non_diff_grads = False
379                return out0, out1
380
381            @staticmethod
382            def backward(ctx, g0, g1):
383                self.assertIsNone(g1)
384                return g0
385
386        a = torch.tensor(1.0, requires_grad=True)
387        out = Func.apply(a)[0]
388        out.backward()
389
390    def test_legacy_function_deprecation_exception(self):
391        # Trigger exception
392        class MyFunction(Function):
393            def forward(self, x):
394                return x
395
396            def backward(self, grad_output):
397                return grad_output
398
399        # Check exception occurs
400        with self.assertRaisesRegex(
401            RuntimeError,
402            "Legacy autograd function with non-static forward method is deprecated",
403        ):
404            MyFunction()(torch.randn(3, 4))
405
406    class SimulateBackwardError(Function):
407        @staticmethod
408        def forward(ctx, input):
409            return input.clone()
410
411        @staticmethod
412        @once_differentiable
413        def backward(ctx, input):
414            raise Exception("Simulate error on backward pass")  # noqa: TRY002
415
416    def test_custom_function_exception(self):
417        t1 = torch.rand((3, 3), requires_grad=True)
418        t2 = torch.rand((3, 3), requires_grad=True)
419
420        tmp = (t1 + t2) * (t1 + t2)
421        t3 = TestAutograd.SimulateBackwardError.apply(tmp)
422        with self.assertRaisesRegex(Exception, "Simulate error on backward pass"):
423            t3.sum().backward()
424
425    def test_custom_function_non_tensor_inputs_outputs(self):
426        class MyFunction(Function):
427            @staticmethod
428            def forward(ctx, t1, t2, scale, t3):
429                t4 = t1 + t2 * t3
430                t5 = t1 * t2 + t3
431                t4 *= scale
432                t5 *= scale
433
434                # Save scale
435                ctx.scale = scale
436                ctx.save_for_backward(t1, t2, t3)
437                return scale, t4, None, True, t5, "bar", t1
438
439            @staticmethod
440            @once_differentiable
441            def backward(ctx, *grads):
442                # Verify grads
443                self.assertEqual(7, len(grads))
444                self.assertIsNone(grads[0])
445                self.assertIsNone(grads[2])
446                self.assertIsNone(grads[3])
447                self.assertIsNone(grads[5])
448
449                scale = ctx.scale
450                var1, var2, var3 = ctx.saved_tensors
451                return (
452                    grads[1] * scale + grads[4] * var2 * scale + grads[6],
453                    grads[1] * var3 * scale + grads[4] * var1 * scale,
454                    None,
455                    grads[1] * var2 * scale + grads[4] * scale,
456                )
457
458        t1 = torch.rand(10, dtype=torch.double, requires_grad=True)
459        t2 = torch.rand(10, dtype=torch.double, requires_grad=True)
460        t3 = torch.rand(10, dtype=torch.double)
461        scale = random.randint(0, 10)
462        res = MyFunction.apply(t1, t2, scale, t3)
463        self.assertEqual(scale, res[0])
464        self.assertEqual((t1 + t2 * t3) * scale, res[1])
465        self.assertEqual(None, res[2])
466        self.assertEqual(True, res[3])
467        self.assertEqual((t1 * t2 + t3) * scale, res[4])
468        self.assertEqual("bar", res[5])
469        self.assertEqual(t1, res[6])
470
471        # Validate running backward.
472        torch.autograd.backward([res[1].sum(), res[4].sum(), res[6].sum()])
473        self.assertIsNotNone(t1.grad)
474        self.assertIsNotNone(t2.grad)
475        self.assertIsNone(t3.grad)
476
477        # Test gradcheck
478        def foo(t1, t2, t3):
479            res = MyFunction.apply(t1, t2, scale, t3)
480            return res[1], res[4], res[6]
481
482        gradcheck(foo, (t1, t2, t3))
483
484    def test_custom_function_no_tensors(self):
485        class MyFunction(Function):
486            @staticmethod
487            def forward(ctx, t1, t2, scale, t3):
488                t4 = t1 + t2 * t3
489                t5 = t1 * t2 + t3
490                t4 *= scale
491                t5 *= scale
492                return scale, t4, None, True, t5, "bar", t1
493
494            @staticmethod
495            @once_differentiable
496            def backward(ctx, *args):
497                return (args[0], args[1], None, args[2])
498
499        t1 = random.random()
500        t2 = random.random()
501        t3 = random.random()
502        scale = random.randint(0, 10)
503        res = MyFunction.apply(t1, t2, scale, t3)
504        self.assertEqual(scale, res[0])
505        self.assertEqual((t1 + t2 * t3) * scale, res[1])
506        self.assertEqual(None, res[2])
507        self.assertEqual(True, res[3])
508        self.assertEqual((t1 * t2 + t3) * scale, res[4])
509        self.assertEqual("bar", res[5])
510        self.assertEqual(t1, res[6])
511
512    def test_invalid_gradients(self):
513        class MyFunction(Function):
514            @staticmethod
515            def forward(ctx, x):
516                return x * 2
517
518            @staticmethod
519            def backward(ctx, grad_output):
520                return torch.randn(10, dtype=torch.float)
521
522        with self.assertRaisesRegex(RuntimeError, "expected shape"):
523            input = torch.randn(5, 5, dtype=torch.float, requires_grad=True)
524            MyFunction.apply(input).sum().backward()
525
526    def test_unrelated_inputs(self):
527        # test to ensure grad(grad)check runs successfully even if there is an
528        # unrelated (but differentiable) inputs
529
530        def my_function(x, y):
531            return x * x
532
533        x = torch.rand(10, dtype=torch.double, requires_grad=True)
534        y = torch.rand(10, dtype=torch.double, requires_grad=True)
535
536        gradcheck(my_function, (x, y))
537        gradgradcheck(my_function, (x, y))
538
539    def test_not_implemented_grad(self):
540        a = torch.rand(2, requires_grad=True)
541        # if grad for nextafter ends up being implemented, this should be changed
542        y = torch.nextafter(a, a).sum()
543        with self.assertRaisesRegex(
544            NotImplementedError, "the derivative for .* is not implemented"
545        ):
546            y.backward()
547
548    def test_not_implemented_fwad(self):
549        x = torch.randn(3)
550        v = torch.rand(3)
551
552        with fwAD.dual_level():
553            dual_x = fwAD.make_dual(x, v)
554
555            err_msg = r"Trying to use forward AD with .* that does not support it"
556            hint_msg = "Running forward AD for an OP that does not implement it should raise a NotImplementedError"
557
558            with self.assertRaisesRegex(NotImplementedError, err_msg, msg=hint_msg):
559                # if forward AD ends up being implemented for torch.igamma, choose a different op
560                torch.igamma(dual_x, dual_x)
561
562    def test_saved_tensor_hooks_extra_exit_during_bw_no_crash(self):
563        # This usage of saved tensor is not supported, but should not crash
564        def unpack(x):
565            ctx_1.__exit__()
566            return x
567
568        ctx_1 = torch.autograd.graph.saved_tensors_hooks(lambda x: x, unpack)
569        ctx_2 = torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x)
570
571        for i in range(10):
572            with ctx_2:
573                ctx_1.__enter__()
574                x = torch.randn(3, 3, requires_grad=True)
575                x.sin().sum().backward()
576
577        # Clean up
578        for i in range(10):
579            ctx_1.__exit__()
580
581        # Validate there are no more hooks on the stack
582        a = torch.tensor(1.0, requires_grad=True)
583        y = a.exp()
584        y.grad_fn._raw_saved_result.register_hooks(lambda x: x, lambda x: x)
585
586    def test_saved_tensor_hooks_extra_enter_during_bw_no_leak(self):
587        # This usage of saved tensor is not supported, but should not leak
588        def scope():
589            def unpack(x):
590                weak_ctx_1().__enter__()
591                return x
592
593            ctx_1 = torch.autograd.graph.saved_tensors_hooks(lambda x: x, unpack)
594            weak_ctx_1 = weakref.ref(ctx_1)
595
596            x = torch.randn(3, 3, requires_grad=True)
597            with ctx_1:
598                x.sin().sum().backward()
599            return weakref.ref(unpack)
600
601        with disable_gc():
602            unpack_hook_ref = scope()
603            self.assertIsNone(unpack_hook_ref())
604
605    def test_will_engine_execute_node(self):
606        counter = [0]
607
608        class MyFunction(Function):
609            @staticmethod
610            def forward(ctx, x):
611                return x * 2
612
613            @staticmethod
614            def backward(ctx, gO):
615                return gO * 2
616
617        def get_grad_fn(t):
618            if t.requires_grad and t.grad_fn is None:
619                return t.clone().grad_fn.next_functions[0][0]
620            else:
621                return t.grad_fn
622
623        a = torch.randn(2, 3, 4, requires_grad=True)
624        a2 = torch.randn(2, 3, 4, requires_grad=True)
625        b = a * a2
626        b2 = b.cos()
627        c = MyFunction.apply(b)
628
629        should_execute = list(map(get_grad_fn, (a, b, c)))
630        should_not_execute = list(map(get_grad_fn, (a2, b2)))
631
632        def fn(x):
633            counter[0] += 1
634
635            for g in should_execute:
636                self.assertTrue(torch._C._will_engine_execute_node(g))
637
638            for g in should_not_execute:
639                self.assertFalse(torch._C._will_engine_execute_node(g))
640
641        b.register_hook(fn)
642        c.register_hook(fn)
643
644        # .backward(inputs=) is OK
645        out = c.sum()
646        torch.autograd.backward(out, inputs=(a, b), retain_graph=True)
647        self.assertEqual(counter[0], 2)
648
649        # .backward() is OK
650        should_execute = list(map(get_grad_fn, (a, a2, b, c)))
651        should_not_execute = list(map(get_grad_fn, (b2,)))
652        torch.autograd.backward(out, retain_graph=True)
653
654        # .grad is NOT OK when leaf is passed (this is the current state, subject to change)
655        with self.assertRaisesRegex(
656            RuntimeError, "are currently running autograd.grad()"
657        ):
658            torch.autograd.grad(out, (a,))
659
660        # .grad is OK when non-leaf is passed
661        a = torch.randn(1, 2, 3, requires_grad=True) * 2
662        b = a * 2
663
664        def fn(x):
665            # Check a non-leaf
666            counter[0] += 1
667            self.assertTrue(torch._C._will_engine_execute_node(b.grad_fn))
668
669        b.register_hook(fn)
670        counter[0] = 0
671        torch.autograd.grad(b.sum(), (a,))
672        self.assertEqual(counter[0], 1)
673
674        # Verify other errors are raised
675        with self.assertRaisesRegex(RuntimeError, "during the backward pass"):
676            torch._C._will_engine_execute_node(out.grad_fn)
677
678        with self.assertRaisesRegex(RuntimeError, "expects an grad_fn"):
679            torch._C._will_engine_execute_node(out)
680
681    def test_custom_function_vmap_defaults(self):
682        class MySquare(Function):
683            @staticmethod
684            def forward(x):
685                return x**2
686
687            @staticmethod
688            def setup_context(ctx, inputs, output):
689                (x,) = inputs
690                ctx.save_for_backward(x)
691
692            @staticmethod
693            def backward(ctx, gO):
694                (x,) = ctx.saved_tensors
695                return gO * 2 * x
696
697        self.assertFalse(MySquare.generate_vmap_rule)
698        self.assertTrue(hasattr(MySquare, "vmap"))
699
700    def test_custom_function_setup_context_simple(self):
701        class MySquare(Function):
702            @staticmethod
703            def forward(x):
704                return x**2
705
706            @staticmethod
707            def setup_context(ctx, inputs, output):
708                (x,) = inputs
709                ctx.save_for_backward(x)
710
711            @staticmethod
712            def backward(ctx, gO):
713                (x,) = ctx.saved_tensors
714                return gO * 2 * x
715
716        x = torch.randn([], requires_grad=True)
717        y = MySquare.apply(x)
718        (gx,) = torch.autograd.grad(y, x)
719        self.assertEqual(gx, 2 * x)
720
721    def test_custom_function_setup_context_multi_output(self):
722        # Multiple outputs with some non-Tensor outputs.
723        class MySquare(Function):
724            @staticmethod
725            def forward(x):
726                two_x = x.item() * 2
727                return x**2, two_x
728
729            @staticmethod
730            def setup_context(ctx, inputs, output):
731                (x,) = inputs
732                _, two_x = output
733                ctx.two_x = two_x
734
735            @staticmethod
736            @once_differentiable
737            def backward(ctx, gO, _):
738                return gO * ctx.two_x
739
740        x = torch.randn([], requires_grad=True)
741        y, _ = MySquare.apply(x)
742        (gx,) = torch.autograd.grad(y, x)
743        self.assertEqual(gx, 2 * x)
744
745    def test_custom_function_setup_context_multi_input(self):
746        class MyReshape(Function):
747            @staticmethod
748            def forward(x, shape, scale_forward, scale_backward):
749                return x.reshape(shape) * scale_forward
750
751            @staticmethod
752            def setup_context(ctx, inputs, output):
753                x, shape, scale_forward, scale_backward = inputs
754                ctx.scale_backward = scale_backward
755                ctx.x_shape = x.shape
756
757            @staticmethod
758            def backward(ctx, gO):
759                return gO.reshape(ctx.x_shape) * ctx.scale_backward, None, None, None
760
761        class MyReshapeRef(Function):
762            @staticmethod
763            def forward(ctx, x, shape, scale_forward, scale_backward):
764                ctx.scale_backward = scale_backward
765                ctx.x_shape = x.shape
766                return x.reshape(shape) * scale_forward
767
768            @staticmethod
769            def backward(ctx, gO):
770                return gO.reshape(ctx.x_shape) * ctx.scale_backward, None, None, None
771
772        def test(x, shape, scale_forward, scale_backward):
773            y = MyReshape.apply(x, shape, scale_forward, scale_backward).sum()
774            (gx,) = torch.autograd.grad(y, x)
775
776            y_expected = MyReshapeRef.apply(
777                x, shape, scale_forward, scale_backward
778            ).sum()
779            (gx_expected,) = torch.autograd.grad(y_expected, x)
780
781            self.assertEqual(y_expected, y)
782            self.assertEqual(gx_expected, gx)
783
784        test(torch.randn(24, requires_grad=True), (3, 8), 7, 11)
785        test(torch.randn(2, 3, 4, requires_grad=True), (6, 4), -1, 2)
786
787    def test_multiple_insert_removal_caching(self):
788        torch._C._set_cached_tensors_enabled(True)
789        try:
790            x = torch.rand([4])
791
792            torch._C._add_cached_tensor(x)
793            self.assertTrue(torch._C._is_cached_tensor(x))
794
795            torch._C._add_cached_tensor(x)
796            torch._C._remove_cached_tensor(x)
797
798            self.assertFalse(torch._C._is_cached_tensor(x))
799        finally:
800            torch._C._set_cached_tensors_enabled(False)
801
802    def test_accumulate_grad(self):
803        grad_output = torch.ones(5, 5)
804
805        def compute_grad(create_graph):
806            x = torch.randn(5, 5, requires_grad=True)
807            y = x + 2
808            y.backward(grad_output, retain_graph=True)
809            x_grad = x.grad
810            x_grad_clone = x.grad.clone()
811            y.backward(grad_output, create_graph=create_graph)
812            return x_grad, x_grad_clone
813
814        # Accumulate in-place when create_graph is False
815        x_grad, x_grad_clone = compute_grad(create_graph=False)
816        self.assertEqual(x_grad, x_grad_clone * 2)
817
818        # Accumulate out-of-place when create_graph is False
819        x_grad, x_grad_clone = compute_grad(create_graph=True)
820        self.assertEqual(x_grad, x_grad_clone)
821
822    def test_accumulate_grad_tensor_reference(self):
823        def _test_grad_tensor(
824            params_grad_tensor,
825            backward_grad_tensor,
826            should_preserve_reference,
827            create_graph,
828        ):
829            params = torch.tensor([1.5, 1.5]).requires_grad_()
830            params.grad = params_grad_tensor
831            grad_saved = params.grad
832            params.backward(backward_grad_tensor, create_graph=create_graph)
833            self.assertEqual(
834                id(grad_saved) == id(params.grad), should_preserve_reference
835            )
836
837        for create_graph in (False, True):
838            # Accumulate dense gradient to sparse gradient will change the `params.grad` reference
839            _test_grad_tensor(
840                torch.sparse_coo_tensor(
841                    torch.tensor([[1, 1]]).long(), torch.tensor([1.0, 1.0])
842                ),
843                torch.tensor([1.5, 1.5]),
844                False,  # never accumulates in-place
845                create_graph,
846            )
847
848            # Accumulate dense gradient to dense gradient will preserve the `params.grad` reference,
849            # but only if create_graph=False.
850            _test_grad_tensor(
851                torch.tensor([1.5, 1.5]),
852                torch.tensor([1.5, 1.5]),
853                not create_graph,
854                create_graph,
855            )
856
857            # Accumulate sparse gradient to sparse gradient will preserve the `params.grad` reference,
858            # but only if create_graph=False.
859            _test_grad_tensor(
860                torch.sparse_coo_tensor(
861                    torch.tensor([[1, 1]]).long(), torch.tensor([1.0, 1.0])
862                ),
863                torch.sparse_coo_tensor(
864                    torch.tensor([[1, 1]]).long(), torch.tensor([1.0, 1.0])
865                ),
866                not create_graph,
867                create_graph,
868            )
869
870    def test_accumulate_grad_with_zero_numel_grad(self):
871        a = torch.rand(4, 0, requires_grad=True)
872        b = torch.rand(4, 1, requires_grad=True)
873        c = a + b
874        assert c.shape == (4, 0)
875        c.sum().backward()
876
877        self.assertEqual(b.grad, torch.zeros(4, 1))
878        self.assertEqual(a.grad, torch.zeros(4, 0))
879
880    def test_hessian_vector(self):
881        x = torch.randn(2, 2, requires_grad=True)
882        y = torch.randn(2, 2, requires_grad=True)
883
884        z = x**2 + y * x + y**2
885        z.backward(torch.ones(2, 2), create_graph=True)
886
887        with torch.no_grad():
888            x_grad = 2 * x + y
889            y_grad = x + 2 * y
890        self.assertEqual(x.grad, x_grad)
891        self.assertEqual(y.grad, y_grad)
892
893        grad_sum = 2 * x.grad + y.grad
894        grad_sum.backward(torch.ones(2, 2))
895        x_hv = torch.ones(2, 2) * 5
896        y_hv = torch.ones(2, 2) * 4
897        self.assertEqual(x.grad, x_grad + x_hv)
898        self.assertEqual(y.grad, y_grad + y_hv)
899
900    def test_grad(self):
901        x = torch.randn(2, 2, requires_grad=True)
902        y = torch.randn(2, 2, requires_grad=True)
903        z = x**2 + y * x + y**2
904        z.backward(torch.ones(2, 2), create_graph=True)
905
906        x_grad = 2 * x + y
907        y_grad = x + 2 * y
908        self.assertEqual(x.grad, x_grad)
909        self.assertEqual(y.grad, y_grad)
910
911        grad_sum = 2 * x.grad + y.grad
912        x_hv = torch.autograd.grad(
913            outputs=[grad_sum],
914            grad_outputs=[torch.ones(2, 2)],
915            inputs=[x],
916            create_graph=True,
917        )
918        expected_x_hv = torch.ones(2, 2) * 5
919        expected_y_hv = torch.ones(2, 2) * 4
920
921        self.assertEqual(x_hv[0], expected_x_hv)
922        self.assertEqual(x.grad, x_grad)
923        self.assertEqual(y.grad, y_grad)
924
925        # Test that grad_outputs and outputs have the same shape
926        grad_out = torch.ones(2)
927        try:
928            torch.autograd.grad(
929                outputs=[grad_sum],
930                grad_outputs=[grad_out],
931                inputs=[x],
932                create_graph=True,
933            )
934            self.assertFail()
935        except RuntimeError as error:
936            self.assertEqual(
937                str(error),
938                "Mismatch in shape: grad_output[0] has a shape of "
939                + str(grad_out.shape)
940                + " and output[0] has a shape of "
941                + str(grad_sum.shape)
942                + ".",
943            )
944
945    def test_grad_to_node(self):
946        def check_matches(out, inp):
947            ref = torch.autograd.grad(out.sum(), inp)
948
949            edge = torch.autograd.graph.get_gradient_edge(inp)
950            new = torch.autograd.grad(out.sum(), edge)
951            self.assertEqual(ref, new)
952
953        # We need to ensure that our main types of Node work (regular cpp Nodes,
954        # AccumulateGrad Nodes and custom Function)
955        x = torch.rand(2, requires_grad=True)
956        out = x.clone()
957        check_matches(out, x)
958
959        x = x.clone()
960        out = x.clone()
961        check_matches(out, x)
962
963        x = torch.autograd._functions.Resize.apply(x, (2,))
964        out = x.clone()
965        check_matches(out, x)
966
967        x = torch.var_mean(x)[1]
968        out = x.clone()
969        check_matches(out, x)
970
971    def test_grad_to_node_set(self):
972        x = torch.rand(2, requires_grad=True)
973        x_edge = torch.autograd.graph.get_gradient_edge(x)
974        out = x.clone()
975
976        with torch.no_grad():
977            x.set_(torch.rand_like(x))
978
979        with self.assertRaisesRegex(RuntimeError, "to not have been used in the graph"):
980            torch.autograd.grad(out.sum(), x)
981
982        # Works
983        torch.autograd.grad(out.sum(), x_edge)
984
985    def test_grad_to_node_inplace(self):
986        x = torch.rand(2, requires_grad=True).clone()
987        x_edge = torch.autograd.graph.get_gradient_edge(x)
988        x *= 2
989
990        g_old, g_new = torch.autograd.grad(x.sum(), (x_edge, x))
991        self.assertEqual(g_old, 2 * torch.ones_like(x))
992        self.assertEqual(g_new, torch.ones_like(x))
993
994    def test_grad_to_node_multi(self):
995        x = torch.rand(2, requires_grad=True).clone()
996        y = torch.rand(2, requires_grad=True).clone()
997
998        out = x + y
999
1000        ref = torch.autograd.grad(out.sum(), (x, y))
1001
1002        inp_edges = (
1003            GradientEdge(x.grad_fn, x.output_nr),
1004            GradientEdge(y.grad_fn, y.output_nr),
1005        )
1006        new = torch.autograd.grad(out.sum(), inp_edges)
1007
1008        self.assertEqual(ref, new)
1009
1010    def test_grad_to_node_materialize(self):
1011        x = torch.rand(2, requires_grad=True).clone()
1012        edge_x = GradientEdge(x.grad_fn, x.output_nr)
1013        y = torch.rand(2, requires_grad=True).clone()
1014        edge_y = GradientEdge(y.grad_fn, y.output_nr)
1015
1016        out = x.clone()
1017
1018        # Works
1019        torch.autograd.grad(
1020            out.sum(), (edge_x, y), allow_unused=True, materialize_grads=True
1021        )
1022        torch.autograd.grad(
1023            out.sum(), (x, y), allow_unused=True, materialize_grads=True
1024        )
1025        torch.autograd.grad(out.sum(), (x, edge_y), allow_unused=True)
1026
1027        with self.assertRaisesRegex(
1028            RuntimeError,
1029            "materialize_grads cannot be used when the given input is a GradientEdge",
1030        ):
1031            torch.autograd.grad(
1032                out.sum(), (x, edge_y), allow_unused=True, materialize_grads=True
1033            )
1034
1035    def test_backward_to_node(self):
1036        x = torch.rand(2, requires_grad=True).clone()
1037        edge_x = GradientEdge(x.grad_fn, x.output_nr)
1038        y = torch.rand(2, requires_grad=True).clone()
1039        edge_y = GradientEdge(y.grad_fn, y.output_nr)
1040
1041        out = x.clone()
1042
1043        # All should work in this case
1044        torch.autograd.backward(out.sum(), inputs=(edge_x, y))
1045        torch.autograd.backward(out.sum(), inputs=(x, y))
1046        torch.autograd.backward(out.sum(), inputs=(x, edge_y))
1047        torch.autograd.backward(out.sum(), inputs=(edge_x, edge_y))
1048
1049    def test_grad_fn_input_metadata(self):
1050        x = torch.rand(2, requires_grad=True, dtype=torch.float32)
1051        y = torch.rand(2, requires_grad=True, dtype=torch.float32)
1052        z = x * y
1053        z_metadata = z.grad_fn._input_metadata[0]
1054        self.assertEqual(z_metadata.shape, (2,))
1055        self.assertEqual(z_metadata.dtype, torch.float32)
1056
1057        # Multiple outputs
1058        b = torch.rand(3, 3, requires_grad=True)
1059        var, _ = torch.var_mean(b, dim=0)
1060
1061        metadata_0 = var.grad_fn._input_metadata[0]
1062        metadata_1 = var.grad_fn._input_metadata[1]
1063        self.assertEqual(metadata_0.shape, (3,))
1064        self.assertEqual(metadata_1.shape, (3,))
1065
1066        # Preserves symints
1067        nt = torch.nested.nested_tensor(
1068            [torch.randn(3, 2), torch.randn(2, 2)],
1069            layout=torch.jagged,
1070            requires_grad=True,
1071        )
1072        nt_metadata = nt.clone().grad_fn._input_metadata[0]
1073
1074        self.assertIsInstance(nt_metadata.shape[1], torch.SymInt)
1075        self.assertEqual(nt_metadata.shape, nt.shape)
1076        self.assertTrue(nt_metadata.is_nested_tensor)
1077        self.assertFalse(nt_metadata.is_cpp_nested_tensor)
1078        self.assertEqual(nt_metadata.dtype, nt.dtype)
1079
1080        class Test(torch.autograd.Function):
1081            @staticmethod
1082            def forward(ctx, x):
1083                return x
1084
1085            @staticmethod
1086            def backward(ctx, grad_output):
1087                return grad_output
1088
1089        x = torch.randn(3, 3, requires_grad=True)
1090        x = Test.apply(x)
1091        metadata = x.grad_fn._input_metadata[0]
1092        self.assertEqual(metadata.shape, (3, 3))
1093
1094    def test_gradient_edge_output(self):
1095        x = torch.tensor([1.0, 2.0], requires_grad=True)
1096
1097        def fn(x, reduce=True):
1098            tmp = x.sin().cos()
1099            if reduce:
1100                tmp = tmp.sum()
1101            out = tmp.exp().clone().sin().sum()
1102            tmp_edge = torch.autograd.graph.get_gradient_edge(tmp)
1103            return out, tmp_edge
1104
1105        # Compute fn backward in two steps
1106        out, tmp_edge = fn(x)
1107        (tmp_grad,) = torch.autograd.grad(out, (tmp_edge,))
1108
1109        (x_grad,) = torch.autograd.grad(tmp_edge, (x,), grad_outputs=(tmp_grad,))
1110
1111        # Compare with as if we did it in one go.
1112        out, _ = fn(x)
1113        (x_grad_ref,) = torch.autograd.grad(out, (x,))
1114        self.assertEqual(x_grad, x_grad_ref)
1115
1116        # Incorrect case: grad_outputs not passed/implicitly None and output is
1117        # not a scalar
1118        out, tmp_edge = fn(x, reduce=False)
1119        with self.assertRaisesRegex(
1120            RuntimeError, "grad can be implicitly created only for scalar output"
1121        ):
1122            torch.autograd.grad(tmp_edge, (x,))
1123
1124        # grad_outputs is None, and output is a scalar is fine
1125        out, tmp_edge = fn(x, reduce=True)
1126        torch.autograd.grad(tmp_edge, (x,))
1127
1128        # Incorrect case: grad_outputs wrong size
1129        out, tmp_edge = fn(x)
1130        (tmp_grad,) = torch.autograd.grad(out, (tmp_edge,))
1131        with self.assertRaisesRegex(RuntimeError, "Mismatch in shape"):
1132            torch.autograd.grad(
1133                tmp_edge, (x,), grad_outputs=torch.tensor([1.0, 2.0, 3.0, 4.0])
1134            )
1135
1136        # Incorrect case: wrong dtype
1137        out, tmp_edge = fn(x)
1138        (tmp_grad,) = torch.autograd.grad(out, (tmp_edge,))
1139        with self.assertRaisesRegex(RuntimeError, "required to have the same dtype"):
1140            torch.autograd.grad(
1141                tmp_edge,
1142                (x,),
1143                grad_outputs=torch.rand_like(tmp_grad, dtype=torch.complex64),
1144            )
1145
1146    def test_grad_nonleaf(self):
1147        x_init = torch.randn(2, 2, requires_grad=True)
1148        x = x_init
1149        y = torch.randn(2, 2, requires_grad=True)
1150        grad_output = torch.ones(2, 2)
1151
1152        def fn(x):
1153            return x**2 + y * x + y**2
1154
1155        for _ in range(5):
1156            (grad_x,) = torch.autograd.grad(
1157                fn(x), x, grad_outputs=grad_output, create_graph=True
1158            )
1159
1160            grad_x_expected = 2 * x + y
1161            self.assertIsNone(y.grad)
1162            self.assertIsNone(x.grad)
1163            self.assertEqual(grad_x, grad_x_expected)
1164
1165            x = x + 0.05 * grad_x
1166
1167        val_init = fn(x_init).sum()
1168        val_final = fn(x).sum()
1169        self.assertGreater(val_final, val_init)
1170
1171        x.backward(grad_output)
1172        self.assertIsNotNone(y.grad)
1173        self.assertIsNotNone(x_init.grad)
1174
1175    def test_grad_nonleaf_many_outputs(self):
1176        # This checks an edge case for function callbacks
1177        # We want to capture two grads of a function, but can only
1178        # register a single callback.
1179        x = torch.randn(4, 2, requires_grad=True)
1180        a, b = x.chunk(2)
1181
1182        def hook(*grads):
1183            hook_called[0] = True
1184
1185        hook_called = [False]
1186        x.register_hook(hook)
1187
1188        go = torch.randn(2, 2)
1189        grad_a, grad_b = torch.autograd.grad(
1190            (a + 2 * b), [a, b], grad_outputs=go, create_graph=True
1191        )
1192
1193        self.assertEqual(grad_a, go)
1194        self.assertEqual(grad_b, go * 2)
1195        self.assertFalse(hook_called[0])
1196        self.assertIsNone(x.grad)
1197
1198    def test_grad_nonleaf_register_hook(self):
1199        # This checks an edge case for register_hook.
1200        # We want to capture grad of a nonleaf tensor,
1201        # but avoid segfault during backward of other nonleaf tensors
1202        x = torch.randn(5, requires_grad=True)
1203        x_list = x.unbind()
1204
1205        x0 = x_list[0]
1206        hook_results = [None]
1207
1208        def hook(grad):
1209            hook_results[0] = grad
1210
1211        x0.register_hook(hook)
1212
1213        x_list[0].backward()
1214        self.assertEqual(hook_results[0], torch.tensor(1.0))
1215        expected_grad = torch.tensor([1.0, 0, 0, 0, 0])
1216        self.assertEqual(x.grad, expected_grad)
1217        self.assertIsNone(x_list[0].grad)
1218
1219        for i in range(1, 5, 1):
1220            x_list[i].backward()
1221            self.assertEqual(hook_results[0], None)
1222            expected_grad[i] = 1.0
1223            self.assertEqual(x.grad, expected_grad)
1224            self.assertIsNone(x_list[i].grad)
1225
1226    def test_grad_materialize_grads(self):
1227        x = torch.tensor(0.5, requires_grad=True)
1228        a = torch.tensor(1.0, requires_grad=True)
1229        y = x * a
1230        dydx = torch.autograd.grad(y, x, create_graph=True)
1231        d2ydx2_none = torch.autograd.grad(dydx, x, create_graph=True, allow_unused=True)
1232        d2ydx2 = torch.autograd.grad(
1233            dydx, x, create_graph=True, allow_unused=True, materialize_grads=True
1234        )
1235        # `allow_unused` set to True implicitly
1236        d3ydx3 = torch.autograd.grad(d2ydx2, x, materialize_grads=True)
1237        self.assertIsNone(d2ydx2_none[0])
1238        self.assertEqual(d2ydx2[0].item(), 0)
1239        self.assertEqual(d3ydx3[0].item(), 0)
1240        with self.assertRaisesRegex(
1241            ValueError, "Expected allow_unused to be True or not passed when"
1242        ):
1243            torch.autograd.grad(y, x, allow_unused=False, materialize_grads=True)
1244
1245    def test_post_accumulate_grad_hook_on_non_leaf(self):
1246        def hook(tensor):
1247            tensor.sub_(1.0)
1248
1249        leaf = torch.rand(3, requires_grad=True)
1250        non_leaf = 2.0 * leaf
1251
1252        with self.assertRaisesRegex(
1253            RuntimeError,
1254            "post accumulate grad hooks cannot be registered on non-leaf tensors",
1255        ):
1256            non_leaf.register_post_accumulate_grad_hook(hook)
1257
1258    def test_post_accumulate_grad_hook_multiple_hooks(self):
1259        def hook1(tensor):
1260            tensor.sub_(tensor.grad)
1261
1262        def hook2(tensor):
1263            tensor.mul_(4.0)
1264
1265        tensor = torch.rand(3, requires_grad=True)
1266        tensor_ref = tensor.clone().detach()
1267        tensor.register_post_accumulate_grad_hook(hook1)
1268        tensor.register_post_accumulate_grad_hook(hook2)
1269        sum = tensor.sum()
1270        sum.backward()
1271        # both hooks should be called, in order
1272        self.assertEqual(4.0 * (tensor_ref - 1.0), tensor)
1273
1274    def test_post_accumulate_grad_hook_multiple_tensors(self):
1275        def hook(tensor):
1276            tensor.sub_(tensor.grad)
1277
1278        tensor1 = torch.rand(3, requires_grad=True)
1279        tensor1_ref = tensor1.clone().detach()
1280        tensor2 = torch.rand(5, requires_grad=True)
1281        tensor2_ref = tensor2.clone().detach()
1282        tensor1.register_post_accumulate_grad_hook(hook)
1283        tensor2.register_post_accumulate_grad_hook(hook)
1284        tensor1.sum().backward()
1285        tensor2.sum().backward()
1286        # both tensors should have been modified
1287        self.assertEqual(tensor1_ref - 1.0, tensor1)
1288        self.assertEqual(tensor2_ref - 1.0, tensor2)
1289
1290    def test_post_accumulate_grad_hook_returns_not_None(self):
1291        def bad_hook(tensor):
1292            return tensor.grad
1293
1294        tensor = torch.rand(2, 3, requires_grad=True)
1295        tensor.register_post_accumulate_grad_hook(bad_hook)
1296        # should error!
1297        with self.assertRaisesRegex(RuntimeError, "hooks should return None."):
1298            tensor.sum().backward()
1299
1300    def test_post_accumulate_grad_hook_e2e(self):
1301        def setup_optim_in_bwd(model):
1302            optims = {}
1303            handles = []
1304
1305            def optim_step_hook(param):
1306                optims[param].step()
1307                optims[param].zero_grad()
1308
1309            for p in model.parameters():
1310                optims[p] = torch.optim.Adam([p])
1311                handles.append(p.register_post_accumulate_grad_hook(optim_step_hook))
1312
1313            return handles
1314
1315        model = torch.nn.Linear(3, 2)
1316        input = torch.rand(2, 3)
1317        handles = setup_optim_in_bwd(model)
1318
1319        # make a copy for reference
1320        model_copy = deepcopy(model)
1321        optim_copy = torch.optim.Adam(model_copy.parameters())
1322
1323        iters = 5
1324
1325        for _ in range(iters):
1326            loss = model(input).sum()
1327            loss.backward()
1328
1329            loss_copy = model_copy(input).sum()
1330            loss_copy.backward()
1331            optim_copy.step()
1332            optim_copy.zero_grad()
1333
1334        params_copy = []  # freeze a copy of the params to compare later
1335        for p_reference, p in zip(model_copy.parameters(), model.parameters()):
1336            self.assertEqual(p_reference, p)
1337            params_copy.append(p_reference.clone().detach())
1338
1339        # After removing the handle, the model should no longer update.
1340        for h in handles:
1341            h.remove()
1342
1343        for _ in range(iters):
1344            loss = model(input).sum()
1345            loss.backward()
1346
1347            loss_copy = model_copy(input).sum()
1348            loss_copy.backward()
1349            optim_copy.step()
1350            optim_copy.zero_grad()
1351
1352        for p_static, p_reference, p in zip(
1353            params_copy, model_copy.parameters(), model.parameters()
1354        ):
1355            self.assertEqual(p_static, p)
1356            self.assertNotEqual(p_reference, p)
1357
1358    def test_post_accumulate_grad_hook_gets_cleaned_up(self):
1359        def fun_stuff_with_hook():
1360            thing_to_put_in_hook = torch.rand(3)
1361
1362            def hook(tensor):
1363                tensor.sub_(tensor.grad)
1364                tensor.add_(thing_to_put_in_hook)
1365
1366            tensor = torch.rand(3, requires_grad=True)
1367            tensor.register_post_accumulate_grad_hook(hook)
1368            tensor.sum().backward()
1369            ref = weakref.ref(thing_to_put_in_hook)
1370            gc.collect()
1371            return tensor, ref
1372
1373        with disable_gc():
1374            tensor, ref = fun_stuff_with_hook()
1375            self.assertIsNotNone(
1376                ref()
1377            )  # thing_to_put_in_hook should be kept alive by tensor
1378
1379            del tensor
1380            gc.collect()
1381            self.assertIsNone(ref())  # thing_to_put_in_hook should be cleaned
1382
1383    def test_post_accumulate_grad_hook_ordering(self):
1384        tensor = torch.rand(3, requires_grad=True)
1385
1386        def pre_hook(grad):
1387            return grad.sub(2.0)
1388
1389        def acc_grad_node_pre_hook(grad_out):
1390            return (grad_out[0].div(5.0),)
1391
1392        def post_acc_grad_hook(tensor):
1393            tensor.grad.add_(0.5)
1394
1395        def acc_grad_node_post_hook(grad_in, grad_out):
1396            tensor.grad = grad_out[0].mul(10)
1397
1398        acc_grad = tensor.view_as(tensor).grad_fn.next_functions[0][0]
1399        tensor.register_hook(pre_hook)
1400        acc_grad.register_prehook(acc_grad_node_pre_hook)
1401        tensor.register_post_accumulate_grad_hook(post_acc_grad_hook)
1402        acc_grad.register_hook(acc_grad_node_post_hook)
1403        tensor.sum().backward()
1404
1405        # the hooks should run in the order of:
1406        #   1. tensor prehook
1407        #   2. acc_grad prehook
1408        #   3. tensor post acc_grad hook
1409        #   4. acc_grad posthook
1410        # so that would be ((1 - 2) / 5 + 0.5) * 10 = 3
1411        self.assertEqual(torch.tensor([3.0, 3.0, 3.0]), tensor.grad)
1412
1413    def test_hook_with_no_name(self):
1414        # Create a hook that do not have a __name__ attribute
1415        class MyHookClass:
1416            def __call__(self, grad):
1417                return grad.clone()
1418
1419        x = torch.randn(5, requires_grad=True).clone()
1420        x.register_hook(MyHookClass())
1421        x.sum().backward()
1422        # Should run fine
1423
1424    def test_prehook_ordering(self):
1425        # Hooks registered to tensor are ordered before those
1426        # that are registered to grad_fn
1427        log = []
1428
1429        def hook1(g):
1430            log.append(1)
1431            return g * 3
1432
1433        def hook2(gs):
1434            log.append(2)
1435            return tuple(g * 2 for g in gs)
1436
1437        a = torch.tensor(1.0, requires_grad=True)
1438        b = a.clone()
1439
1440        b.grad_fn.register_prehook(hook2)
1441        b.register_hook(hook1)
1442        b.grad_fn.register_prehook(hook2)
1443
1444        acc = b.grad_fn.next_functions[0][0]
1445        a.register_hook(hook1)
1446        acc.register_prehook(hook2)
1447        a.register_hook(hook1)
1448
1449        b.sum().backward(retain_graph=True)
1450        self.assertEqual(log, [1, 2, 2, 1, 1, 2])
1451
1452        # grad also runs hooks on accumulate grad nodes, even though
1453        # the accumulate grad nodes are not actually executed
1454        log = []
1455        torch.autograd.grad(b.sum(), inputs=(a,), retain_graph=True)
1456        self.assertEqual(log, [1, 2, 2, 1, 1])
1457
1458        log = []
1459        b.sum().backward(inputs=(b,))
1460        self.assertEqual(log, [1, 2, 2])
1461        # retains_grad hooks would not observe modifications by all pre hooks
1462        # because they are executed after
1463        self.assertEqual(b.grad.item(), 3)
1464
1465    def test_retains_grad_can_always_observe_tensor_prehook(self):
1466        def tensor_prehook(g):
1467            return g * 2
1468
1469        a = torch.tensor(1.0, requires_grad=True)
1470        b = a.clone()
1471        b.register_hook(tensor_prehook)
1472        b.retain_grad()
1473        b.register_hook(tensor_prehook)
1474
1475        b.clone().backward()
1476        self.assertEqual(b.grad.item(), 4)
1477
1478        a = torch.tensor(1.0, requires_grad=True)
1479        b = a.clone()
1480        b.retain_grad()
1481        b.register_hook(tensor_prehook)
1482
1483        b.clone().backward()
1484        self.assertEqual(b.grad.item(), 2)
1485
1486    def test_accumulate_grad_posthooks_can_observe_tensor_prehook(self):
1487        # Post hooks on accumulate should be able to observe changes to
1488        # grad made by tensor prehooks
1489        a = torch.tensor(1.0, requires_grad=True)
1490
1491        def tensor_prehook(g):
1492            return g * 2
1493
1494        def posthook(gO, gI):
1495            self.assertTrue(torch.allclose(gI[0], a * 2))
1496            self.assertEqual(len(gO), 0)
1497
1498        def prehook(gI):
1499            self.assertTrue(torch.allclose(gI[0], a * 2))
1500            self.assertEqual(len(gI), 1)
1501
1502        b = a.clone()
1503        acc = b.grad_fn.next_functions[0][0]
1504        acc.register_hook(posthook)
1505        acc.register_prehook(prehook)
1506        a.register_hook(tensor_prehook)
1507
1508        b.backward()
1509
1510    def test_accumulate_grad_posthooks_should_not_execute(self):
1511        def tensor_prehook(g):
1512            raise RuntimeError
1513
1514        def posthook(gO, gI):
1515            raise RuntimeError
1516
1517        a = torch.tensor(1.0, requires_grad=True)
1518        a.register_hook(tensor_prehook)
1519        b = torch.tensor(1.0, requires_grad=True)
1520        c = a.clone()
1521        acc = c.grad_fn.next_functions[0][0]
1522        acc.register_hook(posthook)
1523
1524        out = a + b + c
1525        out.sum().backward(inputs=[b])
1526
1527    def test_hook_edge_case_when_called_with_grad(self):
1528        # grad executes the tensor hooks of the next node but not
1529        # grad_fn pre hooks or the post hooks
1530        a = torch.tensor(1.0, requires_grad=True)
1531        b = a * 2
1532        c = b * 2
1533
1534        tensor_hook_count = [0]
1535        prehook_count = [0]
1536        posthook_count = [0]
1537
1538        def reset_counts():
1539            nonlocal tensor_hook_count, prehook_count, posthook_count
1540            tensor_hook_count = [0]
1541            prehook_count = [0]
1542            posthook_count = [0]
1543
1544        def tensor_prehook(g):
1545            tensor_hook_count[0] += 1
1546
1547        def prehook(g):
1548            prehook_count[0] += 1
1549
1550        def posthook(gI, gO):
1551            posthook_count[0] += 1
1552
1553        a.register_hook(tensor_prehook)
1554        b.register_hook(tensor_prehook)
1555        acc = b.grad_fn.next_functions[0][0]
1556        acc.register_hook(posthook)
1557        acc.register_prehook(prehook)
1558        b.grad_fn.register_hook(posthook)
1559        b.grad_fn.register_prehook(prehook)
1560
1561        torch.autograd.grad(c, inputs=(b), retain_graph=True)
1562        self.assertEqual(tensor_hook_count[0], 1)
1563        self.assertEqual(posthook_count[0], 0)
1564        self.assertEqual(prehook_count[0], 0)
1565        reset_counts()
1566
1567        torch.autograd.grad(c, inputs=(a, b), retain_graph=True)
1568        self.assertEqual(tensor_hook_count[0], 2)
1569        self.assertEqual(posthook_count[0], 1)
1570        self.assertEqual(prehook_count[0], 1)
1571        reset_counts()
1572
1573        c.backward(retain_graph=True)
1574        self.assertEqual(tensor_hook_count[0], 2)
1575        self.assertEqual(posthook_count[0], 2)
1576        self.assertEqual(prehook_count[0], 2)
1577        reset_counts()
1578
1579        c.backward(inputs=(a, b), retain_graph=True)
1580        self.assertEqual(tensor_hook_count[0], 2)
1581        self.assertEqual(posthook_count[0], 2)
1582        self.assertEqual(prehook_count[0], 2)
1583
1584    def test_sharded_grad(self):
1585        leaves = [torch.zeros(5, 5, requires_grad=True) for _ in range(10)]
1586        intermediates = [l * i + l * l for i, l in enumerate(leaves)]
1587        loss = sum(v * i for i, v in enumerate(intermediates)).sum()
1588
1589        # define a helper for dividing intermediates into groups
1590        def group(l, group_size):
1591            return (l[i : i + group_size] for i in range(0, len(l), group_size))
1592
1593        # Compute the d loss / d intermediates in chunks of shard_size
1594        shard_size = 2
1595        d_intermediates = [
1596            d_i
1597            for intermediates_batch in group(intermediates, shard_size)
1598            for d_i in torch.autograd.grad(loss, intermediates_batch)
1599        ]
1600        # Compute rest of backward pass
1601        torch.autograd.backward(intermediates, d_intermediates)
1602
1603        for i, l in enumerate(leaves):
1604            self.assertEqual(l.grad, i * i * (1 + l))
1605
1606    def test_backward_badcalls(self):
1607        x = torch.ones(1)
1608        with self.assertRaisesRegex(RuntimeError, "does not require grad"):
1609            x.backward()
1610
1611    def test_grad_badcalls(self):
1612        x = torch.ones(1)
1613        y = x**2
1614        with self.assertRaisesRegex(RuntimeError, "does not require grad"):
1615            torch.autograd.grad(x, y)
1616        with self.assertRaisesRegex(RuntimeError, "does not require grad"):
1617            torch.autograd.grad(y, x)
1618
1619        x = torch.ones(1, requires_grad=True)
1620        y = x**2
1621        torch.autograd.grad(y, x)  # this should succeed now
1622
1623    def test_grad_empty_inputs(self):
1624        x = torch.tensor([1.0], requires_grad=True)
1625        with self.assertRaisesRegex(ValueError, "grad requires non-empty inputs."):
1626            torch.autograd.grad(2 * x, [], grad_outputs=torch.tensor([1.0]))
1627
1628    def test_grad_fn_badcalls(self):
1629        error_regex = "expected .* arguments, got .* instead"
1630        x = torch.ones(1, requires_grad=True)
1631        y = x**2
1632        with self.assertRaisesRegex(TypeError, error_regex):
1633            y.grad_fn(x.detach(), x.detach())  # too many
1634        with self.assertRaisesRegex(TypeError, error_regex):
1635            y.grad_fn()  # too few
1636
1637        y.grad_fn(x.detach())  # this should succeed
1638
1639    def test_grad_unreachable(self):
1640        x = torch.ones(1, requires_grad=True)
1641        y = torch.ones(1, requires_grad=True)
1642        # Make sure x and y have grad accumulators allocated
1643        z = x * 2
1644        w = y * 2
1645
1646        grad_x, grad_y = torch.autograd.grad(x * 2, [x, y], allow_unused=True)
1647        self.assertEqual(grad_x, x * 2)
1648        self.assertIsNone(grad_y)
1649
1650        # This is slightly different than the case above, because z doesn't even
1651        # have a grad accumulator allocated.
1652        z = torch.ones(1, requires_grad=True)
1653        grad_x, grad_z = torch.autograd.grad(x * 2, [x, z], allow_unused=True)
1654        self.assertEqual(grad_x, x * 2)
1655        self.assertIsNone(grad_z)
1656
1657        # allow_unused=False, but grads contains None inside, should throw
1658        with self.assertRaisesRegex(RuntimeError, "Set allow_unused=True"):
1659            grad_x, grad_y = torch.autograd.grad(x * 2, [x, y], allow_unused=False)
1660
1661    def test_grad_unreachable_discovery(self):
1662        # Test that certain nodes are not erroneously executed when an input
1663        # is unreachable. See #39784
1664        class MyFunc(torch.autograd.Function):
1665            @staticmethod
1666            def forward(ctx, x):
1667                return x
1668
1669            @staticmethod
1670            def backward(ctx, x):
1671                self.fail("This node should not be executed!")
1672
1673        x = MyFunc.apply(torch.randn(1, requires_grad=True) * 2)
1674        y = torch.randn(1, requires_grad=True)
1675        (gY,) = torch.autograd.grad(x, (y,), allow_unused=True)
1676        self.assertIsNone(gY)
1677
1678        x = MyFunc.apply(torch.randn(1, requires_grad=True) * 2)
1679        y = torch.randn(1, requires_grad=True)
1680        z = torch.randn(1, requires_grad=True)
1681        (gY, gZ) = torch.autograd.grad(x + z, (y, z), allow_unused=True)
1682        self.assertIsNone(gY)
1683        self.assertIsNotNone(gZ)
1684
1685        x = MyFunc.apply(torch.randn(1, requires_grad=True) * 2)
1686        y = torch.randn(1, requires_grad=True)
1687        torch.autograd.backward(x, inputs=(y,))  # allow_unused is implicitly True!
1688        self.assertIsNone(y.grad)
1689
1690    def test_grad_batched_grad(self):
1691        x = torch.randn(2, 2, requires_grad=True)
1692
1693        out = x.clone()  # Size([2, 2])
1694        batched_grad = (
1695            torch.arange(3).expand(2, 2, 3).transpose(0, 2)
1696        )  # Size([3, 2, 2])
1697        (grad,) = torch.autograd.grad(out, (x,), (batched_grad,), is_grads_batched=True)
1698        self.assertEqual(
1699            grad, torch.arange(3).expand(2, 2, 3).transpose(0, 2).to(dtype=grad.dtype)
1700        )
1701
1702        # Detect shape mismatch
1703        grad_out = torch.ones(2, 2)
1704        with self.assertRaisesRegex(
1705            RuntimeError, "If `is_grads_batched=True`, we interpret the first"
1706        ):
1707            torch.autograd.grad(
1708                outputs=out,
1709                grad_outputs=(grad_out,),
1710                inputs=(x,),
1711                is_grads_batched=True,
1712            )
1713
1714        # Scalar outputs
1715        out = x.sum()  # Size([])
1716        batched_grad = torch.arange(3)  # Size([3])
1717        (grad,) = torch.autograd.grad(out, (x,), (batched_grad,), is_grads_batched=True)
1718        self.assertEqual(
1719            grad, torch.arange(3).expand(2, 2, 3).transpose(0, 2).to(dtype=grad.dtype)
1720        )
1721
1722        # We consider scalar and sized-1 to be a mismatch. This is consistent with current non-batched behavior.
1723        grad_out = torch.ones(2).unsqueeze(1)
1724        with self.assertRaisesRegex(
1725            RuntimeError, "If `is_grads_batched=True`, we interpret the first"
1726        ):
1727            torch.autograd.grad(
1728                outputs=out,
1729                grad_outputs=(grad_out,),
1730                inputs=(x,),
1731                is_grads_batched=True,
1732            )
1733
1734    def test_hooks(self):
1735        x = torch.ones(5, 5, requires_grad=True)
1736        y = torch.ones(5, 5) * 4
1737        y.requires_grad_(True)
1738
1739        counter = [0]
1740
1741        def bw_hook(inc, grad):
1742            self.assertIsInstance(grad, torch.Tensor)
1743            counter[0] += inc
1744
1745        z = x**2 + x * 2 + x * y + y
1746        x.register_hook(lambda *args: bw_hook(0, *args))
1747        test = z.register_hook(lambda *args: bw_hook(1, *args))
1748        z.backward(torch.ones(5, 5), retain_graph=True)
1749        self.assertEqual(counter[0], 1)
1750
1751        test2 = z.register_hook(lambda *args: bw_hook(2, *args))
1752        z.backward(torch.ones(5, 5), retain_graph=True)
1753        self.assertEqual(counter[0], 4)
1754
1755        test2.remove()
1756        z.backward(torch.ones(5, 5), retain_graph=True)
1757        self.assertEqual(counter[0], 5)
1758
1759        def bw_hook_modify(grad):
1760            return grad.mul(2)
1761
1762        test.remove()
1763        z.register_hook(bw_hook_modify)
1764        with torch.no_grad():
1765            y.grad.zero_()
1766        z.backward(torch.ones(5, 5), retain_graph=True)
1767        self.assertEqual(y.grad, (x + 1) * 2)
1768
1769        y.register_hook(bw_hook_modify)
1770        with torch.no_grad():
1771            y.grad.zero_()
1772        z.backward(torch.ones(5, 5))
1773        self.assertEqual(y.grad, (x + 1) * 4)
1774
1775    def _get_mul2(self, use_custom_function):
1776        if use_custom_function:
1777
1778            class Mul2(Function):
1779                @staticmethod
1780                def forward(ctx, x):
1781                    return x * 2
1782
1783                @staticmethod
1784                def backward(ctx, gO):
1785                    return gO * 2
1786
1787            return Mul2.apply
1788        else:
1789            return lambda x: x * 2
1790
1791    def test_grad_fn_prehooks(self):
1792        for use_custom_function in (True, False):
1793            mul2 = self._get_mul2(use_custom_function)
1794
1795            a = torch.tensor([1.0], requires_grad=True)
1796            b = mul2(a)
1797
1798            post_counter = [0]
1799            pre_counter = [0]
1800
1801            def posthook(grad_input, grad_output):
1802                self.assertEqual(pre_counter[0], 3)
1803                self.assertTrue(torch.allclose(grad_output[0], torch.ones(1) * 8))
1804                self.assertTrue(torch.allclose(grad_input[0], torch.ones(1) * 16))
1805                post_counter[0] += 1
1806                return grad_input
1807
1808            def prehook(grad_output):
1809                pre_counter[0] += 1
1810                return (grad_output[0] * 2,)
1811
1812            # register posthook x 2
1813            b.grad_fn.register_hook(posthook)
1814            b.grad_fn.register_hook(posthook)
1815            # register prehook x 3
1816            b.grad_fn.register_prehook(prehook)
1817            b.grad_fn.register_prehook(lambda x: None)
1818            b.grad_fn.register_prehook(prehook)
1819            b.grad_fn.register_prehook(prehook)
1820            b.grad_fn.register_prehook(lambda x: x)
1821            b.grad_fn.register_prehook(lambda x: None)
1822
1823            b.sum().backward()
1824
1825            self.assertEqual(post_counter[0], 2)
1826            self.assertEqual(pre_counter[0], 3)
1827
1828            # Return None
1829            a = torch.rand(3, 3, requires_grad=True)
1830            b = mul2(a)
1831
1832            def prehook(grad_output):
1833                pre_counter[0] += 1
1834                return None
1835
1836            b.grad_fn.register_prehook(prehook)
1837            b.sum().backward()
1838            self.assertEqual(pre_counter[0], 4)
1839            self.assertTrue(torch.allclose(a.grad, torch.ones(3, 3) * 2))
1840
1841    def test_grad_fn_prehooks_multiple_outputs(self):
1842        # Compute gradients without hooks
1843        b = torch.rand(3, 3, requires_grad=True)
1844        var, mean = torch.var_mean(b, dim=0)
1845        (var + mean).sum().backward()
1846
1847        # Compute gradients with hooks
1848        a = b.detach().requires_grad_()
1849        counter = [0]
1850
1851        def prehook(grad_output):
1852            gvar, gmean = grad_output
1853            counter[0] += 1
1854            return (gvar * 2, gmean * 2)
1855
1856        var, mean = torch.var_mean(a, dim=0)
1857        mean.grad_fn.register_prehook(prehook)
1858        (var + mean).sum().backward()
1859
1860        self.assertEqual(counter[0], 1)
1861        # Compare
1862        self.assertTrue(torch.allclose(a.grad, b.grad * 2))
1863
1864        # Test with custom Function
1865        class DoubleMul2(Function):
1866            @staticmethod
1867            def forward(ctx, x, a, y):
1868                ctx.a = a
1869                return a * x * 2, a, a * y * 2
1870
1871            @staticmethod
1872            def backward(ctx, g1, _a, g2):
1873                return ctx.a * g1 * 2, None, ctx.a * g2 * 2
1874
1875        counter = [0]
1876
1877        def prehook(grad_output):
1878            g1, ga, g2 = grad_output
1879            self.assertIsNone(ga)
1880            counter[0] += 1
1881            return (g1 * 2, None, g2 * 2)
1882
1883        a = torch.randn(3, 3, requires_grad=True)
1884        b = torch.randn(3, 3, requires_grad=True)
1885        k = 3
1886        c, _, d = DoubleMul2.apply(a, k, b)
1887        c.grad_fn.register_prehook(prehook)
1888        (c + d).sum().backward()
1889
1890        self.assertEqual(counter[0], 1)
1891        self.assertTrue(torch.allclose(a.grad, torch.ones(1) * 4 * k))
1892        self.assertTrue(torch.allclose(b.grad, torch.ones(1) * 4 * k))
1893
1894    def test_grad_fn_prehooks_remove_hooks(self):
1895        for use_custom_function in (True, False):
1896            mul2 = self._get_mul2(use_custom_function)
1897
1898            # Simply remove hooks
1899
1900            a = torch.rand(3, 3, requires_grad=True)
1901            b = mul2(a)
1902            counter = [0]
1903
1904            def prehook(grad_output):
1905                counter[0] += 1
1906                return None
1907
1908            handle = b.grad_fn.register_prehook(prehook)
1909            b.grad_fn.register_prehook(prehook)
1910            handle.remove()
1911            b.sum().backward()
1912            self.assertTrue(torch.allclose(a.grad, torch.ones(3, 3) * 2))
1913            self.assertEqual(counter[0], 1)
1914
1915            # Remove hooks during backward
1916            a = torch.rand(3, 3, requires_grad=True)
1917            b = mul2(a)
1918            counter = [0]
1919
1920            def prehook1(grad_output):
1921                handle2.remove()
1922                # Remove hook that is already removed is OK
1923                handle3.remove()
1924                return None
1925
1926            def prehook2(grad_output):
1927                counter[0] += 1
1928                return None
1929
1930            # Hooks that registered first run first
1931            b.grad_fn.register_prehook(prehook1)
1932            handle2 = b.grad_fn.register_prehook(prehook2)
1933            handle3 = b.grad_fn.register_prehook(prehook2)
1934            handle3.remove()
1935            b.sum().backward()
1936            self.assertTrue(torch.allclose(a.grad, torch.ones(3, 3) * 2))
1937            self.assertEqual(counter[0], 1)
1938
1939    def test_node_post_hook_registered_during_unpack_hook(self):
1940        """
1941        Test that post hooks registered during one of the node's
1942        unpack hooks are properly restricted and will run properly.
1943        """
1944        test_case = self
1945
1946        class RegisterPostNodeHook(torch.autograd.graph.saved_tensors_hooks):
1947            def __init__(self) -> None:
1948                def pack_tensor(tensor: torch.Tensor) -> torch.Tensor:
1949                    return tensor
1950
1951                def unpack_tensor(tensor: torch.Tensor) -> torch.Tensor:
1952                    node = torch._C._current_autograd_node()
1953
1954                    def hook(outputs, inputs):
1955                        # Assert that inputs passed in are None
1956                        test_case.assertTrue(all(i is None for i in inputs))
1957                        halved_outputs = tuple(
1958                            o / 2.0 if o is not None else None for o in outputs
1959                        )
1960                        return halved_outputs
1961
1962                    node.register_hook(hook)
1963                    return tensor
1964
1965                super().__init__(pack_tensor, unpack_tensor)
1966
1967        a = torch.rand(3, 3, requires_grad=True)
1968
1969        def model():
1970            var, mean = torch.var_mean(a, dim=0)
1971            loss = (var + mean).sum()
1972            loss.backward()
1973
1974        model()
1975        ref_grad = a.grad.clone()
1976
1977        with RegisterPostNodeHook():
1978            model()
1979
1980        # Verify that the post hook got called and the grad propagation worked
1981        self.assertEqual(ref_grad / 2.0 + ref_grad, a.grad)
1982
1983    def test_hooks_cpp(self):
1984        # Tests hooks for autograd function implemented in C++
1985        bn = torch.nn.BatchNorm1d(5, affine=False)
1986        bn.double()
1987        bn.eval()
1988
1989        counter = [0]
1990
1991        def bw_hook(grad):
1992            counter[0] += 1
1993            return grad * 2
1994
1995        x = torch.ones(5, 5, dtype=torch.double, requires_grad=True)
1996        z = bn(x)
1997        z.register_hook(bw_hook)
1998        z.sum().backward()
1999
2000        self.assertEqual(counter[0], 1, msg="bw_hook not called")
2001        self.assertEqual(
2002            x.grad, torch.ones(5, 5, dtype=torch.double) * 2, atol=1e-5, rtol=0
2003        )
2004
2005    def test_hook_none(self):
2006        # WARNING: this is a test for autograd internals.
2007        # You should never have to use such things in your code.
2008        class NoneGradientFunction(Function):
2009            @staticmethod
2010            def forward(ctx, x, y):
2011                assert ctx.needs_input_grad[0]
2012                assert not ctx.needs_input_grad[1]
2013                return x, y
2014
2015            @staticmethod
2016            def backward(ctx, grad_x, grad_y):
2017                return grad_x, None
2018
2019        was_called = [False]
2020
2021        def hook(grad):
2022            self.assertIsNotNone(grad)
2023            was_called[0] = True
2024
2025        x = torch.randn(5, 5, requires_grad=True)
2026        y = torch.randn(5, 5)
2027        rx, ry = NoneGradientFunction.apply(x, y)
2028        rx.register_hook(hook)
2029        ry.register_hook(hook)
2030        sum(rx, ry).sum().backward()
2031        self.assertTrue(was_called[0])
2032
2033    def test_retain_grad(self):
2034        input = torch.rand(1, 3, requires_grad=True)
2035        h1 = input * 3
2036        out = (h1 * h1).sum()
2037
2038        # It should be possible to call retain_grad() multiple times
2039        h1.retain_grad()
2040        h1.retain_grad()
2041
2042        # Gradient should be accumulated
2043        out.backward(retain_graph=True)
2044        self.assertEqual(h1 * 2, h1.grad)
2045        out.backward(retain_graph=True)
2046        self.assertEqual(h1 * 4, h1.grad)
2047
2048        with torch.no_grad():
2049            input.grad.zero_()
2050        # It should be a no-op for leaves
2051        input.retain_grad()
2052        input.retain_grad()
2053        out.backward()
2054        self.assertEqual(input * 18, input.grad)
2055
2056    # NB: See test/cpp/api/autograd.cpp for more tests on the interaction between
2057    #     retains_grad and hooks in cpp
2058    def test_retain_grad_inplace(self):
2059        a = torch.tensor([1.0], requires_grad=True).clone()
2060        a.retain_grad()
2061        a.mul_(2)
2062        a.sum().backward()
2063        self.assertEqual(a.grad, torch.tensor([1.0]))
2064
2065        a = torch.tensor([1.0], requires_grad=True).clone()
2066        a.retain_grad()
2067        # Inplace multiple times is OK
2068        a.mul_(2)
2069        a.mul_(2)
2070        a.sum().backward()
2071        self.assertEqual(a.grad, torch.tensor([1.0]))
2072
2073        # When in-place over view is done, the retains_grad hooks should be
2074        # moved from base's original grad_fn to the copyslices node.
2075        x = torch.tensor([1.0], requires_grad=True).clone()
2076        x.retain_grad()
2077        x_view = x[:]
2078        x_view *= 2
2079        x *= 2
2080        x.sum().backward()
2081        # The grad is 1, not 4, because we are computing grad wrt the latest
2082        # version of x.
2083        self.assertEqual(a.grad, torch.tensor([1.0]))
2084
2085        # If the base did not originally require grad, there should be no hook
2086        # to move. Make sure this case runs without error.
2087        x = torch.zeros(4)
2088        y = x.view(2, 2)
2089        y.add_(torch.randn(2, 2, requires_grad=True))
2090
2091    def test_retains_grad_inplace_multiple_outputs(self):
2092        class DoubleMul(Function):
2093            @staticmethod
2094            def forward(ctx, x):
2095                return x * 2, x * 3
2096
2097            @staticmethod
2098            def backward(ctx, g1, g2):
2099                return g1 * 2 + g2 * 3
2100
2101        var_mean = partial(torch.var_mean, dim=0)
2102
2103        for fn in (DoubleMul.apply, var_mean):
2104            b = torch.rand(3, 3, requires_grad=True)
2105            var, mean = fn(b)
2106            var.retain_grad()
2107            mean.retain_grad()
2108            # node has two retains_grad hooks
2109            var.mul_(2)
2110            # the retain_grad hook multi-output node refers should now be a nullptr
2111            (var + mean).sum().backward()
2112            gvar = var.grad
2113            gmean = mean.grad
2114
2115            a = b.detach().requires_grad_(True)
2116            var, mean = fn(a)
2117            var.mul_(2)
2118            out = (var + mean).sum()
2119            gvar_expected, gmean_expected = torch.autograd.grad(out, inputs=(var, mean))
2120            self.assertTrue(torch.allclose(gvar, gvar_expected))
2121            self.assertTrue(torch.allclose(gmean, gmean_expected))
2122
2123    def test_retain_grad_inplace_over_view(self):
2124        base = torch.tensor([1.0], requires_grad=True).clone()
2125        view = base[:]
2126        view2 = base[:]
2127        view.retain_grad()
2128        view2.retain_grad()
2129        view.mul_(2)
2130        (view + view2).sum().backward()
2131
2132        # The old grad_fn, slice, wouldn't be part of the graph during backward
2133        # so if the retains grad were not properly updated to the new grad_fn,
2134        # the grad would still be None
2135        self.assertEqual(view.grad, view2.grad)
2136        self.assertEqual(view.grad, torch.tensor([1.0]))
2137
2138    def test_tensor_hooks_inplace(self):
2139        # Check that the second hook gets registered to the new version of tensor
2140        count1 = [0]
2141        count2 = [0]
2142
2143        def fn1(grad):
2144            count1[0] += 1
2145            # x2 from mul, x2 from fn2
2146            self.assertEqual(grad, torch.tensor([4.0]))
2147            return grad * 2
2148
2149        def fn2(grad):
2150            count2[0] += 1
2151            self.assertEqual(grad, torch.tensor([1.0]))
2152            return grad * 2
2153
2154        a = torch.tensor([1.0], requires_grad=True)
2155        b = a.clone()
2156        b.register_hook(fn1)
2157        b.mul_(2)
2158        b.register_hook(fn2)
2159        b.sum().backward()
2160        self.assertEqual(count1[0], 1)
2161        self.assertEqual(count2[0], 1)
2162        self.assertEqual(a.grad, torch.tensor([8.0]))
2163
2164        count3 = [0]
2165
2166        def fn3(grad):
2167            count3[0] += 1
2168            self.assertEqual(grad, torch.tensor([4.0]))
2169            return grad * 2
2170
2171        a = torch.tensor([1.0], requires_grad=True)
2172        b = a.clone()
2173        b.register_hook(fn3)
2174        # Inplace multiple times is OK
2175        b.mul_(2)
2176        b.mul_(2)
2177        b.sum().backward()
2178        self.assertEqual(count1[0], 1)
2179        self.assertEqual(a.grad, torch.tensor([8.0]))
2180
2181    def test_tensor_hooks_inplace_multiple_outputs(self):
2182        class DoubleMul(Function):
2183            @staticmethod
2184            def forward(ctx, x):
2185                return x * 2, x * 3
2186
2187            @staticmethod
2188            def backward(ctx, g1, g2):
2189                return g1 * 2 + g2 * 3
2190
2191        var_mean = partial(torch.var_mean, dim=0)
2192
2193        for fn in (DoubleMul.apply, var_mean):
2194            counts = [0, 0, 0]
2195
2196            def fn0(grad):
2197                counts[0] += 1
2198                self.assertEqual(grad, torch.ones_like(out1) * 2)
2199
2200            def fn1(grad):
2201                counts[1] += 1
2202                self.assertEqual(grad, torch.ones_like(out1) * 3)
2203
2204            def fn2(grad):
2205                counts[2] += 1
2206                self.assertEqual(grad, torch.ones_like(out1))
2207
2208            b = torch.rand(3, 3, requires_grad=True)
2209            out1, out2 = fn(b)
2210            out1.register_hook(fn0)
2211            out2.register_hook(fn1)
2212            # node refers to two hook dicts
2213            # out1 no longer no longer points to its old hook dict
2214            out1.mul_(2)
2215            # fn2 is registered to out1's new hook dict
2216            out1.register_hook(fn2)
2217            (out1 + out2 * 3).sum().backward()
2218            self.assertEqual(counts, [1, 1, 1])
2219
2220    def test_tensor_hooks_inplace_over_view(self):
2221        # There might be a better UX here, but this is the way it is now
2222        count = [0]
2223
2224        def fn0(grad):
2225            self.fail()
2226
2227        def fn1(grad):
2228            self.fail()
2229
2230        def fn2(grad):
2231            count[0] += 1
2232            self.assertEqual(grad, torch.tensor([1.0]))
2233
2234        base = torch.tensor([1.0], requires_grad=True).clone()
2235        view = base[:]
2236        view2 = base[:]
2237        view.register_hook(fn0)
2238        view2.register_hook(fn1)
2239        view.mul_(2)
2240        # We need to explicitly trigger an update to view to update its grad_fn
2241        view2.grad_fn
2242        view2.register_hook(fn2)
2243        (view + view2).sum().backward()
2244        # The hooks originally registered to view are not fired, one must explicitly
2245        # trigger an update to the view's grad_fn, and then register a new hook
2246        self.assertEqual(count[0], 1)
2247
2248    def test_retain_grad_cycle(self):
2249        x = torch.ones(5, 5, requires_grad=True)
2250
2251        def run_test():
2252            y = x * 2
2253            y.retain_grad()
2254
2255            return y / 2, torch._C._WeakTensorRef(y)
2256
2257        z, ref = run_test()
2258        self.assertTrue(ref.expired())
2259        z.sum().backward()
2260
2261    def test_backward(self):
2262        v = torch.randn(5, 5, requires_grad=True)
2263        x = torch.randn(5, 5, requires_grad=True)
2264        y = (torch.rand(5, 5) + 0.1).requires_grad_(True)
2265        z = torch.randn(5, 5, requires_grad=True)
2266        grad_output = torch.randn(5, 5)
2267
2268        v.backward(grad_output)
2269        self.assertEqual(v.grad, grad_output)
2270
2271        a = x + (y * z) + 4 * z**2 * x / y
2272        a.backward(grad_output)
2273        x_grad = 4 * z.pow(2) / y + 1
2274        y_grad = z - 4 * x * z.pow(2) / y.pow(2)
2275        z_grad = 8 * x * z / y + y
2276        self.assertEqual(x.grad, x_grad * grad_output)
2277        self.assertEqual(y.grad, y_grad * grad_output)
2278        self.assertEqual(z.grad, z_grad * grad_output)
2279
2280    def test_to_sparse_backward(self):
2281        to_attr_names = (
2282            "to_dense",
2283            "to_sparse",
2284            "to_sparse_csr",
2285            "to_sparse_csc",
2286            "to_sparse_bsr",
2287            "to_sparse_bsc",
2288        )
2289        to_params = ((), (), (), (), (2,), (2,))
2290        to_attr_names_params = dict(zip(to_attr_names, to_params))
2291
2292        def check_inversion_possible(
2293            t, layout1, layout1_params, layout2, layout2_params
2294        ):
2295            l = (layout1, layout2)
2296            p = (layout1_params, layout2_params)
2297            for l1, l2, p1, p2 in ((*l, *p), (*l[::-1], *p[::-1])):
2298                try:
2299                    to_l1 = getattr(t, l1)(*p1)
2300                    to_l2 = getattr(to_l1, l2)(*p2)
2301                except RuntimeError:
2302                    return False
2303
2304            return True
2305
2306        self_strided = torch.rand(4, 4, dtype=torch.double) + 1
2307        grad_strided = torch.rand(4, 4, dtype=torch.double) + 1
2308
2309        for from_to_attr in to_attr_names:
2310            from_params = to_attr_names_params[from_to_attr]
2311            self_from = getattr(self_strided, from_to_attr)(
2312                *from_params
2313            ).requires_grad_(True)
2314
2315            for to_to_attr in to_attr_names[1:]:
2316                to_params = to_attr_names_params[to_to_attr]
2317
2318                if check_inversion_possible(
2319                    self_strided, from_to_attr, from_params, to_to_attr, to_params
2320                ):
2321                    self_to = getattr(self_from, to_to_attr)(*to_params)
2322                    grad_to = getattr(grad_strided, to_to_attr)(*to_params)
2323
2324                    # No gradcheck support for BSR/BSC, so the grads are checked explicitly
2325                    grad_res = torch.autograd.grad(self_to, self_from, grad_to)[0]
2326
2327                    self.assertEqual(grad_res.layout, self_from.layout)
2328                    self.assertEqual(grad_res.to_dense(), grad_strided)
2329
2330    def test_sparse_mm_backward(self):
2331        size = (3, 3)
2332
2333        mm_test_cases = product(*(([False, True],) * 4))
2334
2335        for a_req_grad, a_is_sparse, b_req_grad, b_is_sparse in mm_test_cases:
2336            # We should only be testing cases with sparse inputs, and at least one
2337            # input needs to require grad so we can call a backward pass
2338            if not ((a_is_sparse or b_is_sparse) and (a_req_grad or b_req_grad)):
2339                continue
2340            a = torch.randn(size)
2341            if a_is_sparse:
2342                # detaching as `a` needs to be a leaf
2343                a = a.to_sparse().detach()
2344            b = torch.randn(size)
2345            if b_is_sparse:
2346                # detaching as `b` needs to be a leaf
2347                b = b.to_sparse().detach()
2348
2349            a = a.requires_grad_(a_req_grad)
2350            b = b.requires_grad_(b_req_grad)
2351
2352            r = a.mm(b)
2353            s = r.sum().backward()
2354            a_grad = None if a.grad is None else a.grad.clone().detach()
2355            b_grad = None if b.grad is None else b.grad.clone().detach()
2356
2357            # Redo with only dense tensors
2358            a = (
2359                (a.to_dense() if a.is_sparse else a)
2360                .clone()
2361                .detach()
2362                .requires_grad_(a_req_grad)
2363            )
2364            b = (
2365                (b.to_dense() if b.is_sparse else b)
2366                .clone()
2367                .detach()
2368                .requires_grad_(b_req_grad)
2369            )
2370
2371            r = a.mm(b)
2372            r.sum().backward()
2373
2374            self.assertEqual(a_grad, a.grad)
2375            self.assertEqual(b_grad, b.grad)
2376
2377    def test_multi_backward(self):
2378        x = torch.randn(5, 5, requires_grad=True)
2379        y = torch.randn(5, 5, requires_grad=True)
2380
2381        q = torch.randn(5, 5, requires_grad=True)
2382
2383        a = torch.randn(5, 5, requires_grad=True)
2384        b = torch.randn(5, 5, requires_grad=True)
2385
2386        q2 = q * 2
2387        z = x + y + q2
2388        c = a * b + q2
2389        grad_z = torch.randn(5, 5)
2390        grad_c = torch.randn(5, 5)
2391        torch.autograd.backward([z, c], [grad_z, grad_c])
2392
2393        self.assertEqual(x.grad, grad_z)
2394        self.assertEqual(y.grad, grad_z)
2395        self.assertEqual(a.grad, grad_c * b)
2396        self.assertEqual(b.grad, grad_c * a)
2397        self.assertEqual(q.grad, (grad_c + grad_z) * 2)
2398
2399    def test_multi_backward_no_grad(self):
2400        x = torch.randn(5, 5, requires_grad=True)
2401        y = torch.randn(5, 5, requires_grad=False)
2402
2403        z = x + y
2404        q = y * 2
2405
2406        # NB: we currently raise an exception if any arguments to backwards
2407        # have requires_grad=False and don't have a grad_fn. We may want to
2408        # relax that check to a warning.
2409        def call_backwards():
2410            torch.autograd.backward([z, q], [torch.ones(5, 5), torch.ones(5, 5)])
2411
2412        self.assertRaises(RuntimeError, call_backwards)
2413
2414    def test_backward_with_inputs(self):
2415        x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
2416        y = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
2417
2418        def fn():
2419            return x**2 + y * x + y**2
2420
2421        gradient = torch.ones(2, 2)
2422        x_grad_expected = 2 * x + y
2423        y_grad_expected = x + 2 * y
2424
2425        @torch.no_grad()
2426        def reset_grad():
2427            x.grad.zero_()
2428            y.grad.zero_()
2429
2430        torch.autograd.backward(fn(), gradient, inputs=[x, y])
2431        self.assertEqual(x.grad, x_grad_expected)
2432        self.assertEqual(y.grad, y_grad_expected)
2433
2434        reset_grad()
2435        torch.autograd.backward(fn(), gradient, inputs=[x])
2436        self.assertEqual(x.grad, x_grad_expected)
2437        self.assertEqual(y.grad, torch.zeros(2, 2), exact_dtype=False)
2438
2439        reset_grad()
2440        torch.autograd.backward(fn(), gradient, inputs=[y])
2441        self.assertEqual(y.grad, y_grad_expected)
2442        self.assertEqual(x.grad, torch.zeros(2, 2), exact_dtype=False)
2443
2444        reset_grad()
2445        torch.autograd.backward(fn(), gradient, inputs=y)
2446        self.assertEqual(y.grad, y_grad_expected)
2447        self.assertEqual(x.grad, torch.zeros(2, 2), exact_dtype=False)
2448
2449        reset_grad()
2450        self.assertRaisesRegex(
2451            RuntimeError,
2452            "cannot be empty",
2453            lambda: torch.autograd.backward(fn(), gradient, inputs=[]),
2454        )
2455
2456    def test_backward_with_nonleaf_inputs(self):
2457        x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
2458        x_nonleaf = x * 1
2459        y = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
2460        z = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
2461
2462        out = x_nonleaf**2 + y * x_nonleaf + y**2
2463
2464        out.backward(
2465            torch.ones(2, 2, dtype=torch.double),
2466            create_graph=True,
2467            inputs=[x, y, x_nonleaf],
2468        )
2469        x_grad_expected = 2 * x + y
2470        y_grad_expected = x + 2 * y
2471        x_non_leaf_expected = 2 * x_nonleaf + y
2472
2473        self.assertEqual(y.grad, y_grad_expected)
2474        self.assertEqual(x.grad, x_grad_expected)
2475        self.assertEqual(x_nonleaf.grad, x_non_leaf_expected)
2476
2477        # backward doesn't have an allow_unused flag, so the behavior of backward
2478        # when variable is not part of the graph is as if allow_used were true
2479        # x.grad will simply be None.
2480        out.backward(
2481            torch.ones(2, 2, dtype=torch.double), create_graph=True, inputs=[z]
2482        )
2483        self.assertIsNone(z.grad)
2484
2485    def test_dependent_backward(self):
2486        x = torch.randn(10, requires_grad=True)
2487        y = x**2
2488        z = y**3
2489
2490        go_y = torch.randn(10)
2491        go_z = torch.randn(10)
2492        torch.autograd.backward([y, z], [go_y, go_z])
2493
2494        xd = x
2495        self.assertEqual(x.grad, 2 * xd * go_y + 6 * xd.pow(5) * go_z)
2496
2497    def test_save_output_nr(self):
2498        x = torch.randn(10, requires_grad=True)
2499
2500        class MultiOutputFn(Function):
2501            @staticmethod
2502            def forward(ctx, x):
2503                return x[:5], x[5:]
2504
2505            @staticmethod
2506            def backward(ctx, *grad):
2507                return torch.cat(grad)
2508
2509        a, b = MultiOutputFn.apply(x)
2510        self.assertEqual(b.output_nr, 1)
2511
2512        class TestFn(Function):
2513            @staticmethod
2514            def forward(ctx, b):
2515                ctx.save_for_backward(b)
2516                return b * 2
2517
2518            @staticmethod
2519            def backward(ctx, grad_b):
2520                (b,) = ctx.saved_tensors
2521                self.assertEqual(b.output_nr, 1)
2522
2523        TestFn.apply(b).sum().backward()
2524
2525    def test_first_grad_fn_access_in_no_grad_mode(self):
2526        a = torch.tensor([1 + 1j], requires_grad=True).clone()
2527        v = a.real
2528        a.add_(1)
2529        with torch.autograd.grad_mode.no_grad():
2530            v.grad_fn
2531
2532    @skipIfTorchDynamo("too slow")
2533    def test_free_deep_graph(self):
2534        def scope():
2535            depth = 150000
2536            x = torch.randn(1, requires_grad=True)
2537            y = x.clone()
2538
2539            # build a "chain" computation graph
2540            for _ in range(depth):
2541                y = y + y * 0.000001
2542
2543            # graph deletion occurs when the above locals go out of scope.
2544            # In this case `del y` will trigger it but it's easier to leave
2545            # it to Python to delete the locals.
2546
2547        # Should not stack overflow
2548        scope()
2549
2550    @skipIfTorchDynamo("too slow")
2551    def test_free_deep_graph_complicated(self):
2552        def scope():
2553            depth = 100000
2554            randchoice = torch.randint(2, [depth, 2])
2555            x = torch.randn(1, requires_grad=True)
2556            y = x.clone()
2557
2558            # Hold the two previous values
2559            prev_values = [None, None]
2560
2561            # Build a "chain with skip connections" graph
2562            for _ in range(depth):
2563                prev_tensors = [
2564                    tensor for tensor in prev_values[:-1] if tensor is not None
2565                ]
2566                prev_values.append(y)
2567                prev_values.pop(0)
2568
2569                # Definitely pick one tensor to add
2570                y += y * 0.000001
2571
2572                # Possibly add other tensors
2573                nprev = len(prev_tensors)
2574                if nprev == 2:
2575                    y += randchoice[depth].mul(torch.cat(prev_tensors)).sum()
2576
2577            # graph deletion occurs when the above locals go out of scope.
2578
2579        # Should not stack overflow
2580        scope()
2581
2582    @skipIfTorchDynamo("too slow")
2583    def test_free_deep_graph_pyfunction(self):
2584        class MyOp(Function):
2585            @staticmethod
2586            def forward(ctx, tensor1, tensor2):
2587                return tensor1 + tensor2
2588
2589            @staticmethod
2590            def backward(ctx, grad_output):
2591                return grad_output, grad_output
2592
2593        def scope():
2594            depth = 150000
2595            x = torch.randn(1, requires_grad=True)
2596            y = x.clone()
2597
2598            # build deeply nested computation graph
2599            for _ in range(depth):
2600                y = MyOp.apply(y, y)
2601
2602            # graph deletion occurs when the above locals go out of scope.
2603
2604        # Should not stack overflow
2605        scope()
2606
2607    def test_no_unnecessary_save(self):
2608        # If we kept x in the derivative Function of x * 2 we would
2609        # get an error in the backward that would complain that we've
2610        # modified x, which was needed for gradient computation.
2611        # Since we should elide unnecessary saves, this test should pass.
2612        mu = torch.ones(1, requires_grad=True)
2613        x = torch.empty(1)
2614        loss = 0
2615        for i in range(3):
2616            x.detach_()
2617            x.copy_(mu + i)
2618            ft = torch.tensor([float(i)])
2619            multiplied = x * ft
2620            s = multiplied.sum()
2621            loss += s
2622        loss.backward()
2623
2624    def test_no_grad(self):
2625        x = torch.ones(5, 5, requires_grad=True)
2626        y = torch.ones(5, 5) * 4
2627        with torch.no_grad():
2628            w = x + y
2629
2630        def adder(x, y):
2631            return x + y
2632
2633        adders = [torch.no_grad()(adder), torch.no_grad(adder)]
2634
2635        for adder in adders:
2636            z = adder(x, y)
2637
2638            self.assertFalse(w.requires_grad)
2639            self.assertRaises(RuntimeError, lambda: w.backward(torch.ones(5, 5)))
2640            self.assertIsNone(w.grad_fn)
2641            self.assertFalse(z.requires_grad)
2642            self.assertRaises(RuntimeError, lambda: z.backward(torch.ones(5, 5)))
2643            self.assertIsNone(z.grad_fn)
2644
2645        # test nested decorator and with-statement on no_grad
2646        with torch.no_grad():
2647            self.assertFalse(torch.is_grad_enabled())
2648            w = adder(x, y)
2649            self.assertFalse(torch.is_grad_enabled())
2650
2651    def test_enable_grad_decorator_no_paren(self):
2652        x = torch.ones(1, requires_grad=True)
2653
2654        @torch.enable_grad
2655        def doubler(x):
2656            return x * 2
2657
2658        with torch.no_grad():
2659            z = doubler(x)
2660        self.assertTrue(z.requires_grad)
2661
2662    def test_set_grad_generator_functions(self):
2663        @torch.no_grad()
2664        def gen_no_grad():
2665            for i in range(10):
2666                self.assertEqual(torch.is_grad_enabled(), False)
2667                yield i
2668
2669        with torch.enable_grad():
2670            for _ in gen_no_grad():
2671                self.assertEqual(torch.is_grad_enabled(), True)
2672
2673        @torch.enable_grad()
2674        def gen_enable_grad():
2675            for i in range(10):
2676                self.assertEqual(torch.is_grad_enabled(), True)
2677                yield i
2678
2679        with torch.no_grad():
2680            for _ in gen_enable_grad():
2681                self.assertEqual(torch.is_grad_enabled(), False)
2682
2683    def test_set_grad_generator_functions_recursive(self):
2684        # enable_grad_decorator_recursive and no_grad_decorator_recursive call each other
2685        # recursively, to ensure that the decorators preserve the caller's setting
2686        @torch.enable_grad()
2687        def enable_grad_decorator_recursive(depth):
2688            self.assertTrue(torch.is_grad_enabled())
2689            if depth > 0:
2690                no_grad_decorator_recursive(depth - 1)
2691                self.assertTrue(torch.is_grad_enabled())
2692
2693        @torch.no_grad()
2694        def no_grad_decorator_recursive(depth):
2695            self.assertFalse(torch.is_grad_enabled())
2696            if depth > 0:
2697                enable_grad_decorator_recursive(depth - 1)
2698                self.assertFalse(torch.is_grad_enabled())
2699
2700        # enable_grad_context_manager_recursive and no_grad_context_manager_recursive call
2701        # each other recursively, to ensure that the decorators preserve the caller's setting
2702        def enable_grad_context_manager_recursive(depth):
2703            with torch.enable_grad():
2704                self.assertTrue(torch.is_grad_enabled())
2705                if depth > 0:
2706                    no_grad_context_manager_recursive(depth - 1)
2707                    self.assertTrue(torch.is_grad_enabled())
2708
2709        def no_grad_context_manager_recursive(depth):
2710            with torch.no_grad():
2711                self.assertFalse(torch.is_grad_enabled())
2712                if depth > 0:
2713                    enable_grad_context_manager_recursive(depth - 1)
2714                    self.assertFalse(torch.is_grad_enabled())
2715
2716        with torch.enable_grad():
2717            self.assertTrue(torch.is_grad_enabled())
2718            enable_grad_decorator_recursive(10)
2719            self.assertTrue(torch.is_grad_enabled())
2720            enable_grad_context_manager_recursive(10)
2721            self.assertTrue(torch.is_grad_enabled())
2722
2723        with torch.no_grad():
2724            self.assertFalse(torch.is_grad_enabled())
2725            enable_grad_decorator_recursive(10)
2726            self.assertFalse(torch.is_grad_enabled())
2727            enable_grad_context_manager_recursive(10)
2728            self.assertFalse(torch.is_grad_enabled())
2729
2730    def test_set_grad_coroutines(self):
2731        @torch.no_grad()
2732        def coro_no_grad(n=10):
2733            self.assertFalse(torch.is_grad_enabled())
2734            for i in range(n):
2735                self.assertFalse(torch.is_grad_enabled())
2736                r = yield i
2737                self.assertFalse(torch.is_grad_enabled())
2738                self.assertEqual(i, r)
2739            self.assertFalse(torch.is_grad_enabled())
2740
2741        @torch.enable_grad()
2742        def coro_enable_grad(n=10):
2743            self.assertTrue(torch.is_grad_enabled())
2744            for i in range(n):
2745                self.assertTrue(torch.is_grad_enabled())
2746                r = yield i
2747                self.assertTrue(torch.is_grad_enabled())
2748                self.assertEqual(i, r)
2749            self.assertTrue(torch.is_grad_enabled())
2750
2751        with torch.enable_grad():
2752            self.assertTrue(torch.is_grad_enabled())
2753            coro, r = coro_no_grad(), None
2754            try:
2755                while True:
2756                    self.assertTrue(torch.is_grad_enabled())
2757                    r = coro.send(r)
2758                    self.assertTrue(torch.is_grad_enabled())
2759
2760            except StopIteration:
2761                pass
2762
2763        with torch.no_grad():
2764            self.assertFalse(torch.is_grad_enabled())
2765            coro, r = coro_enable_grad(), None
2766            try:
2767                while True:
2768                    self.assertFalse(torch.is_grad_enabled())
2769                    r = coro.send(r)
2770                    self.assertFalse(torch.is_grad_enabled())
2771
2772            except StopIteration:
2773                pass
2774
2775    def test_set_grad_coroutines_benign_exceptions(self):
2776        class RecoverableException(Exception):
2777            pass
2778
2779        @torch.no_grad()
2780        def coro_no_grad(n=10):
2781            has_raised = False
2782            for i in range(n):
2783                try:
2784                    self.assertFalse(torch.is_grad_enabled())
2785                    yield (-i if has_raised else i)
2786
2787                except RecoverableException:
2788                    self.assertFalse(torch.is_grad_enabled())
2789                    has_raised = True
2790
2791        @torch.enable_grad()
2792        def coro_enable_grad(n=10):
2793            has_raised = False
2794            for i in range(n):
2795                try:
2796                    self.assertTrue(torch.is_grad_enabled())
2797                    yield (-i if has_raised else i)
2798
2799                except RecoverableException:
2800                    self.assertTrue(torch.is_grad_enabled())
2801                    has_raised = True
2802
2803        with torch.enable_grad():
2804            coro = coro_no_grad()
2805            assert 0 == next(coro)
2806            try:
2807                while True:
2808                    r = coro.throw(RecoverableException)
2809                    self.assertLess(r, 0)
2810
2811            except StopIteration:
2812                pass
2813
2814        with torch.no_grad():
2815            coro = coro_enable_grad()
2816            assert 0 == next(coro)
2817            try:
2818                while True:
2819                    r = coro.throw(RecoverableException)
2820                    self.assertLess(r, 0)
2821
2822            except StopIteration:
2823                pass
2824
2825    def test_set_grad_coroutines_critical_exceptions(self):
2826        class UnrecoverableException(Exception):
2827            pass
2828
2829        class SecondaryException(Exception):
2830            pass
2831
2832        @torch.no_grad()
2833        def coro_no_grad(n=10):
2834            has_raised = False
2835            for i in range(n):
2836                try:
2837                    self.assertFalse(torch.is_grad_enabled())
2838                    yield (-i if has_raised else i)
2839
2840                except UnrecoverableException:
2841                    self.assertFalse(torch.is_grad_enabled())
2842                    raise SecondaryException from None
2843
2844        @torch.enable_grad()
2845        def coro_enable_grad(n=10):
2846            has_raised = False
2847            for i in range(n):
2848                try:
2849                    self.assertTrue(torch.is_grad_enabled())
2850                    yield (-i if has_raised else i)
2851
2852                except UnrecoverableException:
2853                    self.assertTrue(torch.is_grad_enabled())
2854                    raise SecondaryException from None
2855
2856        with torch.enable_grad():
2857            coro = coro_no_grad()
2858            assert 0 == next(coro)
2859            with self.assertRaises(SecondaryException):
2860                coro.throw(UnrecoverableException)
2861
2862        with torch.no_grad():
2863            coro = coro_enable_grad()
2864            assert 0 == next(coro)
2865            with self.assertRaises(SecondaryException):
2866                coro.throw(UnrecoverableException)
2867
2868    def test_set_grad_coroutines_exit(self):
2869        @torch.no_grad()
2870        def coro_no_grad(state):
2871            for i in range(10):
2872                try:
2873                    self.assertFalse(torch.is_grad_enabled())
2874                    yield i
2875
2876                except GeneratorExit:
2877                    self.assertFalse(torch.is_grad_enabled())
2878                    state.add("GeneratorExit")
2879                    raise
2880
2881        @torch.enable_grad()
2882        def coro_enable_grad(state):
2883            for i in range(10):
2884                try:
2885                    self.assertTrue(torch.is_grad_enabled())
2886                    yield i
2887
2888                except GeneratorExit:
2889                    self.assertTrue(torch.is_grad_enabled())
2890                    state.add("GeneratorExit")
2891                    raise
2892
2893        state = set()
2894        with torch.enable_grad():
2895            coro = coro_no_grad(state)
2896            for i in range(5):
2897                next(coro)
2898
2899            coro.close()
2900        self.assertTrue("GeneratorExit" in state)
2901
2902        state = set()
2903        with torch.no_grad():
2904            coro = coro_enable_grad(state)
2905            for i in range(5):
2906                next(coro)
2907
2908            coro.close()
2909        self.assertTrue("GeneratorExit" in state)
2910
2911    def test_no_grad_python_function(self):
2912        """Python Functions should respect grad mode."""
2913        x = torch.ones(5, 5, requires_grad=True)
2914
2915        class MyOp(Function):
2916            @staticmethod
2917            def forward(self, x):
2918                return x + 1
2919
2920            @staticmethod
2921            def backward(self, dy):
2922                return dy
2923
2924        with torch.no_grad():
2925            y = MyOp.apply(x)
2926        self.assertFalse(y.requires_grad)
2927
2928    def test_indexing(self):
2929        x = torch.arange(1.0, 17).view(4, 4)
2930        y = Variable(x, requires_grad=True)
2931
2932        def compare(x, y, idx, indexed_tensor, indexed_var):
2933            indexed_var_t = indexed_var.data
2934            if not isinstance(indexed_tensor, torch.Tensor):
2935                indexed_var_t = indexed_var_t[0]
2936            self.assertEqual(indexed_tensor, indexed_var_t)
2937
2938            indexed_var.sum().backward()
2939            expected_grad = torch.empty(x.size()).fill_(0)
2940            expected_grad[idx] = 1
2941            self.assertEqual(y.grad, expected_grad)
2942
2943        def check_index(x, y, idx):
2944            if y.grad is not None:
2945                with torch.no_grad():
2946                    y.grad.zero_()
2947            indexed_tensor = x[idx]
2948            indexed_var = y[idx]
2949            compare(x, y, idx, indexed_tensor, indexed_var)
2950
2951        check_index(x, y, 1)
2952        check_index(x, y, (1, 1))
2953        check_index(x, y, slice(1, None))
2954        check_index(x, y, slice(None, 2))
2955        check_index(x, y, (slice(None, 2), 2))
2956        check_index(x, y, (slice(1, 2), 2))
2957        check_index(x, y, (1, slice(2, None)))
2958        check_index(x, y, (slice(None, None), slice(2, None)))
2959        check_index(x, y, torch.LongTensor([0, 2]))
2960        check_index(x, y, torch.rand(4, 4).bernoulli().bool())
2961        check_index(x, y, (Ellipsis, slice(2, None)))
2962        check_index(x, y, ([0], [0]))
2963        check_index(x, y, ([1, 2, 3], [0]))
2964        check_index(x, y, ([1, 2], [2, 1]))
2965        check_index(x, y, ([[1, 2], [3, 0]], [[0, 1], [2, 3]]))
2966        check_index(x, y, ([slice(None), [2, 3]]))
2967        check_index(x, y, ([[2, 3], slice(None)]))
2968
2969        # advanced indexing, with less dim, or ellipsis
2970        check_index(x, y, ([0]))
2971        check_index(x, y, ([0],))
2972
2973        x = torch.arange(1.0, 49).view(4, 3, 4)
2974        y = Variable(x, requires_grad=True)
2975
2976        check_index(x, y, (slice(None), [0], [0]))
2977        check_index(x, y, ([0], [0], slice(None)))
2978        check_index(x, y, (slice(None), [0, 1, 2], [0]))
2979        check_index(x, y, ([0, 1, 2], [0], slice(None)))
2980        check_index(x, y, (slice(None), [1, 2], [2, 1]))
2981        check_index(x, y, ([1, 2], [2, 1], slice(None)))
2982        check_index(x, y, (slice(None), [[1, 2], [2, 0]], [[0, 1], [2, 3]]))
2983        check_index(x, y, ([[1, 2], [3, 0]], [[0, 1], [2, 2]], slice(None)))
2984        check_index(x, y, (slice(None), slice(None), [2, 1]))
2985        check_index(x, y, (slice(None), [2, 1], slice(None)))
2986        check_index(x, y, ([2, 1], slice(None), slice(None)))
2987
2988        # advanced indexing, with less dim, or ellipsis
2989        check_index(x, y, ([0],))
2990        check_index(x, y, ([0], slice(None)))
2991        check_index(x, y, ([0], Ellipsis))
2992        check_index(x, y, ([1, 2], [0, 1]))
2993        check_index(x, y, ([1, 2], [0, 1], Ellipsis))
2994        check_index(x, y, (Ellipsis, [1, 2], [0, 1]))
2995
2996        # advanced indexing, with a tensor wrapped in a variable
2997        z = torch.LongTensor([0, 1])
2998        zv = Variable(z, requires_grad=False)
2999        seq = [z, Ellipsis]
3000        seqv = [zv, Ellipsis]
3001
3002        if y.grad is not None:
3003            with torch.no_grad():
3004                y.grad.zero_()
3005        indexed_tensor = x[seq]
3006        indexed_var = y[seqv]
3007        compare(x, y, seq, indexed_tensor, indexed_var)
3008
3009    def test_indexing_duplicates(self):
3010        x = torch.arange(1.0, 17).view(4, 4)
3011        y = Variable(x, requires_grad=True)
3012
3013        idx = torch.LongTensor([1, 1, 3, 2, 1, 2])
3014        y[idx].sum().backward()
3015        expected_grad = torch.zeros(4, 4)
3016        for i in idx:
3017            expected_grad[i] += 1
3018        self.assertEqual(y.grad, expected_grad)
3019
3020        # with advanced indexing
3021        x = torch.arange(1.0, 17).view(4, 4)
3022        y = Variable(x, requires_grad=True)
3023
3024        idx = [[1, 1, 3, 2, 1, 2], [0]]
3025        y[idx].sum().backward()
3026        expected_grad = torch.zeros(4, 4)
3027        for i in idx[0]:
3028            for j in idx[1]:
3029                expected_grad[i][j] += 1
3030
3031        self.assertEqual(y.grad, expected_grad)
3032
3033        x = torch.arange(1.0, 17).view(4, 4)
3034        y = Variable(x, requires_grad=True)
3035        idx = [[[1, 2], [0, 0]], [[0, 1], [1, 1]]]
3036        y[idx].sum().backward()
3037        expected_grad = torch.tensor(
3038            [
3039                [0.0, 2.0, 0.0, 0.0],
3040                [1.0, 0.0, 0.0, 0.0],
3041                [0.0, 1.0, 0.0, 0.0],
3042                [0.0, 0.0, 0.0, 0.0],
3043            ]
3044        )
3045        self.assertEqual(y.grad, expected_grad)
3046
3047        x = torch.arange(1.0, 65).view(4, 4, 4)
3048        y = Variable(x, requires_grad=True)
3049
3050        idx = [[1, 1, 1], slice(None), slice(None)]
3051        y[idx].sum().backward()
3052        expected_grad = torch.empty(4, 4, 4).zero_()
3053        expected_grad[1].fill_(3)
3054        self.assertEqual(y.grad, expected_grad)
3055
3056    def test_index_backward_does_not_save_tensor(self):
3057        # Example from https://github.com/pytorch/pytorch/issues/24853.
3058        # if `index(tensor, indices)` saves `tensor` for backwards, then it will
3059        # trigger a version check on `tensor` during the backward pass, which
3060        # will cause the following code to error because `tensor` gets modified
3061        # by the indexing line.
3062        a = torch.tensor([1.0, 0, 0])
3063        b = torch.zeros(3, requires_grad=True)
3064        tensor = b + 0
3065        tensor[a != 0] = tensor[a != 0]
3066        tensor.backward(torch.zeros_like(tensor))
3067
3068    def test_volatile_deprecated(self):
3069        v = torch.autograd.torch.randn(3, 3)
3070        with warnings.catch_warnings(record=True) as w:
3071            self.assertFalse(v.volatile)
3072        self.assertIn("volatile", str(w[0].message))
3073
3074    def test_saved_variables_deprecated(self):
3075        class MyFunction(Function):
3076            @staticmethod
3077            def forward(ctx, tensor1, tensor2):
3078                ctx.save_for_backward(tensor1, tensor2)
3079                return tensor1 + tensor2
3080
3081            @staticmethod
3082            def backward(ctx, grad_output):
3083                var1, var2 = ctx.saved_variables
3084                return (grad_output, grad_output)
3085
3086        with warnings.catch_warnings(record=True) as warns:
3087            warnings.simplefilter("always")
3088            x = torch.randn((3, 3), requires_grad=True)
3089            y = torch.randn((3, 3), requires_grad=True)
3090            MyFunction.apply(x, y).sum().backward()
3091
3092            has_deprecated = (
3093                "deprecated" in str(warn) and "saved_variables" in str(warn)
3094                for warn in warns
3095            )
3096            has_deprecated = reduce(lambda x, y: x or y, has_deprecated)
3097            self.assertTrue(has_deprecated)
3098
3099    def test_requires_grad(self):
3100        x = torch.randn(5, 5)
3101        y = torch.randn(5, 5)
3102        z = torch.randn(5, 5, requires_grad=True)
3103        a = x + y
3104        self.assertFalse(a.requires_grad)
3105        b = a + z
3106        self.assertTrue(b.requires_grad)
3107
3108        def error():
3109            raise RuntimeError
3110
3111        # Make sure backward isn't called on these
3112        a._backward_hooks = OrderedDict()
3113        x._backward_hooks = OrderedDict()
3114        y._backward_hooks = OrderedDict()
3115        a._backward_hooks["test"] = error
3116        x._backward_hooks["test"] = error
3117        y._backward_hooks["test"] = error
3118        b.backward(torch.ones(5, 5))
3119
3120    def test_requires_grad_(self):
3121        x = torch.randn(5, 5)
3122        y = torch.randn(5, 5, requires_grad=True)
3123        self.assertIs(x, x.requires_grad_())
3124        self.assertTrue(x.requires_grad)
3125        self.assertIs(y, y.requires_grad_())
3126        self.assertTrue(y.requires_grad)
3127        self.assertIs(x, x.requires_grad_(True))
3128        self.assertTrue(x.requires_grad)
3129        self.assertIs(y, y.requires_grad_(True))
3130        self.assertTrue(y.requires_grad)
3131        z = x * y
3132        self.assertRaises(RuntimeError, lambda: z.requires_grad_(False))
3133        self.assertIs(z, z.requires_grad_())
3134        self.assertTrue(z.requires_grad)
3135        self.assertIs(z, z.requires_grad_(True))
3136        self.assertTrue(z.requires_grad)
3137
3138        self.assertIs(x, x.requires_grad_(False))
3139        self.assertFalse(x.requires_grad)
3140        self.assertIs(y, y.requires_grad_(False))
3141        self.assertFalse(y.requires_grad)
3142
3143    def test_requires_grad_inplace(self):
3144        a = torch.randn(5, 5)
3145        b = torch.randn(5, 5, requires_grad=True)
3146        a += b
3147        self.assertTrue(a.requires_grad)
3148
3149        # non-leaf
3150        a = torch.randn(5, 5) + 0
3151        b = torch.randn(5, 5, requires_grad=True)
3152        a += b
3153        self.assertTrue(a.requires_grad)
3154
3155    def test_no_requires_grad_inplace(self):
3156        # basic case, should be able to modify inplace while requires_grad is False
3157        a = torch.randn(2, 3)
3158        a.add_(5)
3159        a.requires_grad = True
3160        a.sum().backward()
3161        self.assertEqual(a.grad, torch.ones(2, 3))
3162
3163        # same but with a view
3164        a = torch.randn(2, 3)
3165        b = a[:]
3166        b.add_(5)
3167        a.requires_grad = True
3168        a.sum().backward()
3169        self.assertEqual(a.grad, torch.ones(2, 3))
3170
3171        # should fail if requires_grad = True when we modify inplace
3172        a = torch.randn(2, 3)
3173        b = a[:]
3174        a.requires_grad = True
3175        with self.assertRaises(RuntimeError):
3176            a.add_(5)
3177        with self.assertRaises(RuntimeError):
3178            b.add_(5)
3179
3180    def test_attribute_deletion(self):
3181        x = torch.randn((5, 5), requires_grad=True)
3182        del x.grad
3183        self.assertIsNone(x.grad)
3184        with self.assertRaises(RuntimeError):
3185            del x.data
3186        with self.assertRaises(TypeError):
3187            x.data = None
3188        with self.assertRaises(RuntimeError):
3189            del x.requires_grad
3190        with self.assertRaises(RuntimeError):
3191            del x._grad_fn
3192        with self.assertRaises(RuntimeError):
3193            del x._backward_hooks
3194
3195    def test_duplicate_backward_root(self):
3196        a = torch.randn(5, 5, requires_grad=True)
3197        b = torch.randn(5, 5, requires_grad=True)
3198
3199        x = a * b
3200        grad_output = torch.randn_like(x)
3201        torch.autograd.backward([x, x], [grad_output, grad_output])
3202
3203        self.assertEqual(a.grad, b * grad_output * 2)
3204        self.assertEqual(b.grad, a * grad_output * 2)
3205
3206    def test_backward_no_grad(self):
3207        a = torch.randn(5, 5, requires_grad=True)
3208        b = a + 2
3209        with self.assertRaises(RuntimeError):
3210            torch.autograd.backward([b], [None])
3211
3212    def test_backward_twice_with_saved_values(self):
3213        b = torch.randn(3, requires_grad=True, dtype=torch.double)
3214        c = torch.zeros(3, dtype=torch.double)
3215        c[[1, 2]] = b[[1, 1]]
3216        c.backward(torch.tensor([1, 1, 1], dtype=torch.double))
3217        self.assertRaisesRegex(
3218            RuntimeError,
3219            "Specify retain_graph=True",
3220            lambda: c.backward(torch.tensor([1, 1, 1], dtype=torch.double)),
3221        )
3222
3223    def test_backward_twice_retained_graph_with_saved_values(self):
3224        b = torch.randn(3, requires_grad=True, dtype=torch.double)
3225        c = torch.zeros(3, dtype=torch.double)
3226        c[[1, 2]] = b[[1, 1]]
3227        c.backward(torch.tensor([1, 1, 1], dtype=torch.double), retain_graph=True)
3228        c.backward(torch.tensor([1, 1, 1], dtype=torch.double))
3229
3230    def test_backward_twice_without_saved_values(self):
3231        b = torch.randn(3, requires_grad=True, dtype=torch.double)
3232        c = b + 1
3233        c.backward(torch.tensor([1, 1, 1], dtype=torch.double))
3234        c.backward(torch.tensor([1, 1, 1], dtype=torch.double))
3235
3236    def test_backward_twice_retained_graph_without_saved_values(self):
3237        b = torch.randn(3, requires_grad=True, dtype=torch.double)
3238        c = torch.zeros(3, dtype=torch.double)
3239        c[[1, 2]] = b[[1, 1]]
3240        c.backward(torch.tensor([1, 1, 1], dtype=torch.double), retain_graph=True)
3241        c.backward(torch.tensor([1, 1, 1], dtype=torch.double))
3242
3243    def test_backward_create_graph_warns(self):
3244        with set_warn_always_context(True):
3245            b = torch.randn(3, requires_grad=True, dtype=torch.double)
3246            c = b * b
3247            with warnings.catch_warnings(record=True) as ws:
3248                c.backward(torch.ones_like(c), create_graph=True)
3249            b.grad = None
3250            self.assertTrue(
3251                any(
3252                    "Using backward() with create_graph=True" in str(w.message)
3253                    for w in ws
3254                )
3255            )
3256
3257            # Should not warn for grad
3258            with warnings.catch_warnings(record=True) as ws:
3259                torch.autograd.grad(c, b, torch.ones_like(c), create_graph=True)
3260            self.assertFalse(
3261                any(
3262                    "Using backward() with create_graph=True" in str(w.message)
3263                    for w in ws
3264                )
3265            )
3266
3267    def test_next_functions(self):
3268        x = torch.randn(5, 5, requires_grad=True)
3269        y = torch.randn(5, 5, requires_grad=True)
3270
3271        a = x + y
3272        self.assertIsNotNone(a.grad_fn)
3273        next_functions = a.grad_fn.next_functions
3274        self.assertEqual(len(next_functions), 2)
3275        self.assertIsInstance(next_functions[0][0], torch._C._functions.AccumulateGrad)
3276        self.assertEqual(next_functions[0][1], 0)
3277        self.assertIsInstance(next_functions[1][0], torch._C._functions.AccumulateGrad)
3278        self.assertEqual(next_functions[1][1], 0)
3279
3280        b = a + 5
3281        next_functions = b.grad_fn.next_functions
3282        self.assertEqual(len(next_functions), 2)
3283        self.assertIs(next_functions[0][0], a.grad_fn)
3284        self.assertIs(next_functions[1][0], None)
3285
3286    def test_inplace(self):
3287        x = torch.ones(5, 5, requires_grad=True)
3288        y = Variable(torch.ones(5, 5) * 4, requires_grad=True)
3289
3290        z = x * y
3291        q = z + y
3292        w = z * y
3293        z.add_(2)
3294        # Add doesn't need it's inputs to do backward, so it shouldn't raise
3295        q.backward(torch.ones(5, 5), retain_graph=True)
3296        # Mul saves both inputs in forward, so it should raise
3297        self.assertRaises(RuntimeError, lambda: w.backward(torch.ones(5, 5)))
3298
3299        z = x * y
3300        q = z * y
3301        r = z + y
3302        w = z.add_(y)
3303        # w is a the last expression, so this should succeed
3304        w.backward(torch.ones(5, 5), retain_graph=True)
3305        # r doesn't use the modified value in backward, so it should succeed
3306        r.backward(torch.ones(5, 5), retain_graph=True)
3307        # q uses dirty z, so it should raise
3308        self.assertRaises(RuntimeError, lambda: q.backward(torch.ones(5, 5)))
3309
3310        with torch.no_grad():
3311            x.grad.zero_()
3312        m = x / 2
3313        z = m + y / 8
3314        q = z * y
3315        r = z + y
3316        prev_version = z._version
3317        w = z.exp_()
3318        self.assertNotEqual(z._version, prev_version)
3319        r.backward(torch.ones(5, 5), retain_graph=True)
3320        self.assertEqual(x.grad, torch.ones(5, 5) / 2)
3321        w.backward(torch.ones(5, 5), retain_graph=True)
3322        self.assertEqual(x.grad, torch.empty(5, 5).fill_((1 + math.e) / 2))
3323        self.assertRaises(RuntimeError, lambda: q.backward(torch.ones(5, 5)))
3324
3325        leaf = torch.ones(5, 5, requires_grad=True)
3326        x = leaf.clone()
3327        x.add_(10)
3328        self.assertEqual(x, torch.ones(5, 5) * 11)
3329        # x should be still usable
3330        y = x + 2
3331        y.backward(torch.ones(5, 5))
3332        self.assertEqual(leaf.grad, torch.ones(5, 5))
3333        z = x * y
3334        x.add_(2)
3335        self.assertRaises(RuntimeError, lambda: z.backward(torch.ones(5, 5)))
3336
3337    def test_mark_non_differentiable(self):
3338        class MyFunction(Function):
3339            @staticmethod
3340            def forward(ctx, input):
3341                output = input > 0
3342                ctx.mark_non_differentiable(output)
3343                return output
3344
3345            @staticmethod
3346            def backward(ctx, grad_output):
3347                return (grad_output * 0).to(torch.double)
3348
3349        x = torch.randn(5, 5, requires_grad=True)
3350        mask = MyFunction.apply(x)
3351        self.assertFalse(mask.requires_grad)
3352        y = x.masked_fill(mask, 0)
3353        y.sum().backward()
3354
3355    @skipIfTorchDynamo("compile tested in test/dynamo/test_autograd_function.py")
3356    def test_mark_non_differentiable_mixed(self):
3357        class MyFunction(Function):
3358            @staticmethod
3359            def forward(ctx, input):
3360                a = input + 1
3361                b = input + 2
3362                ctx.mark_non_differentiable(a)
3363                return a, b
3364
3365            @staticmethod
3366            def backward(ctx, grad_a, grad_b):
3367                self.assertTrue((grad_a == 0).all())
3368                self.assertTrue((grad_b == 1).all())
3369                return grad_b
3370
3371        x = torch.randn(5, 5, requires_grad=True)
3372        a, b = MyFunction.apply(x)
3373        self.assertFalse(a.requires_grad)
3374        self.assertTrue(b.requires_grad)
3375        b.sum().backward()
3376        self.assertEqual(x.grad, torch.ones(5, 5))
3377
3378    def test_mark_non_differentiable_none(self):
3379        # This used to segfault because MyFunction would send back null
3380        # gradients to MulBackward, which is implemented in C++. C++
3381        # implemented functions expect incoming grad_outputs to be non-null.
3382        class MyFunction(Function):
3383            @staticmethod
3384            def forward(ctx, input):
3385                output = input.clone()
3386                ctx.mark_non_differentiable(output)
3387                return output
3388
3389            @staticmethod
3390            def backward(ctx, grad_output):
3391                return None
3392
3393        x = torch.randn(5, 5, requires_grad=True)
3394        r = MyFunction.apply(x * x)
3395        (r * x).sum().backward()
3396
3397    def test_return_duplicate(self):
3398        class DoubleDuplicate(Function):
3399            @staticmethod
3400            def forward(ctx, x):
3401                output = x * 2
3402                return output, output
3403
3404            @staticmethod
3405            def backward(ctx, grad1, grad2):
3406                return grad1 * 2 + grad2 * 2
3407
3408        def fn(x):
3409            a, b = DoubleDuplicate.apply(x)
3410            self.assertIs(a, b)
3411            return a + b
3412
3413        x = torch.randn(5, 5, dtype=torch.double, requires_grad=True)
3414        gradcheck(fn, [x])
3415        gradgradcheck(fn, [x])
3416
3417    def test_return_duplicate_inplace(self):
3418        class DoubleInplace(Function):
3419            @staticmethod
3420            def forward(ctx, x):
3421                x.mul_(2)
3422                ctx.mark_dirty(x)
3423                return x, x
3424
3425            @staticmethod
3426            def backward(ctx, grad1, grad2):
3427                return grad1 * 2 + grad2 * 2
3428
3429        def inplace_fn(x):
3430            a, b = DoubleInplace.apply(x.clone())
3431            self.assertIs(a, b)
3432            return a + b
3433
3434        x = torch.randn(5, 5, dtype=torch.double, requires_grad=True)
3435        gradcheck(inplace_fn, [x])
3436        gradgradcheck(inplace_fn, [x])
3437
3438        # Can't modify leaf variables in-place
3439        self.assertRaises(RuntimeError, lambda: InplaceFunction.apply(x))
3440        # Functions which modify views in-place must return only one output
3441        self.assertRaises(RuntimeError, lambda: InplaceFunction.apply(x.clone()[0]))
3442
3443    def _test_setitem(self, size, index):
3444        x = torch.ones(*size, requires_grad=True)
3445        y = x + 2
3446        y_version = y._version
3447        y[index] = 2
3448        self.assertNotEqual(y._version, y_version)
3449        y.backward(torch.ones(*size))
3450        expected_grad = torch.ones(*size)
3451        expected_grad[index] = 0
3452        self.assertEqual(x.grad, expected_grad)
3453
3454    def _test_setitem_tensor(self, size, index):
3455        x = torch.ones(*size, requires_grad=True)
3456        y = x + 2
3457        y_version = y._version
3458        value = x.new(x[index].size()).fill_(7)
3459        value.requires_grad = True
3460        y[index] = value
3461        self.assertNotEqual(y._version, y_version)
3462        y.backward(torch.ones(*size))
3463        expected_grad_input = torch.ones(*size)
3464        expected_grad_input[index] = 0
3465        self.assertEqual(x.grad, expected_grad_input)
3466        self.assertEqual(value.grad, torch.ones_like(value))
3467
3468        # case when x broadcasts to as y[1]
3469        x = torch.randn(4, requires_grad=True)
3470        y = torch.zeros(2, 3, 4)
3471        y[1] = x
3472        y.backward(torch.randn(2, 3, 4))
3473        self.assertEqual(x.size(), x.grad.size())
3474
3475    def test_setitem(self):
3476        self._test_setitem((5, 5), 1)
3477        self._test_setitem((5,), 1)
3478        self._test_setitem((1,), 0)
3479        self._test_setitem((10,), [[0, 4, 2]])
3480        self._test_setitem((5, 5), [[0, 4], [2, 2]])
3481        self._test_setitem((5, 5, 5), [slice(None), slice(None), [1, 3]])
3482        self._test_setitem((5, 5, 5), [slice(None), [1, 3], slice(None)])
3483        self._test_setitem((5, 5, 5), [[1, 3], slice(None), slice(None)])
3484        self._test_setitem((5, 5, 5), [slice(None), [2, 4], [1, 3]])
3485        self._test_setitem((5, 5, 5), [[1, 3], [2, 4], slice(None)])
3486        self._test_setitem_tensor((5, 5), 3)
3487        self._test_setitem_tensor((5, 5), [[0, 1], [1, 0]])
3488        self._test_setitem_tensor((5,), 3)
3489        self._test_setitem_tensor(
3490            (5,), Variable(torch.LongTensor([3]), requires_grad=False).sum()
3491        )
3492        self._test_setitem_tensor((5,), [[0, 1, 2, 3]])
3493        self._test_setitem_tensor((5, 5, 5), [slice(None), slice(None), [1, 3]])
3494        self._test_setitem_tensor((5, 5, 5), [slice(None), [1, 3], slice(None)])
3495        self._test_setitem_tensor((5, 5, 5), [[1, 3], slice(None), slice(None)])
3496        self._test_setitem_tensor((5, 5, 5), [slice(None), [2, 4], [1, 3]])
3497        self._test_setitem_tensor((5, 5, 5), [[1, 3], [2, 4], slice(None)])
3498        self._test_setitem_tensor(
3499            (5, 5, 5),
3500            [
3501                Variable(torch.LongTensor([1, 3]), requires_grad=False),
3502                [2, 4],
3503                slice(None),
3504            ],
3505        )
3506
3507    def test_setitem_mask(self):
3508        mask = torch.BoolTensor(5, 5).bernoulli_()
3509        self._test_setitem((5, 5), Variable(mask))
3510        self._test_setitem((5,), Variable(mask[0]))
3511        self._test_setitem((1,), Variable(mask[0, 0:1]))
3512        self._test_setitem_tensor((5, 5), Variable(mask))
3513        self._test_setitem_tensor((5,), Variable(mask[0]))
3514
3515    def test_select_sum(self):
3516        # both select and sum return Scalars in ATen; ensure they work together.
3517        x = torch.randn(10, dtype=torch.double, requires_grad=True)
3518
3519        def func(x):
3520            return x.select(0, 1).sum()
3521
3522        gradcheck(func, [x])
3523        gradgradcheck(func, [x])
3524
3525    def test_diagonal_expanded_v(self):
3526        value = torch.rand([])
3527        v_expanded = torch.tensor(value).expand(10)
3528        a = torch.rand(10, 10, dtype=torch.double, requires_grad=True)
3529        (result,) = torch.autograd.grad(a.diagonal(), a, v_expanded)
3530        self.assertEqual(result, torch.eye(10, dtype=torch.double) * value)
3531
3532    def test_select_expanded_v(self):
3533        v_expanded = torch.rand(10).expand(10, 10)
3534        a = torch.rand(10, 10, 10, requires_grad=True)
3535        (result,) = torch.autograd.grad(a[0], a, v_expanded)
3536        expected = torch.zeros(10, 10, 10)
3537        expected[0] = v_expanded
3538        self.assertEqual(result, expected)
3539
3540    def test_slice_expanded_v(self):
3541        v_expanded = torch.rand(10, 1).expand(2, 10, 10)
3542        a = torch.rand(10, 10, 10, requires_grad=True)
3543        (result,) = torch.autograd.grad(a[3:5], a, v_expanded)
3544        expected = torch.zeros(10, 10, 10)
3545        expected[3:5] = v_expanded
3546        self.assertEqual(result, expected)
3547
3548    def test_unused_output(self):
3549        x = torch.randn(10, 10, requires_grad=True)
3550        outputs = x.chunk(5)
3551        o = outputs[2]
3552        o = o * 4 + 2
3553        o.sum().backward()
3554        expected_grad = torch.zeros(10, 10)
3555        expected_grad[4:6] = 4
3556        self.assertEqual(x.grad, expected_grad)
3557
3558        with torch.no_grad():
3559            x.grad.zero_()
3560        grad_output = torch.randn(2, 10)
3561        outputs = x.chunk(5)
3562        outputs[0].backward(grad_output)
3563        expected_grad = torch.zeros(10, 10)
3564        expected_grad[:2] = grad_output
3565        self.assertEqual(x.grad, expected_grad)
3566
3567    # TODO: opinfo this or move to the sparse test suite
3568    def _test_sparse_gather(self, size_x, size_ind, dim):
3569        x = torch.randn(size_x, requires_grad=True)
3570        if len(size_ind) > 0 and len(size_x) > 0:
3571            ind = torch.randint(x.size(dim), size_ind)
3572        else:
3573            ind = torch.zeros(size_ind, dtype=torch.int64)
3574        out = torch.gather(x, dim, ind, sparse_grad=False)
3575        grad = torch.rand_like(out)
3576        out.backward(grad)
3577        grad_dense = x.grad.clone()
3578        x.grad = None
3579        out = torch.gather(x, dim, ind, sparse_grad=True)
3580        out.backward(grad)
3581        self.assertEqual(grad_dense, x.grad.to_dense())
3582
3583    def test_sparse_gather_dim0(self):
3584        self._test_sparse_gather((10, 10), (5, 10), 0)
3585
3586    def test_sparse_gather_dim1(self):
3587        self._test_sparse_gather((10, 10, 5), (10, 5, 5), 1)
3588
3589    def test_sparse_gather_dim_neg(self):
3590        self._test_sparse_gather((10, 10, 5), (10, 10, 2), -1)
3591
3592    def test_sparse_gather_ind_scalar(self):
3593        self._test_sparse_gather((10,), (), 0)
3594
3595    def test_sparse_gather_x_scalar(self):
3596        self._test_sparse_gather((), (2,), 0)
3597
3598    def test_sparse_gather_both_scalar(self):
3599        self._test_sparse_gather((), (), 0)
3600
3601    def test_gc_in_destructor(self):
3602        """
3603        Previously, if a Function destructor triggered a garbage collection,
3604        the Variable's tp_dealloc handler would get called twice leading to a
3605        segfault.
3606        """
3607
3608        class CollectOnDelete(Function):
3609            def forward(self, x):
3610                return x
3611
3612            def backward(self, grad_output):
3613                return grad_output
3614
3615            def __del__(self):
3616                gc.collect()
3617
3618        for _ in range(10):
3619            CollectOnDelete().forward(torch.randn(1, requires_grad=True)).backward()
3620
3621    def test_naughty_autograd_function_attribute_access(self):
3622        class Id(Function):
3623            @staticmethod
3624            def forward(ctx, x):
3625                return x
3626
3627            @staticmethod
3628            def backward(ctx, grad_x):
3629                return grad_x
3630
3631        with self.assertWarnsRegex(DeprecationWarning, "should not be instantiated"):
3632            f = Id()
3633
3634        # After raising warning, should still return an instance
3635        self.assertIsInstance(f, Id)
3636        x = torch.zeros(1, requires_grad=True)
3637        with self.assertRaisesRegex(
3638            RuntimeError, "non-static forward method is deprecated"
3639        ):
3640            f(x)
3641        t = Id.apply(x)
3642        self.assertEqual(t.grad_fn.name(), "IdBackward")
3643
3644        # THPFunction is the base class of both grad_fn and autograd functions,
3645        # which means that a lot of accessors on them may segfault. Test that we
3646        # properly error in this case.
3647        t = torch.ones(1, requires_grad=True)
3648        t._backward_hooks = {}
3649        with self.assertRaisesRegex(
3650            RuntimeError, "Attribute '_register_hook_dict' is invalid"
3651        ):
3652            f._register_hook_dict(t)
3653        with self.assertRaisesRegex(
3654            RuntimeError, "Attribute 'register_hook' is invalid"
3655        ):
3656            f.register_hook(lambda x, y: None)
3657        with self.assertRaisesRegex(
3658            RuntimeError, "Attribute 'next_functions' is invalid"
3659        ):
3660            f.next_functions
3661        with self.assertRaisesRegex(RuntimeError, "Attribute 'name' is invalid"):
3662            f.name()
3663        with self.assertRaisesRegex(
3664            RuntimeError, "underlying PyNode has already been deallocated"
3665        ):
3666            f.metadata
3667
3668    @unittest.expectedFailure
3669    def test_naughty_anomaly_access(self):
3670        class MyFunction(Function):
3671            @staticmethod
3672            def forward(ctx, x):
3673                return x
3674
3675            @staticmethod
3676            def backward(ctx, g):
3677                return g
3678
3679        x = torch.zeros(1, requires_grad=True)
3680        y = MyFunction.apply(x)
3681        y.backward()
3682        y.grad_fn.metadata
3683        g = y.grad_fn
3684        del y
3685        g.metadata  # this currently fails, but shouldn't
3686
3687    def test_naughty_autograd_function_stashing_ctx(self):
3688        saved_ctx = []
3689
3690        class Id(Function):
3691            @staticmethod
3692            def forward(ctx, x):
3693                ctx.save_for_backward(x)
3694                return x
3695
3696            @staticmethod
3697            def backward(ctx, grad_x):
3698                saved_ctx.append(ctx)
3699                return ctx.saved_tensors
3700
3701        p = torch.zeros(1, requires_grad=True)
3702        loss = Id.apply(p)
3703        loss.backward(retain_graph=True)
3704        del loss
3705        # At this point in time, it complains that the graph has been freed
3706        # (which indeed true, although a somewhat indirect way of stating the
3707        # problem).
3708        self.assertRaises(RuntimeError, lambda: saved_ctx[0].saved_tensors)
3709
3710    def test_custom_autograd_repeated_grad_grad(self):
3711        # This test failed the equality check in PR #22983; it's an interesting
3712        # and different test case worth enshrining.  mult1 is not testing
3713        # anything that interesting, but mult2 is the interesting case.
3714
3715        def mult1(x):
3716            return x.prod(dim=-1).prod(dim=-1)
3717
3718        class Mult(torch.autograd.Function):
3719            @staticmethod
3720            def forward(ctx, x):
3721                y = mult1(x)
3722                ctx.save_for_backward(x, y)
3723                return y
3724
3725            @staticmethod
3726            def backward(ctx, grad_output):
3727                x, y = ctx.saved_tensors
3728                return (grad_output * y)[:, None, None] / x
3729
3730        mult2 = Mult.apply
3731
3732        def check_gradgrad_repeated(x, y):
3733            (gy,) = torch.autograd.grad(y[0], x, create_graph=True)
3734            (ggy_1,) = torch.autograd.grad(gy[0, 0, 0], x, retain_graph=True)
3735            (gy,) = torch.autograd.grad(y[0], x, create_graph=True)
3736            (ggy_2,) = torch.autograd.grad(gy[0, 0, 0], x, retain_graph=True)
3737            self.assertEqual(ggy_1[0, 0, 1], ggy_2[0, 0, 1])
3738
3739        x = torch.ones(2, 4, 4).requires_grad_()
3740        check_gradgrad_repeated(x, mult1(x))
3741        check_gradgrad_repeated(x, mult2(x))
3742
3743    def test_custom_autograd_no_early_free(self):
3744        # This test failed complaining that buffers had already been freed
3745        # prior to #22983.  Also pretty interesting test case.
3746        class Double(torch.autograd.Function):
3747            @staticmethod
3748            def forward(ctx, x):
3749                y = x**2
3750                ctx.save_for_backward(x, y)
3751                return y
3752
3753            @staticmethod
3754            def backward(ctx, grad_output):
3755                x, _ = ctx.saved_tensors
3756                return grad_output * 2 * x
3757
3758        # this is equivalent, but uses the output of .forward() in .backward()
3759        class Double2(Double):
3760            @staticmethod
3761            def backward(ctx, grad_output):
3762                x, y = ctx.saved_tensors
3763                return grad_output * 2 * y / x
3764
3765        double = Double.apply
3766        double2 = Double2.apply
3767
3768        x = torch.tensor(2).double().requires_grad_()
3769
3770        self.assertTrue(gradcheck(double, x))
3771        self.assertTrue(gradgradcheck(double, x))
3772        self.assertTrue(gradcheck(double2, x))
3773        self.assertTrue(gradgradcheck(double2, x))
3774
3775        y = double(x)
3776        torch.autograd.grad(y, x, create_graph=True)
3777        torch.autograd.grad(y, x)
3778
3779        y = double2(x)
3780        torch.autograd.grad(y, x, create_graph=True)
3781        torch.autograd.grad(y, x)  # should not error!
3782
3783    def test_detach(self):
3784        x = torch.randn(10, 10, requires_grad=True)
3785        y = x + 2
3786        y = y.detach()
3787        z = y * 4 + 2
3788        self.assertFalse(y.requires_grad)
3789        self.assertFalse(z.requires_grad)
3790
3791        x = torch.randn(10, 10, requires_grad=True)
3792        y = x * 2
3793        y = y.detach()
3794        self.assertFalse(y.requires_grad)
3795        self.assertIsNone(y.grad_fn)
3796        z = x + y
3797        z.sum().backward()
3798        # This is an incorrect gradient, but we assume that's what the user
3799        # wanted. detach() is an advanced option.
3800        self.assertEqual(x.grad, torch.ones(10, 10))
3801
3802        # in-place detach
3803        x = torch.randn(10, 10, requires_grad=True)
3804        y = torch.randn(10, 10, requires_grad=True)
3805        a = x * 2
3806        (y + a).sum().backward(retain_graph=True)
3807        a.detach_()
3808        self.assertFalse(a.requires_grad)
3809        (y + a).sum().backward()  # this won't backprop to x
3810        self.assertEqual(x.grad, torch.ones(10, 10) * 2)
3811        self.assertEqual(y.grad, torch.ones(10, 10) * 2)
3812
3813        # in-place detach on a view raises an exception
3814        view = x.narrow(0, 1, 4)
3815        self.assertRaisesRegex(RuntimeError, "view", lambda: view.detach_())
3816
3817    def test_detach_base(self):
3818        "detaching base does not detach view"
3819        x = torch.randn(10, 10, requires_grad=True)
3820        view = x.narrow(0, 1, 4)
3821        x.detach_()
3822        self.assertFalse(x.requires_grad)
3823        self.assertTrue(view.requires_grad)
3824        self.assertIsNotNone(view.grad_fn)
3825        self.assertIs(view._base, x)
3826
3827    def test_detach_then_inplace_raises_in_autograd(self):
3828        x = torch.randn([], requires_grad=True)
3829        orig_x = x.detach().clone()
3830
3831        y = x**2  # saves x
3832        z = x.detach()
3833        z.zero_()
3834        with self.assertRaisesRegex(RuntimeError, "has been modified by an inplace"):
3835            y.backward()
3836
3837    def _test_type_conversion_backward(self, t):
3838        fvar = Variable(t(torch.randn(5, 5).float()), requires_grad=True)
3839        fvar.double().sum().backward()
3840        self.assertEqual(fvar.grad, torch.ones_like(fvar))
3841        self.assertEqual(type(fvar.grad), type(fvar))
3842        dvar = Variable(t(torch.randn(5, 5).double()), requires_grad=True)
3843        dvar.float().sum().backward()
3844        self.assertEqual(dvar.grad, torch.ones_like(dvar))
3845        self.assertEqual(type(dvar.grad), type(dvar))
3846
3847    def test_type_conversions(self):
3848        x = torch.randn(5, 5)
3849        self.assertIsInstance(x.float(), torch.FloatTensor)
3850        self.assertIsInstance(x.int(), torch.IntTensor)
3851        if torch.cuda.is_available():
3852            self.assertIsInstance(x.float().cuda(), torch.cuda.FloatTensor)
3853            self.assertIsInstance(x.int().cuda(), torch.cuda.IntTensor)
3854            self.assertIsInstance(x.int().cuda().cpu(), torch.IntTensor)
3855            if torch.cuda.device_count() >= 2:
3856                x2 = x.float().cuda(1)
3857                self.assertIsInstance(x2, torch.cuda.FloatTensor)
3858                self.assertIs(x2.get_device(), 1)
3859                x2 = x.float().cuda()
3860                self.assertIsInstance(x2, torch.cuda.FloatTensor)
3861                self.assertIs(x2.get_device(), 0)
3862                x2 = x2.cuda(1)
3863                self.assertIsInstance(x2, torch.cuda.FloatTensor)
3864                self.assertIs(x2.get_device(), 1)
3865                y = Variable(torch.randn(5).cuda(1), requires_grad=True)
3866                y.cpu().sum().backward()
3867                self.assertIs(y.grad.get_device(), 1)
3868                self.assertIs(y.long().get_device(), 1)
3869
3870        for t in [
3871            torch.DoubleTensor,
3872            torch.FloatTensor,
3873            torch.IntTensor,
3874            torch.ByteTensor,
3875        ]:
3876            for y_var in (True, False):
3877                y = torch.randint(5, (5, 5), dtype=t.dtype)
3878                y = Variable(y) if y_var else y
3879                self.assertIsInstance(x.type(t), t)
3880                self.assertIsInstance(x.type_as(y), t)
3881                # TODO: t.dtype should work
3882                t_dtype = t().dtype
3883                self.assertIsInstance(x.type(t_dtype), t)
3884                self.assertIs(t_dtype, x.type(t_dtype).dtype)
3885                self.assertEqual(y.data_ptr(), y.type(t).data_ptr())
3886                if torch.cuda.is_available():
3887                    for x_cuda in (True, False):
3888                        for y_cuda in (True, False):
3889                            x_c = x.cuda() if x_cuda else x
3890                            y_c = y.cuda() if y_cuda else y
3891                            _, y_type = y_c.type().rsplit(".", 1)
3892                            y_typestr = ("torch.cuda." if y_cuda else "torch.") + y_type
3893                            self.assertEqual(y_c.type(), x_c.type(y_typestr).type())
3894                            self.assertIs(y_c.dtype, x_c.type(y_c.dtype).dtype)
3895                            self.assertEqual(
3896                                y_c.data_ptr(),
3897                                y_c.cuda().data_ptr() if y_cuda else y_c.data_ptr(),
3898                            )
3899
3900        self._test_type_conversion_backward(lambda x: x)
3901        if torch.cuda.is_available():
3902            self._test_type_conversion_backward(lambda x: x.cuda())
3903            if torch.cuda.device_count() >= 2:
3904                # one of these has to be the non-default device
3905                self._test_type_conversion_backward(lambda x: x.cuda(0))
3906                self._test_type_conversion_backward(lambda x: x.cuda(1))
3907
3908    def test_isolated_node(self):
3909        x = torch.randn(5, 5, requires_grad=True)
3910        y = torch.randn(5, 5, requires_grad=True)
3911
3912        a = x + y
3913        b = torch.max(a, 1, True)[1].repeat(1, 5).double()
3914        o = (b + a).sum()
3915        o.backward()
3916
3917    def test_shape(self):
3918        x = torch.randn(3, 4)
3919        self.assertEqual(2, len(x.shape))
3920        self.assertEqual(x.shape[0], 3)
3921        self.assertEqual(x.shape[1], 4)
3922
3923    def test_numpy_requires_grad(self):
3924        x = torch.randn(2, 2, requires_grad=True)
3925        err_msg_outputs = r"Can't call numpy\(\) on Tensor that requires grad. Use tensor.detach\(\).numpy\(\) instead."
3926        with self.assertRaisesRegex(RuntimeError, err_msg_outputs):
3927            x.numpy()
3928
3929        with torch.no_grad():
3930            x.numpy()
3931
3932        x = torch.randn(2, 2)
3933        x.numpy()
3934
3935        with torch.no_grad():
3936            x.numpy()
3937
3938    def test_return_leaf(self):
3939        class Identity(Function):
3940            @staticmethod
3941            def forward(ctx, a, b):
3942                return a, a + b
3943
3944            @staticmethod
3945            def backward(ctx, grad_a, grad_b):
3946                return grad_a + grad_b, grad_b
3947
3948        hook_called = [False]
3949        x = torch.randn(5, 5, requires_grad=True)
3950        y = torch.randn(5, 5, requires_grad=True)
3951
3952        q, p = Identity.apply(x, y)
3953
3954        # Make sure hooks only receive grad from usage of q, not x.
3955        def hook(grad):
3956            hook_called[0] = True
3957            self.assertEqual(grad, torch.ones(5, 5))
3958
3959        q.register_hook(hook)
3960        (q + p + x).sum().backward()
3961        self.assertEqual(x.grad, torch.ones(5, 5) * 3)
3962        self.assertEqual(y.grad, torch.ones(5, 5))
3963        self.assertTrue(hook_called[0])
3964
3965    def test_return_leaf_inplace(self):
3966        class Inplace(InplaceFunction):
3967            @staticmethod
3968            def forward(ctx, a, b):
3969                ctx.mark_dirty(a)
3970                return a.add_(b), b + 2
3971
3972            @staticmethod
3973            def backward(ctx, grad_a, grad_b):
3974                return grad_a, grad_a + grad_b
3975
3976        x = torch.randn(5, 5)
3977        y = torch.randn(5, 5, requires_grad=True)
3978
3979        q, p = Inplace.apply(x, y)
3980        self.assertIs(q, x)
3981        self.assertIs(q.grad_fn.__class__, Inplace._backward_cls)
3982        self.assertTrue(q.requires_grad)
3983        q.sum().backward()
3984        self.assertEqual(y.grad, torch.ones(5, 5))
3985
3986    def test_leaf_assignment(self):
3987        x = torch.randn(5, 5)
3988        y = torch.randn(5, requires_grad=True)
3989        z = torch.randn(5, requires_grad=True)
3990
3991        x[0] = y
3992        x[1] = 2 * z
3993        self.assertTrue(x.requires_grad)
3994        self.assertIsNot(x.grad_fn, None)
3995        x.sum().backward()
3996        self.assertEqual(y.grad, torch.ones(5))
3997        self.assertEqual(z.grad, torch.ones(5) * 2)
3998
3999    def test_no_grad_assignment(self):
4000        x = torch.randn(5, 5, requires_grad=True)
4001        y = torch.randn(5)
4002        with torch.no_grad():
4003            x[0] = y
4004
4005        self.assertTrue(x.requires_grad)
4006        self.assertIsNone(x.grad_fn)
4007
4008    def test_no_grad_modifies_version(self):
4009        x = torch.randn(5, requires_grad=True)
4010        y = torch.randn(5, requires_grad=True)
4011        z = (x * y).sum()
4012        with torch.no_grad():
4013            x *= 2
4014        self.assertRaisesRegex(
4015            RuntimeError, "modified by an inplace operation", lambda: z.backward()
4016        )
4017
4018    def test_increment_version(self):
4019        a = torch.rand(5, requires_grad=True)
4020        v = a._version
4021        torch.autograd.graph.increment_version(a)
4022        self.assertEqual(a._version, v + 1)
4023
4024        a = torch.zeros(5, dtype=torch.int)
4025        v = a._version
4026        torch.autograd.graph.increment_version(a)
4027        self.assertEqual(a._version, v + 1)
4028
4029        with torch.inference_mode():
4030            a = torch.rand(5, requires_grad=True)
4031            # does not error
4032            torch.autograd.graph.increment_version(a)
4033
4034        # does not error
4035        torch.autograd.graph.increment_version(a)
4036
4037    def test_no_grad_input(self):
4038        class MyFunction(Function):
4039            @staticmethod
4040            def forward(self, x):
4041                return x
4042
4043            @staticmethod
4044            def backward(self, grad_output):
4045                return grad_output
4046
4047        x = torch.randn(5, requires_grad=True)
4048        with torch.no_grad():
4049            y = MyFunction.apply(x)
4050
4051        self.assertTrue(x.requires_grad)
4052        self.assertIsNone(y.grad_fn)
4053
4054    def test_backward_copy(self):
4055        # This tests checks backward engine for a very subtle bug that appreared
4056        # in one of the initial versions of autograd. Gradients tensors were
4057        # simply stored in lists while the function waited for all its gradients
4058        # to be computed. However, sometimes an output was used multiple times,
4059        # so the gradients needed to be summed. Engine used to keep a need_copy
4060        # set of tensors that will need a clone upon next addition and removed
4061        # them from the set as soon as the clone was performed. However, this
4062        # could lead to incorrect results if the same gradient tensor was
4063        # buffered in three places in the graph:
4064        # 1. When accumulating gradients in one of these places it was cloned
4065        #    and removed from need_copy set.
4066        # 2. When accumulating in second place, it wasn't in the need_copy set,
4067        #    so the gradients were simply accumulated in-place (which already
4068        #    modified the grad in 3rd place)
4069        # 3. When accumulating in the third place, it wasn't in the need_copy set
4070        #    as well, so the incoming gradient was summed in-place, yielding
4071        #    incorrect results in all functions, except the first one.
4072        x = torch.ones(5, 5, requires_grad=True)
4073        y = torch.ones(5, 5, requires_grad=True)
4074        # Simulate that we're in the middle of the graph
4075        a = x + 2
4076        b = y + 2
4077        c = x + 2
4078        # This op will just return grad_output two times in backward
4079        add1 = a + b
4080        add2 = add1 + c
4081        # Simulate a long branch, so grad_output will get buffered.
4082        for _ in range(4):
4083            a = a * 2
4084            b = b * 2
4085            c = c * 2
4086        branch = a + b + c
4087        out = add2 + branch
4088        # expected gradients are:
4089        # for x: 34 (16 from final a, 16 from final c, 2 from add2)
4090        # for y: 17 (16 from final b, 1 from add2)
4091        grad_output = torch.ones(5, 5)
4092        out.backward(grad_output)
4093        self.assertEqual(x.grad, torch.ones(5, 5) * 34)
4094        self.assertEqual(y.grad, torch.ones(5, 5) * 17)
4095
4096    def test_save_none_for_backward(self):
4097        test_case = self
4098
4099        class MyFn(Function):
4100            @staticmethod
4101            def forward(ctx, input):
4102                ctx.save_for_backward(None, input, None)
4103                return input * input
4104
4105            @staticmethod
4106            def backward(ctx, grad_output):
4107                n1, input, n2 = ctx.saved_tensors
4108                test_case.assertIsNone(n1)
4109                test_case.assertIsNone(n2)
4110                return 2 * input * grad_output
4111
4112        x = torch.randn(5, 5, requires_grad=True)
4113        y = MyFn.apply(x)
4114        y.sum().backward()
4115        self.assertEqual(x.grad, 2 * x)
4116
4117    def test_too_many_grads(self):
4118        class MyFn(Function):
4119            @staticmethod
4120            def forward(ctx, input):
4121                return input
4122
4123            @staticmethod
4124            def backward(ctx, grad_output):
4125                return grad_output, None, None
4126
4127        x = torch.randn(5, 5, requires_grad=True)
4128        y = MyFn.apply(x)
4129        y.sum().backward()
4130        self.assertEqual(x.grad, torch.ones_like(x))
4131
4132    def test_pickle(self):
4133        x = torch.randn(10, 10, requires_grad=True)
4134        y = torch.randn(10, 10, requires_grad=False)
4135
4136        def assert_strict_equal(var1, var2):
4137            self.assertEqual(var1, var2)
4138            self.assertEqual(var1.requires_grad, var2.requires_grad)
4139
4140        serialized = [pickle.dumps([x, y], protocol=p) for p in range(3)]
4141        for dump in serialized:
4142            xc, yc = pickle.loads(dump)
4143            assert_strict_equal(xc, x)
4144            assert_strict_equal(yc, y)
4145
4146    @skipIfTorchDynamo("compile tested in test/dynamo/test_autograd_function.py")
4147    def test_dep_nograd(self):
4148        class F1(Function):
4149            @staticmethod
4150            def forward(ctx, input):
4151                out = torch.randn(input.size())
4152                ctx.mark_non_differentiable(out)
4153                return input, out
4154
4155            @staticmethod
4156            def backward(ctx, grad_output, ignored):
4157                return grad_output
4158
4159        class F2(Function):
4160            @staticmethod
4161            def forward(ctx, input, ignored):
4162                return input
4163
4164            @staticmethod
4165            def backward(ctx, grad_output):
4166                return grad_output, None
4167
4168        x = torch.randn(5, requires_grad=True)
4169        a, b = F1.apply(x)
4170        b = b + 1  # separate F1 from F2 by another op
4171        self.assertTrue(a.requires_grad)
4172        self.assertFalse(b.requires_grad)
4173        c = F2.apply(a, b)
4174        c.backward(torch.ones(c.size()))
4175        self.assertEqual(x.grad, torch.ones(x.size()))
4176
4177    def test_set_grad_enabled(self):
4178        x = torch.tensor([1.0], requires_grad=True)
4179        with torch.set_grad_enabled(False):
4180            y = x * 2
4181        self.assertFalse(y.requires_grad)
4182        with torch.set_grad_enabled(True):
4183            y = x * 2
4184        self.assertTrue(y.requires_grad)
4185        with torch.set_grad_enabled(False):
4186            torch.set_grad_enabled(True)
4187            y = x * 2
4188        self.assertTrue(y.requires_grad)
4189
4190    def test_set_grad_enabled_wraps(self):
4191        for decorator in [True, False]:
4192            with torch.enable_grad():
4193                self.assertTrue(torch.is_grad_enabled())
4194
4195                if decorator:
4196                    # This should not mutate the global grad mode!
4197                    @torch.set_grad_enabled(False)
4198                    def inner_func(x):
4199                        return x.sin()
4200
4201                else:
4202
4203                    def inner_func(x):
4204                        return x.sin()
4205
4206                    # This is non-idiomatic usage!
4207                    # More idiomatic usage: torch.set_grad_enabled(False)(inner_func)
4208                    obj = torch.set_grad_enabled(False)
4209                    self.assertTrue(not torch.is_grad_enabled())
4210
4211                    # this will consume the set_grad_enabled global mutation!
4212                    inner_func = obj(inner_func)
4213                    self.assertTrue(torch.is_grad_enabled())
4214
4215                self.assertTrue(torch.is_grad_enabled())
4216
4217                x = torch.zeros(1, requires_grad=True)
4218                self.assertTrue(not inner_func(x).requires_grad)
4219
4220    def test_simple_reentrant(self):
4221        y_data = torch.randn(2, 2)
4222
4223        class Reenter(Function):
4224            @staticmethod
4225            def forward(ctx, x):
4226                with torch.enable_grad():
4227                    ctx.x = Variable(x, requires_grad=True)
4228                    ctx.y = Variable(y_data, requires_grad=True)
4229                    ctx.output_var = ctx.x * ctx.y
4230                return ctx.output_var.detach()
4231
4232            @staticmethod
4233            def backward(ctx, grad_output):
4234                with torch.enable_grad():
4235                    ctx.output_var.sum().backward()
4236                return ctx.x.grad * grad_output
4237
4238        # Reentrant starts on CPU thread, finishs on GPU thread
4239        x = torch.randn(2, 2, requires_grad=True)
4240        out = Reenter.apply(x)
4241        out.sum().backward()
4242        self.assertEqual(x.grad, y_data)
4243
4244    def test_reentrant_child_error(self):
4245        # Parent graph.
4246        a = torch.rand(3, 3, requires_grad=True)
4247        c = a * a
4248
4249        # Reentrant child graph.
4250        b = torch.rand(3, 3, requires_grad=True)
4251        e = b * b
4252        f = TestAutograd.SimulateBackwardError.apply(e)
4253        reentrant_root = f.sum()
4254
4255        class ReentrantFunc(Function):
4256            @staticmethod
4257            def forward(ctx, inp):
4258                return inp.clone()
4259
4260            @staticmethod
4261            def backward(ctx, grad):
4262                # Reentrant backward in child will throw an error.
4263                reentrant_root.backward()
4264                return grad
4265
4266        d = ReentrantFunc.apply(c)
4267        with self.assertRaisesRegex(Exception, "Simulate error"):
4268            d.sum().backward()
4269
4270    def test_var_mean_differentiable(self):
4271        dim = [2, 4]
4272        keepdim = False
4273        input1 = torch.randn(3, 4, 5, 6, 2, 3, requires_grad=True)
4274        input2 = deepcopy(input1)
4275        var1, mean1 = torch.var_mean(input1, dim=dim, keepdim=keepdim)
4276        var2 = input2.var(dim=dim, keepdim=keepdim)
4277        mean2 = input2.mean(dim=dim, keepdim=keepdim)
4278        grad = torch.randn(3, 4, 6, 3, requires_grad=True)
4279
4280        r1 = var1 * var1 * mean1 * mean1
4281        r2 = var2 * var2 * mean2 * mean2
4282        self.assertEqual(r1, r2, rtol=0.01, atol=0.0)
4283
4284        torch.autograd.backward(r1, grad)
4285        torch.autograd.backward(r2, grad)
4286        self.assertEqual(input1.grad, input2.grad, rtol=0.01, atol=0.0)
4287
4288    @skipIfNoLapack
4289    def test_lobpcg(self):
4290        def func(k, A, largest=True, B=None):
4291            X_shape = list(A.shape)
4292            X_shape[-1] = k
4293            X = torch.eye(A.size(-2), k, dtype=A.dtype, device=A.device)
4294            if A.dim() > 2:
4295                X = X.expand(X_shape)
4296
4297            D, U = torch.lobpcg(A=A, k=k, B=B, X=X, largest=largest)
4298
4299            # LOBPCG uses a random initial eigenspace approximation
4300            # if parameter `X` is not provided.
4301            # This may cause a non-deterministic behavior
4302            # when it comes to the sign of an eigenvector
4303            # (note if v is an eigenvector, so is -v),
4304            # hence we eliminate this non-determinism
4305            # by making sure that each column of U
4306            # gets multiplied by the sign of its max (in absolute value) element.
4307            # Also, gradcheck changes the content of the input by +/- eps (default to 1e-06)
4308            # to compute the numerical gradient which can also cause the signs to flip.
4309            _, idx = U.abs().max(-2, keepdim=True)
4310            sign = U.gather(-2, idx).sign()
4311            U = U * sign
4312            return D, U
4313
4314        # TODO: review if this can be ported to OpInfos or moved to test_linalg.py
4315        def run_symeig_test(k, sizes, largest=True):
4316            A = torch.rand(*sizes).double()
4317            A = (A @ A.mT) / 10
4318            A.requires_grad_(True)
4319
4320            gradcheck(lambda A: func(k, A, largest), A, check_batched_grad=False)
4321
4322            # Custom gradient vectors for better stability due to some
4323            # non-determinism in the lobpcg's forward.
4324            # Note it is not required if symeig is in forward instead (tested).
4325            D_grad = torch.rand(*A.shape[:-2], k) / 100
4326            U_grad = torch.rand(*A.shape[:-1], k) / 100
4327            gradgradcheck(
4328                lambda A: func(k, A, largest),
4329                A,
4330                [D_grad, U_grad],
4331                atol=1e-4,
4332                check_batched_grad=False,
4333            )
4334
4335            # check whether A.grad is symmetric
4336            A = A.detach().requires_grad_(True)
4337            D, U = func(k, A, largest)
4338            (D.sum() + U.sum()).backward()
4339            self.assertEqual(A.grad, A.grad.mT)
4340
4341        for largest in [True, False]:
4342            run_symeig_test(1, (6, 6), largest=largest)
4343            run_symeig_test(1, (2, 6, 6), largest=largest)
4344            run_symeig_test(1, (2, 2, 6, 6), largest=largest)
4345            run_symeig_test(2, (6, 6), largest=largest)
4346            run_symeig_test(2, (2, 6, 6), largest=largest)
4347            run_symeig_test(2, (2, 2, 6, 6), largest=largest)
4348            run_symeig_test(3, (9, 9), largest=largest)
4349            run_symeig_test(3, (2, 9, 9), largest=largest)
4350            run_symeig_test(3, (2, 2, 9, 9), largest=largest)
4351
4352    def test_variable_traverse(self):
4353        def get_out_and_unrefed_cycle():
4354            inp = torch.randn(10, requires_grad=True)
4355            tmp = inp.view(10, 1)
4356            out = tmp.view(10)
4357
4358            # Create a reference cycle that contains an
4359            # intermediary Variable in the graph
4360            my_list = []
4361            my_list.append(tmp)
4362            my_list.append(my_list)
4363
4364            return out
4365
4366        out = get_out_and_unrefed_cycle()
4367        gc.collect()
4368        # This will segfault if things have been erroneously released
4369        out.backward(torch.randn(out.size()))
4370
4371    # TODO: review porting these to OpInfo tests
4372    def test_pow_zero_tensor_gradient(self):
4373        def run_test(input_size, exponent):
4374            input = torch.zeros(*input_size, requires_grad=True)
4375            input.pow(exponent).sum().backward()
4376            self.assertEqual(input.grad.abs().sum(), 0)
4377
4378        run_test((10,), torch.zeros(10))
4379        run_test((10, 10), torch.zeros(10, 10))
4380        run_test((10,), 0)
4381
4382    def test_current_graph_task_id(self):
4383        id = [-1]
4384
4385        def hook(_):
4386            id[0] = torch._C._current_graph_task_id()
4387
4388        t = torch.tensor(1.0, requires_grad=True).clone()
4389        t.register_hook(hook)
4390
4391        t.backward(retain_graph=True)
4392        base = id[0]
4393        t.backward(retain_graph=True)
4394        self.assertEqual(id[0] - base, 1)
4395        t.backward(retain_graph=True)
4396        self.assertEqual(id[0] - base, 2)
4397
4398        self.assertEqual(torch._C._current_graph_task_id(), -1)
4399
4400    def test_current_graph_task_execution_order(self):
4401        predicted = [None]
4402
4403        def hook(_):
4404            predicted[0] = torch._C._current_graph_task_execution_order()
4405
4406        def names(nodes):
4407            return ", ".join([node.name().split(" ")[-1] for node in nodes]) + "\n"
4408
4409        def grad_fns(*tensors):
4410            # or grad accumulator
4411            out = []
4412            for t in tensors:
4413                if t.requires_grad and t.grad_fn is None:
4414                    out.append(t.clone().grad_fn.next_functions[0][0])
4415                else:
4416                    out.append(t.grad_fn)
4417            return out
4418
4419        actual = []
4420
4421        def register_logging_hooks(*tensors):
4422            # register hooks that log the order in which they are called
4423            def get_hook(i):
4424                def hook(t_):
4425                    actual.append(tensors[i])
4426
4427                return hook
4428
4429            for i, t in enumerate(tensors):
4430                t.register_hook(get_hook(i))
4431
4432        # Basic example: single path
4433        t = torch.tensor(1.0, requires_grad=True).clone().sin().exp()
4434        t.register_hook(hook)
4435        with torch.autograd.set_multithreading_enabled(False):
4436            t.backward()
4437        self.assertExpectedInline(
4438            names(predicted[0]),
4439            """\
4440ExpBackward0, SinBackward0, CloneBackward0, torch::autograd::AccumulateGrad
4441""",
4442        )
4443
4444        # We don't exactly follow sequence_nr order
4445        a = torch.tensor(1.0, requires_grad=True)
4446        b = torch.tensor(2.0, requires_grad=True)
4447        c = b.sin()
4448        d = a.cos()
4449        out = c * d
4450        register_logging_hooks(a, b, c, d, out)
4451        out.register_hook(hook)
4452        with torch.autograd.set_multithreading_enabled(False):
4453            out.backward()
4454        self.assertEqual(predicted[0], grad_fns(*actual))
4455        actual = []
4456
4457        # Accumulate grad node has more than one input
4458        a = torch.tensor(1.0, requires_grad=True)
4459        b = a.sin()
4460        c = a.cos()
4461        out = b * c
4462        register_logging_hooks(a, b, c, out)
4463        out.register_hook(hook)
4464        with torch.autograd.set_multithreading_enabled(False):
4465            out.backward()
4466        self.assertEqual(predicted[0], grad_fns(*actual))
4467        actual = []
4468
4469        # Multiple roots are also OK
4470        a = torch.tensor(1.0, requires_grad=True)
4471        b = a * 2
4472        out = b.sin()
4473        out2 = b.cos()
4474        out3 = b.cos()
4475        register_logging_hooks(a, b, out, out2, out3)
4476        out3.register_hook(hook)
4477        with torch.autograd.set_multithreading_enabled(False):
4478            torch.autograd.grad((out, out3, out2), inputs=(a,))
4479        self.assertExpectedInline(
4480            names(predicted[0]),
4481            """\
4482CosBackward0, CosBackward0, SinBackward0, MulBackward0, torch::autograd::AccumulateGrad
4483""",
4484        )
4485        # TODO: Uncomment after update to hooks behavior
4486        # self.assertEqual(predicted[0], grad_fns(*actual))
4487        actual = []
4488
4489        # Case where next node is nullptr
4490        a = torch.tensor(1.0, requires_grad=True)
4491        b = a * 2
4492        out = b.sin()
4493        register_logging_hooks(a, b, out)
4494        out.register_hook(hook)
4495        with torch.autograd.set_multithreading_enabled(False):
4496            out.backward()
4497        self.assertEqual(predicted[0], grad_fns(*actual))
4498        actual = []
4499
4500        # Case where two `inputs` on the same path
4501        a = torch.tensor(1.0, requires_grad=True)
4502        b = a * 2
4503        out = b.sin()
4504        register_logging_hooks(a, b, out)
4505        out.register_hook(hook)
4506        with torch.autograd.set_multithreading_enabled(False):
4507            torch.autograd.grad((out,), inputs=(a, b))
4508        self.assertEqual(
4509            names(predicted[0]),
4510            """\
4511SinBackward0, MulBackward0, torch::autograd::AccumulateGrad
4512""",
4513        )
4514        # TODO: Uncomment after update to hooks behavior
4515        # self.assertEqual(predicted[0], grad_fns(*actual))
4516        actual = []
4517
4518        # Case where `inputs` specifies a subgraph
4519        a = torch.tensor(1.0, requires_grad=True)
4520        b = torch.tensor(1.0, requires_grad=True)
4521        c = a * b
4522        out = c.sin()
4523        register_logging_hooks(a, b, c, out)
4524        out.register_hook(hook)
4525        with torch.autograd.set_multithreading_enabled(False):
4526            torch.autograd.grad((out,), inputs=(a,))
4527        self.assertEqual(
4528            names(predicted[0]),
4529            """\
4530SinBackward0, MulBackward0, torch::autograd::AccumulateGrad
4531""",
4532        )
4533        # TODO: Uncomment after update to hooks behavior
4534        # self.assertEqual(predicted[0], grad_fns(*actual))
4535        actual = []
4536
4537        # Errors when not called in a backward
4538        with self.assertRaisesRegex(
4539            RuntimeError, "should only be called during the backward pass"
4540        ):
4541            torch._C._current_graph_task_execution_order()
4542
4543        # Errors when context manager not enabled
4544        t = torch.tensor(1.0, requires_grad=True).clone().sin().exp()
4545        t.register_hook(hook)
4546        with self.assertRaisesRegex(
4547            RuntimeError,
4548            "expects the current backward to be executed with multithreading disabled",
4549        ):
4550            t.backward()
4551
4552    def test_view_replay_enabled(self):
4553        def f(x):
4554            out = x.clone().view(-1)
4555            # mutate the view, triggering autograd view-replay logic
4556            out.add_(1)
4557            return out
4558
4559        x = torch.ones(2, 2, requires_grad=True)
4560
4561        # Test as a context manager
4562        with torch.autograd._force_original_view_tracking(False):
4563            out = f(x)
4564            self.assertTrue("AsStridedBackward" in str(out.grad_fn))
4565            self.assertFalse(torch.autograd.is_view_replay_enabled())
4566        self.assertFalse(torch.autograd.is_view_replay_enabled())
4567
4568        with torch.autograd._force_original_view_tracking(True):
4569            out = f(x)
4570            self.assertTrue("ViewBackward" in str(out.grad_fn))
4571            self.assertTrue(torch.autograd.is_view_replay_enabled())
4572        out = f(x)
4573        self.assertTrue("AsStridedBackward" in str(out.grad_fn))
4574        self.assertFalse(torch.autograd.is_view_replay_enabled())
4575
4576        with torch.autograd._force_original_view_tracking(False):
4577            torch.autograd._force_original_view_tracking(True)
4578            out = f(x)
4579            self.assertTrue("ViewBackward" in str(out.grad_fn))
4580            self.assertTrue(torch.autograd.is_view_replay_enabled())
4581        self.assertFalse(torch.autograd.is_view_replay_enabled())
4582
4583        # Test as a function
4584        torch.autograd._force_original_view_tracking(False)
4585        out = f(x)
4586        self.assertTrue("AsStridedBackward" in str(out.grad_fn))
4587        self.assertFalse(torch.autograd.is_view_replay_enabled())
4588
4589        torch.autograd._force_original_view_tracking(True)
4590        out = f(x)
4591        self.assertTrue("ViewBackward" in str(out.grad_fn))
4592        self.assertTrue(torch.autograd.is_view_replay_enabled())
4593
4594    def test_unsafe_set_version_counter(self):
4595        x = torch.ones(2, requires_grad=True).clone()
4596        x.add_(1)
4597        x.add_(2)
4598        self.assertEqual(2, x._version)
4599        with torch.autograd._unsafe_preserve_version_counter(x):
4600            x.mul_(2)
4601            x.mul_(3)
4602        # version counter doesn't change inside of the context manager
4603        self.assertEqual(2, x._version)
4604
4605        torch._C._autograd._unsafe_set_version_counter(x, 0)
4606        self.assertEqual(0, x._version)
4607        with self.assertRaisesRegex(RuntimeError, "Cannot set"):
4608            torch._C._autograd._unsafe_set_version_counter(x, -1)
4609
4610    def test_current_node(self):
4611        pr = []
4612
4613        class MyMode(TorchDispatchMode):
4614            def __torch_dispatch__(self, func, types, args, kwargs=None):
4615                node = torch._C._current_autograd_node()
4616                # Don't use node.name() here as it is not consistent on windows
4617                node_name = node.__class__.__name__ if node else "None"
4618                pr.append(f"Running {func} from within {node_name}")
4619                return func(*args, **(kwargs or {}))
4620
4621        with MyMode():
4622            pr.append("FW")
4623            a = torch.rand(10, requires_grad=True)
4624            b = a.mul(2).div(3).sum()
4625            pr.append("BW")
4626            b.backward()
4627            pr.append("Done")
4628
4629        self.assertExpectedInline(
4630            "\n".join(pr),
4631            """\
4632FW
4633Running aten.rand.default from within None
4634Running aten.mul.Tensor from within None
4635Running aten.div.Tensor from within None
4636Running aten.sum.default from within None
4637BW
4638Running aten.ones_like.default from within None
4639Running aten.expand.default from within SumBackward0
4640Running aten.div.Tensor from within DivBackward0
4641Running aten.mul.Tensor from within MulBackward0
4642Running aten.detach.default from within AccumulateGrad
4643Running aten.detach.default from within AccumulateGrad
4644Done""",
4645        )
4646
4647    def test_profiler(self):
4648        x = torch.randn(10, 10)
4649
4650        with profile(use_kineto=kineto_available()) as p:
4651            self.assertTrue(torch.autograd._profiler_enabled())
4652            y = x * 2 + 4
4653
4654        self.assertFalse(torch.autograd._profiler_enabled())
4655
4656        names = ["aten::mul", "aten::add"]
4657        found_indices = set()
4658        for evt in p.function_events:
4659            if evt.name in names:
4660                found_indices.add(names.index(evt.name))
4661        self.assertEqual(len(found_indices), len(names))
4662
4663    def test_profiler_seq_nr(self):
4664        with profile(use_kineto=kineto_available()) as p:
4665            x = torch.randn(10, 10, requires_grad=True)
4666            y = torch.randn(10, 10, requires_grad=True)
4667            z = x + y
4668            s = z.sum(dim=None)
4669            s.backward()
4670        print(p.key_averages().table(sort_by="self_cpu_time_total", row_limit=-1))
4671        # expecting aten::add, aten::sum to have the sequence numbers,
4672        # expecting the corresponding backward nodes to have the same numbers
4673        # as the forward ops
4674        autograd_ops = {
4675            ("aten::add", "Add"): [],
4676            ("aten::sum", "Sum"): [],
4677        }
4678        accumulate_ops = []
4679        found_empty = False
4680        for e in p.function_events:
4681            for (fwd_name, bwd_name), ops in autograd_ops.items():
4682                if e.name == fwd_name or (bwd_name in e.name and "Backward" in e.name):
4683                    ops.append(e)
4684
4685            if "AccumulateGrad" in e.name:
4686                accumulate_ops.append(e)
4687
4688            # check that nested ops (e.g. empty) don't have
4689            # sequence number
4690            if e.name == "aten::empty":
4691                self.assertEqual(e.sequence_nr, -1)
4692                found_empty = True
4693
4694        for idx, ((fwd_name, bwd_name), ops) in enumerate(autograd_ops.items()):
4695            self.assertEqual(len(ops), 3)
4696            self.assertEqual(ops[0].name, fwd_name)
4697            self.assertEqual(
4698                ops[1].name,
4699                f"autograd::engine::evaluate_function: {bwd_name}Backward{idx}",
4700            )
4701            self.assertEqual(ops[2].name, f"{bwd_name}Backward{idx}")
4702            self.assertGreaterEqual(ops[0].sequence_nr, 0)
4703            self.assertEqual(ops[1].sequence_nr, ops[0].sequence_nr)
4704            self.assertEqual(ops[2].sequence_nr, ops[0].sequence_nr)
4705            self.assertEqual(ops[0].fwd_thread, 0)
4706            self.assertEqual(ops[1].fwd_thread, ops[0].thread)
4707            self.assertEqual(ops[2].fwd_thread, ops[0].thread)
4708        self.assertTrue(found_empty)
4709
4710    def test_profiler_unboxed_only(self):
4711        x = torch.rand(3, 4)
4712
4713        with torch.autograd.profiler.profile(use_kineto=kineto_available()) as prof:
4714            x.resize_([3, 2])
4715
4716    def test_profiler_propagation(self):
4717        def foo(x):
4718            with record_function("in_foo") as rf:
4719                return x * 2
4720
4721        x = torch.rand(3, 4)
4722        traced_foo = torch.jit.trace(foo, x)
4723
4724        def bar(x):
4725            with record_function("in_bar") as rf:
4726                # we expect that profiler will be able
4727                # propagate across fork
4728                fut = torch.jit._fork(traced_foo, x)
4729                y = torch.jit._wait(fut)
4730                # note: continuation (and rf's end) can
4731                # be executed in a different thread
4732                with record_function("in_bar_after_wait") as rf2:
4733                    y = y * 2
4734                return y
4735
4736        traced_bar = torch.jit.trace(bar, x)
4737
4738        with profile(use_kineto=kineto_available()) as p:
4739            traced_bar(x)
4740
4741        found_foo = False
4742        found_bar = False
4743        found_bar_after_wait = False
4744        for info in p.function_events:
4745            if info.name == "in_foo":
4746                self.assertFalse(found_foo)
4747                found_foo = True
4748            elif info.name == "in_bar":
4749                self.assertFalse(found_bar)
4750                found_bar = True
4751            elif info.name == "in_bar_after_wait":
4752                self.assertFalse(found_bar_after_wait)
4753                found_bar_after_wait = True
4754        self.assertTrue(found_foo)
4755        self.assertTrue(found_bar)
4756        self.assertTrue(found_bar_after_wait)
4757
4758    def test_record_function_callbacks(self):
4759        x = torch.randn(10, 10)
4760        with profile(use_kineto=kineto_available()) as p:
4761            with record_function("foo"):
4762                y = x * 2 + 4
4763
4764        function_events = p.function_events
4765        foo_event = next(event for event in function_events if "foo" in event.name)
4766        self.assertEqual(foo_event.count, 1)
4767
4768    def test_record_function_legacy(self):
4769        # Test the new _record_function ops work
4770        # Note: Remove once record_function uses these directly
4771        x = torch.randn(10, 10)
4772        with profile(use_kineto=kineto_available()) as p:
4773            handle = torch.ops.profiler._record_function_enter("bar", None)
4774            try:
4775                y = x * 2 + 4
4776            finally:
4777                torch.ops.profiler._record_function_exit(handle)
4778
4779        function_events = p.function_events
4780        foo_event = next(event for event in function_events if "bar" in event.name)
4781        self.assertEqual(foo_event.count, 1)
4782
4783    def test_profiler_aggregation_fake(self):
4784        events = EventList()
4785        id = [0]
4786
4787        def get_id():
4788            id[0] = id[0] + 1
4789            return id[0]
4790
4791        # [[thread_id, [(start, end, id), ....]], ...]
4792        # Using list instead of a dict so order is guaranteed for any Python
4793        # version
4794        threads = [
4795            [1, [(0, 1, get_id()), (1, 2, get_id())]],
4796            [0, [(0, 2, get_id()), (1, 2, get_id()), (1, 3, get_id())]],
4797        ]
4798        for thread, ranges in threads:
4799            for range in ranges:
4800                assert len(range) == 3
4801                events.append(
4802                    FunctionEvent(
4803                        id=range[2],
4804                        node_id=0,
4805                        name="",
4806                        thread=thread,
4807                        start_us=range[0],
4808                        end_us=range[1],
4809                    )
4810                )
4811
4812        events._populate_cpu_children()
4813
4814        # Note that [1, 3] pushes out [0, 2] first. Then we record [1, 2]
4815        # as a child of [1, 3]
4816        res = [[], [], [], [], [4]]
4817
4818        def get_children_ids(event):
4819            return [child.id for child in event.cpu_children]
4820
4821        assert [get_children_ids(event) for event in events] == res
4822
4823    def test_profiler_aggregation_table(self):
4824        """
4825        Test if the profiling result is aggregated for `str(prof)`
4826
4827        See: https://github.com/pytorch/pytorch/issues/37500
4828        """
4829
4830        x = torch.randn(1024)
4831        with torch.autograd.profiler.profile(use_kineto=kineto_available()) as prof:
4832            torch.einsum("i->", x)
4833
4834        prof_str = str(prof)
4835        prof_table = prof.table()
4836
4837        self.assertEqual(prof_table, prof_str)
4838
4839    def test_profiler_function_event_avg(self):
4840        avg = FunctionEventAvg()
4841        avg.add(
4842            FunctionEvent(id=0, node_id=0, name="foo", thread=0, start_us=10, end_us=15)
4843        )
4844        avg.add(
4845            FunctionEvent(id=1, node_id=0, name="foo", thread=0, start_us=20, end_us=30)
4846        )
4847        avg.add(avg)
4848        self.assertEqual(avg.key, "foo")
4849
4850        # aggregate stats
4851        self.assertEqual(avg.count, 4)
4852        self.assertEqual(avg.cpu_time_total, 30)
4853        self.assertEqual(avg.self_cpu_time_total, 30)
4854        self.assertEqual(avg.device_time_total, 0)
4855
4856        # average stats
4857        self.assertEqual(avg.cpu_time, 7.5)
4858        self.assertEqual(avg.device_time_total, 0)
4859
4860    def test_profiler_shapes(self):
4861        print()
4862        layer1 = torch.nn.Linear(20, 30)
4863        layer2 = torch.nn.Linear(30, 40)
4864        input = torch.randn(128, 20)
4865        with profile(record_shapes=True, use_kineto=kineto_available()) as prof:
4866            layer2(layer1(input))
4867
4868        print(prof.function_events)
4869
4870        linear_expected_shapes = [
4871            [[128, 20], [30, 20], [30]],
4872            [[128, 30], [40, 30], [40]],
4873        ]
4874
4875        found_indices = set()
4876        for event in prof.function_events:
4877            if event.name == "aten::linear":
4878                self.assertTrue(event.input_shapes in linear_expected_shapes)
4879                found_indices.add(linear_expected_shapes.index(event.input_shapes))
4880        self.assertEqual(len(found_indices), len(linear_expected_shapes))
4881
4882    def test_profiler_aggregation_lstm(self):
4883        print()
4884        rnn = torch.nn.LSTM(10, 20, 2)
4885        total_time_s = 0
4886        with profile(record_shapes=True, use_kineto=kineto_available()) as prof:
4887            for i in range(20):
4888                input = torch.randn(5, 3, 10)
4889                h = torch.randn(2, 3, 20)
4890                c = torch.randn(2, 3, 20)
4891                start = time.time()
4892                rnn(input, (h, c))
4893                end = time.time()
4894                total_time_s += end - start
4895
4896        print(prof.table(sort_by="self_cpu_time_total", row_limit=10, header="TEST"))
4897        print(
4898            prof.key_averages(group_by_input_shape=True).table(
4899                sort_by="self_cpu_time_total", row_limit=10
4900            )
4901        )
4902        print(
4903            prof.table(
4904                sort_by="self_cpu_time_total",
4905                row_limit=10,
4906                max_src_column_width=300,
4907                header="TEST",
4908                top_level_events_only=True,
4909            )
4910        )
4911        print(
4912            prof.key_averages(group_by_input_shape=True).table(
4913                sort_by="self_cpu_time_total", row_limit=10, top_level_events_only=True
4914            )
4915        )
4916
4917        total_time_us = (
4918            total_time_s * 1000.0 * 1000.0
4919        )  # make it us which is profiler default
4920        print("Total time based on python measurements: ", _format_time(total_time_us))
4921        print(
4922            f"CPU time measurement python side overhead: {(total_time_us / prof.self_cpu_time_total - 1.0) * 100.0:.2f}%"
4923        )
4924
4925        if sys.platform != "win32":
4926            with tempfile.NamedTemporaryFile() as trace_file:
4927                prof.export_chrome_trace(trace_file.name)
4928
4929    def test_record_function(self):
4930        x = torch.randn(10, 10)
4931
4932        def forward(x):
4933            with record_function("outer"):
4934                y = x * 2 + 4
4935                with record_function("inner"):
4936                    y = y - 1
4937            y = y / 1
4938
4939        forward(x)
4940
4941        with profile(use_kineto=kineto_available()) as p:
4942            forward(x)
4943
4944        events = p.function_events
4945        important_events = [
4946            "outer",
4947            "aten::mul",
4948            "aten::add",
4949            "inner",
4950            "aten::sub",
4951            "aten::div",
4952        ]
4953        idx = 0
4954        for info in events:
4955            if info.name == important_events[idx]:
4956                idx = idx + 1
4957            if idx == len(important_events):
4958                break
4959        self.assertEqual(idx, len(important_events))
4960
4961        # We can also use record_function to decorate arbitrary function
4962        @record_function("my_func")
4963        def f(x, y):
4964            return x + y
4965
4966        with profile(use_kineto=kineto_available()) as p:
4967            f(1, 2)
4968
4969        self.assertTrue("my_func" in str(p))
4970
4971    def test_record_function_multithreaded(self):
4972        rf = record_function("outer")
4973        rf.__enter__()
4974        with record_function("inner"):
4975            # test that exiting the record function after starting another one
4976            # doesn't throw.
4977            rf.__exit__(None, None, None)
4978
4979        with record_function("inner"):
4980            rf.__enter__()
4981        # test that exiting the record function after ending another one
4982        # doesn't throw.
4983        rf.__exit__(None, None, None)
4984
4985    def test_dir(self):
4986        x = torch.randn(10, 10)
4987        keys = dir(x)
4988        self.assertIn("shape", keys)
4989
4990        # real and imag are only implemented for complex tensors.
4991        y = torch.randn(10, 10, dtype=torch.cfloat)
4992        imag_key = "imag"
4993        self.assertRaises(RuntimeError, lambda: hasattr(x, imag_key))
4994        self.assertTrue(hasattr(y, imag_key))
4995        keys.remove(imag_key)
4996
4997        for key in keys:
4998            self.assertTrue(hasattr(x, key))
4999
5000    def test_inplace_on_view_saved_output(self):
5001        # Test an in-place operation on a view in which the in-place op saves
5002        # its output. Previously, this created a reference cycle.
5003        dealloc = [0]
5004
5005        class IncrementOnDelete:
5006            def __del__(self):
5007                dealloc[0] += 1
5008
5009        def test():
5010            root = torch.randn(3, 3, requires_grad=True)
5011            copy = root.clone()
5012            copy.grad_fn.register_hook(IncrementOnDelete())
5013            view = copy.view(9)
5014            torch.nn.functional.relu(view, inplace=True)
5015
5016        test()
5017        self.assertEqual(dealloc[0], 1)
5018
5019    def test_inplace_on_view_leaf_errors(self):
5020        # Issue #21875: Fail faster (when we try to modify the view vs. in backward())
5021        x = torch.zeros(1, requires_grad=True)
5022        y = x.view_as(x)
5023        with self.assertRaisesRegex(
5024            RuntimeError,
5025            "a view of a leaf Variable that "
5026            "requires grad is being used in "
5027            "an in-place operation.",
5028        ):
5029            y.add_(1)
5030
5031    def test_inplace_on_view_backward(self):
5032        # Issue #10532: Make sure that this does not raise RuntimeError.
5033        net = nn.Sequential(nn.InstanceNorm2d(2), nn.ReLU(True))
5034
5035        x = torch.tensor([[[[1.0, 1.0]]]], requires_grad=True)
5036        (g,) = torch.autograd.grad(
5037            net(x).pow(2), [x], grad_outputs=x.new_ones(x.shape), create_graph=True
5038        )
5039        torch.autograd.grad(g.sum(), [x])
5040        self.assertEqual(x, torch.tensor([[[[1.0, 1.0]]]]))
5041
5042        # https://discuss.pytorch.org/t/freeing-buffer-strange-behavior/31955/8
5043        inputs = torch.ones((1, 3, 256, 256), requires_grad=True)
5044
5045        tmp1 = (inputs + 1).view_as(inputs)
5046        tmp2 = torch.nn.functional.threshold(tmp1, 0.0, 0.0, True)
5047        prob_interpolated = torch.sigmoid(tmp2)
5048
5049        gradients = torch.autograd.grad(
5050            outputs=prob_interpolated,
5051            inputs=inputs,
5052            grad_outputs=torch.ones(prob_interpolated.size()),
5053            create_graph=True,
5054            retain_graph=True,
5055        )[0]
5056
5057        gradient_penalty = gradients.sum()
5058        gradient_penalty.backward()
5059
5060        fn = gradient_penalty.grad_fn.next_functions[0][0].next_functions[1][0]
5061        self.assertEqual(fn.name(), "ThresholdBackwardBackward0")
5062
5063    def test_inplace_on_view_weak_grad_fn(self):
5064        # Issue 23502: Test that b's grad_fn is preserved.
5065        a = torch.arange(10.0, requires_grad=True)
5066
5067        b = a.narrow(0, 0, 2).clone().view(-1)
5068        b.relu_()
5069
5070        c = b.clone()
5071        del b
5072        gc.collect()
5073
5074        s = c.sum()
5075        s.backward()
5076        self.assertEqual(s, torch.tensor(1.0))
5077
5078        # Issue #21875: Fail faster (when we try to modify the view vs. in backward())
5079        a = torch.rand(10, requires_grad=True).narrow(0, 0, 10)
5080        with self.assertRaises(RuntimeError):
5081            b = a.relu_()
5082
5083    def test_out_variant_raises_when_inputs_require_grad(self):
5084        a = torch.randn(2, 2, requires_grad=True)
5085        b = torch.randn(2, 2, requires_grad=True)
5086        x = torch.zeros_like(a)
5087
5088        # out=... functions don't support automatic differentiation currently
5089        self.assertRaisesRegex(RuntimeError, "out=", lambda: torch.mul(a, b, out=x))
5090
5091        # the inputs can require grad if we're in no_grad() mode
5092        with torch.no_grad():
5093            torch.mul(a, b, out=x)
5094            self.assertEqual(x, a * b)
5095
5096        a = torch.randn(2, 2)
5097        b = torch.randn(2, 2)
5098        x = torch.zeros(2, 2, requires_grad=True)
5099        # we should throw an exception if the output requires grad
5100        self.assertRaisesRegex(RuntimeError, "out=", lambda: torch.mul(a, b, out=x))
5101
5102    def test_anomaly_detect_nan(self):
5103        size = 10
5104
5105        class MyFunc(Function):
5106            @staticmethod
5107            def forward(ctx, inp1, inp2, fail_0th):
5108                ctx.fail_0th = fail_0th
5109                return inp1.sum(0, keepdim=True)
5110
5111            @staticmethod
5112            def backward(ctx, gO):
5113                gI = gO.clone().expand(size)
5114                gI[0] = 0
5115                gI[0] /= 0  # Generate a nan
5116                if ctx.fail_0th:
5117                    return gI, None, None
5118                else:
5119                    return None, gI, None
5120
5121        inp = torch.rand(size, requires_grad=True)
5122        out = MyFunc.apply(inp, inp, True)
5123        out.backward()  # Should not fail
5124
5125        inp = torch.rand(size, requires_grad=True)
5126        out = MyFunc.apply(inp, inp, True)
5127        with self.assertRaisesRegex(
5128            RuntimeError,
5129            "Function 'MyFuncBackward' returned nan values in its 0th output.",
5130        ):
5131            with warnings.catch_warnings(record=True) as w:
5132                with detect_anomaly():
5133                    out.backward()
5134            self.assertIn("No forward pass information", str(w[0].message))
5135
5136        inp = torch.rand(size, requires_grad=True)
5137        with self.assertRaisesRegex(
5138            RuntimeError,
5139            "Function 'MyFuncBackward' returned nan values in its 1th output.",
5140        ):
5141            with warnings.catch_warnings(record=True) as w:
5142                with detect_anomaly():
5143                    out = MyFunc.apply(inp, inp, False)
5144                    out.backward()
5145            self.assertIn("MyFunc.apply", str(w[0].message))
5146
5147    def test_calculate_shape_util(self):
5148        out = torch.randn(10, 5, requires_grad=True)
5149        grad = torch.randn(5, 10, requires_grad=True)
5150        out_shape, grad_shape = _calculate_shape(out, grad, False)
5151
5152        assert out_shape == torch.Size([10, 5])
5153        assert grad_shape == torch.Size([5, 10])
5154
5155        out = torch.nested.as_nested_tensor(
5156            [
5157                torch.randn(10, 5, requires_grad=True),
5158                torch.randn(10, 5, requires_grad=True),
5159                torch.randn(10, 5, requires_grad=True),
5160            ]
5161        )
5162        grad = torch.nested.as_nested_tensor(
5163            [
5164                torch.randn(5, 10, requires_grad=True),
5165                torch.randn(5, 10, requires_grad=True),
5166            ]
5167        )
5168        out_shape, grad_shape = _calculate_shape(out, grad, False)
5169
5170        assert torch.equal(out_shape, torch.tensor([[10, 5], [10, 5], [10, 5]]))
5171        assert torch.equal(grad_shape, torch.tensor([[5, 10], [5, 10]]))
5172
5173    def test_nested_anomaly_detect_nan(self):
5174        size = 10
5175
5176        class MyFunc(Function):
5177            @staticmethod
5178            def forward(ctx, inp1, fail_0th):
5179                ctx.fail_0th = fail_0th
5180                ctx.save_for_backward(inp1)
5181                return inp1.sum(0, keepdim=True)
5182
5183            @staticmethod
5184            def backward(ctx, gO):
5185                (inp,) = ctx.saved_tensors
5186                fail_0th = ctx.fail_0th
5187                g = gO.clone().expand(size)
5188                gI = MyFunc2.apply(g * inp, g + inp, fail_0th)
5189                return gI, None
5190
5191        class MyFunc2(Function):
5192            @staticmethod
5193            def forward(ctx, inp1, inp2, fail_0th):
5194                ctx.fail_0th = fail_0th
5195                return inp1 * 2.0 + inp2
5196
5197            @staticmethod
5198            def backward(ctx, gO):
5199                fail_0th = ctx.fail_0th
5200                g1 = gO.clone()
5201                g2 = gO.clone()
5202                g1[0] = 0
5203                g2[0] = 0
5204                # generate a nan
5205                if fail_0th:
5206                    g1[0] /= 0
5207                else:
5208                    g2[0] /= 0
5209                return g1, g2, None
5210
5211        inp = torch.rand(size, requires_grad=True)
5212        out = MyFunc.apply(inp, True)
5213        (ginp,) = torch.autograd.grad(out, (inp,), create_graph=True)
5214        gsum = ginp.sum()
5215        gsum.backward()  # should not fail
5216
5217        inp = torch.rand(size, requires_grad=True)
5218        out = MyFunc.apply(inp, True)
5219        (ginp,) = torch.autograd.grad(out, (inp,), create_graph=True)
5220        gsum = ginp.sum()
5221        with warnings.catch_warnings(record=True) as w:
5222            with self.assertRaisesRegex(
5223                RuntimeError,
5224                "Function 'MyFunc2Backward' returned nan values in its 0th output.",
5225            ):
5226                with detect_anomaly():
5227                    gsum.backward()
5228        self.assertIn("No forward pass information", str(w[1].message))
5229
5230        inp = torch.rand(size, requires_grad=True)
5231        with warnings.catch_warnings(record=True) as w:
5232            with self.assertRaisesRegex(
5233                RuntimeError,
5234                "Function 'MyFunc2Backward' returned nan values in its 1th output.",
5235            ):
5236                with detect_anomaly():
5237                    out = MyFunc.apply(inp, False)
5238                    (ginp,) = torch.autograd.grad(out, (inp,), create_graph=True)
5239                    gsum = ginp.sum()
5240                    gsum.backward()
5241        self.assertIn("MyFunc2.apply", str(w[1].message))
5242        self.assertIn("MyFunc.apply", str(w[2].message))
5243
5244    def test_anomaly_grad_warnings(self):
5245        # PyTorch won't throw warnings if there is an error
5246        # but we'd want to at least see them in stderr
5247
5248        class StdErrDiverter:
5249            def __enter__(self):
5250                self.stderr_orig = sys.stderr
5251                self.stderr_new = io.StringIO()
5252                sys.stderr = self.stderr_new
5253                return self
5254
5255            def __exit__(self, *args):
5256                self.captured = self.stderr_new.getvalue()
5257                sys.stderr = self.stderr_orig
5258
5259        # if the warnings don't throw, they will be handled as regular warnings
5260        with self.assertRaisesRegex(
5261            RuntimeError,
5262            "one of the variables needed for gradient computation has been "
5263            "modified by an inplace operation",
5264        ):
5265            with warnings.catch_warnings(record=True) as w:
5266                with detect_anomaly():
5267                    a = torch.randn(5, requires_grad=True)
5268                    d1 = a + 1
5269                    d2 = d1**2
5270                    d1 += 1
5271                    torch.autograd.grad(d2.sum(), a)
5272
5273        self.assertEqual(len(w), 2)
5274        self.assertIn("Anomaly Detection has been enabled", str(w[0].message))
5275        self.assertIn("Error detected in PowBackward0", str(w[1].message))
5276
5277        # if the warning throws, it will be printed to sys.stderr
5278        with self.assertRaisesRegex(
5279            RuntimeError,
5280            "one of the variables needed for gradient computation has been "
5281            "modified by an inplace operation",
5282        ):
5283            with warnings.catch_warnings(record=True) as w:
5284                with detect_anomaly():
5285                    warnings.simplefilter("error")
5286                    with StdErrDiverter() as s:
5287                        a = torch.randn(5, requires_grad=True)
5288                        d1 = a + 1
5289                        d2 = d1**2
5290                        d1 += 1
5291                        torch.autograd.grad(d2.sum(), a)
5292
5293        self.assertEqual(len(w), 1)
5294        self.assertIn("Anomaly Detection has been enabled", str(w[0].message))
5295        self.assertIn("Error detected in PowBackward0", s.captured)
5296
5297    def test_anomaly_assign_parent_cleanup(self):
5298        # Test that python objects created are properly cleaned up when assign_parent is called
5299
5300        def get_ref():
5301            # we use torch.exp here but any function that will construct a new node in its
5302            # backward call in grad mode will work
5303            x = torch.randn(2, 2, requires_grad=True)
5304            t = x.exp()
5305
5306            # ExpBackward calls mul, creating the MulBackward node when create_graph=True.
5307            # In anomaly mode, a PyObject referencing MulBackward's "parent" ExpBackward is added to
5308            # MulBackward's anomaly metadata dict, creating the following reference chain:
5309            #
5310            # grad -> MulBackward -> PyObject -> ExpBackward
5311            #
5312            with detect_anomaly():
5313                grad = torch.autograd.grad(t, x, torch.ones_like(t), create_graph=True)
5314
5315            # We add a weak reference to a new Foo object, which we insert into ExpBackward's metadata dict
5316            #
5317            # (PyObject) -> ExpBackward -> dict -> *Foo*
5318            #            t ----^        WeakRef ---^
5319            #
5320            # We want to test that when grad goes out of scope at the end of this function that PyObject is destroyed
5321            # We can test this by seeing whether Foo is not kept alive once t is destroyed
5322            class Foo:
5323                pass
5324
5325            my_obj = Foo()
5326            meta_dict = t.grad_fn.metadata
5327            meta_dict[0] = my_obj
5328            ref = weakref.ref(my_obj)
5329            return t, ref
5330
5331        t, ref = get_ref()
5332        self.assertIsNotNone(ref())
5333        del t
5334        self.assertIsNone(ref())
5335
5336    def test_nested_anomaly_printstack_cleanup(self):
5337        # Test if metadata dict PyObject is properly destroyed
5338        def get_ref():
5339            # This is similar to the construction in test_anomaly_assign_parent_cleanup:
5340            #
5341            # MyFuncBackward2 -> PyObject -> MyFuncBackward -> dict -> Foo
5342            #                               out ---^         WeakRef ---^
5343            #
5344            # We want to check that Foo is still properly destroyed even when MyFunc2Backward's
5345            # AnomalyMetadata calls printstack, which does some python object manipulation.
5346            #
5347            # You might be wondering why we still have to test_anomaly_assign_parent_cleanup,
5348            # since if PyObject is not destroyed here, wouldn't this test would detect that also?
5349            # The answer is that custom function's PyObject (THPFunction) actually only hold
5350            # a weak reference to the c++ node!
5351            class MyFunc(Function):
5352                @staticmethod
5353                def forward(ctx, x):
5354                    ctx.save_for_backward(x)
5355                    return x
5356
5357                @staticmethod
5358                def backward(ctx, gO):
5359                    (x,) = ctx.saved_tensors
5360                    return MyFunc2.apply(x)
5361
5362            class MyFunc2(Function):
5363                @staticmethod
5364                def forward(ctx, x):
5365                    return x
5366
5367                @staticmethod
5368                def backward(ctx, gO):
5369                    return gO + float("NaN")
5370
5371            inp = torch.rand(1, requires_grad=True)
5372            out = MyFunc.apply(inp)
5373            (ginp,) = torch.autograd.grad(out, (inp,), create_graph=True)
5374
5375            with warnings.catch_warnings(record=True) as w:
5376                with self.assertRaisesRegex(
5377                    RuntimeError,
5378                    "Function 'MyFunc2Backward' returned nan values in its 0th output.",
5379                ):
5380                    with detect_anomaly():
5381                        ginp.backward()
5382
5383            class Foo:
5384                pass
5385
5386            my_obj = Foo()
5387            meta_dict = out.grad_fn.metadata
5388            meta_dict[0] = my_obj
5389            ref = weakref.ref(my_obj)
5390            return out, ref
5391
5392        t, ref = get_ref()
5393        self.assertIsNotNone(ref())
5394        del t
5395        self.assertIsNone(ref())
5396
5397    def test_anomaly_mode_no_check_nan(self):
5398        class MyFunc(torch.autograd.Function):
5399            @staticmethod
5400            def forward(ctx, inp):
5401                return inp.clone()
5402
5403            @staticmethod
5404            def backward(ctx, gO):
5405                return torch.tensor(float("nan")).expand(10, 10)
5406
5407        def run_fn(a):
5408            out = MyFunc.apply(a)
5409            return out.sum()
5410
5411        with warnings.catch_warnings(record=True) as w:
5412            with torch.autograd.detect_anomaly(check_nan=False):
5413                inp = torch.rand(10, 10, requires_grad=True)
5414                out = run_fn(inp)
5415                out.backward(retain_graph=True)
5416
5417                with torch.autograd.detect_anomaly(check_nan=True):
5418                    with self.assertRaisesRegex(
5419                        RuntimeError,
5420                        "Function 'MyFuncBackward' returned nan values in its 0th output.",
5421                    ):
5422                        out.backward(retain_graph=True)
5423
5424                out.backward()
5425
5426    def test_no_grad_copy(self):
5427        # create autograd function that saves grad pointer as class static
5428        class MyFunc(Function):
5429            static_grad_ptr = None
5430
5431            @staticmethod
5432            def forward(ctx, inp1, inp2):
5433                return inp1 + inp2
5434
5435            @staticmethod
5436            def backward(ctx, grad):
5437                MyFunc.static_grad_ptr = grad.data_ptr()
5438                return grad, grad
5439
5440        class NonContGradFunc(Function):
5441            @staticmethod
5442            def forward(ctx, inp1):
5443                ctx.size = inp1.size()
5444                return torch.tensor([1.0])
5445
5446            @staticmethod
5447            def backward(ctx, grad):
5448                return torch.ones(1).expand(ctx.size)
5449
5450        a = torch.randn(5, 6, requires_grad=True)
5451        b = torch.randn(5, 6, requires_grad=True)
5452        # non-contiguous grad should be copied
5453        NonContGradFunc.apply(MyFunc.apply(a, b)).backward()
5454        self.assertFalse(a.grad.data_ptr() == MyFunc.static_grad_ptr)
5455        self.assertFalse(b.grad.data_ptr() == MyFunc.static_grad_ptr)
5456        # test case that should trigger no copy for one of a,b
5457        a.grad = b.grad = None
5458        MyFunc.apply(a, b)[1][0].backward()
5459        p_g = MyFunc.static_grad_ptr
5460        p_a = a.grad.data_ptr()
5461        p_b = b.grad.data_ptr()
5462        # check a,b uses different grad buffer
5463        self.assertFalse(p_a == p_b)
5464        # check one of them is using the computed buffer
5465        self.assertTrue(p_a == p_g or p_b == p_g)
5466
5467    def test_no_grad_copy_sparse(self):
5468        # create autograd function that saves grad pointer as class static
5469        class MyFunc(Function):
5470            static_grad_ptr = None
5471
5472            @staticmethod
5473            def forward(ctx, inp1, inp2):
5474                return inp1 + inp2
5475
5476            @staticmethod
5477            def backward(ctx, grad):
5478                MyFunc.static_grad_ptr = grad._values().data_ptr()
5479                return grad, grad
5480
5481        class NonContGradFunc(Function):
5482            static_grad_ptr = None
5483
5484            @staticmethod
5485            def forward(ctx, inp1, inp2):
5486                return inp1 + inp2
5487
5488            @staticmethod
5489            def backward(ctx, grad):
5490                # Create a sparse tensor with non-contigous indices and values
5491                # and return as grad.
5492                v = torch.rand(1, 3)
5493                i = torch.ones(1, 1, dtype=torch.long)
5494                nv = v.expand(8, 3)
5495                ni = i.expand(1, 8)
5496                ngrad = torch.sparse_coo_tensor(ni, nv, (10, 3), dtype=torch.float32)
5497                NonContGradFunc.static_grad_ptr = ngrad._values().data_ptr()
5498                return ngrad, ngrad
5499
5500        a = torch.randn(10, 3, requires_grad=True)
5501        b = torch.randn(10, 3, requires_grad=True)
5502        input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9])
5503        offsets = torch.tensor([0, 4])
5504        import torch.nn.functional as F
5505
5506        # test case that should trigger no copy for one of a,b
5507        emb_matrix = MyFunc.apply(a, b)
5508        loss = F.embedding_bag(emb_matrix, input, offsets, sparse=True).sum()
5509        loss.backward(retain_graph=True)
5510        p_g = MyFunc.static_grad_ptr
5511        p_a = a.grad._values().data_ptr()
5512        p_b = b.grad._values().data_ptr()
5513        # check a,b uses different grad buffer
5514        self.assertFalse(p_a == p_b)
5515        # check one of them is using the computed buffer
5516        self.assertTrue(p_a == p_g or p_b == p_g)
5517
5518        # Run backwards multiple times to ensure accumulation works.
5519        for i in range(10):
5520            loss.backward(retain_graph=True)
5521
5522        # non-contiguous indices and value, we should trigger a copy.
5523        a.grad = b.grad = None
5524        emb_matrix = NonContGradFunc.apply(a, b)
5525        loss = F.embedding_bag(emb_matrix, input, offsets, sparse=True).sum()
5526        loss.backward(retain_graph=True)
5527        p_g = NonContGradFunc.static_grad_ptr
5528        p_a = a.grad._values().data_ptr()
5529        p_b = b.grad._values().data_ptr()
5530        # check a,b uses different grad buffer
5531        self.assertFalse(p_a == p_b)
5532        # Verify we cloned both grads.
5533        self.assertFalse(p_a == p_g)
5534        self.assertFalse(p_b == p_g)
5535
5536        # Run backwards multiple times to ensure accumulation works.
5537        for i in range(10):
5538            loss.backward(retain_graph=True)
5539
5540    def test_gradcheck_single_input(self):
5541        def check(fast_mode):
5542            def f(inp):
5543                return inp.mul(5)
5544
5545            gradcheck(
5546                f,
5547                torch.rand(10, dtype=torch.float64, requires_grad=True),
5548                fast_mode=fast_mode,
5549            )
5550            gradgradcheck(
5551                f,
5552                torch.rand(10, dtype=torch.float64, requires_grad=True),
5553                fast_mode=fast_mode,
5554            )
5555
5556        check(fast_mode=True)
5557        check(fast_mode=False)
5558
5559    @parametrize(
5560        "layout",
5561        (
5562            torch.sparse_coo,
5563            torch.sparse_csr,
5564            torch.sparse_csc,
5565            torch.sparse_bsr,
5566            torch.sparse_bsc,
5567        ),
5568    )
5569    def test_gradcheck_input(self, layout):
5570        if layout in {torch.sparse_bsr, torch.sparse_bsc}:
5571            blocksize = (2, 2)
5572            size = (4, 8)
5573        else:
5574            blocksize = None
5575            size = (2, 2)
5576
5577        def check(fast_mode, masked):
5578            def fn(sparse):
5579                return torch.sum(sparse)
5580
5581            gradcheck(
5582                fn,
5583                torch.rand(size, dtype=torch.double)
5584                .to_sparse(layout=layout, blocksize=blocksize)
5585                .requires_grad_(),
5586                masked=masked,
5587                check_batched_grad=False,
5588                fast_mode=fast_mode,
5589            )
5590
5591        for fast_mode, masked in product(*[(True, False)] * 2):
5592            check(fast_mode=fast_mode, masked=masked)
5593
5594    def test_gradcheck_nondeterministic(self):
5595        class NonDetFunc(Function):
5596            @staticmethod
5597            def forward(ctx, x, jitter=0.0):
5598                ctx._jitter = jitter
5599                return x
5600
5601            @staticmethod
5602            def backward(ctx, grad_out):
5603                return (
5604                    NonDetFunc.apply(grad_out, ctx._jitter)
5605                    * (1 + torch.rand_like(grad_out) * ctx._jitter),
5606                    None,
5607                )
5608
5609        def check(fast_mode):
5610            inp = torch.randn(5, 5, dtype=torch.double, requires_grad=True)
5611            gradcheck(
5612                lambda x: NonDetFunc.apply(x, 0.0),
5613                inp,
5614                check_batched_grad=False,
5615                fast_mode=fast_mode,
5616            )
5617            with self.assertRaisesRegex(RuntimeError, "Backward is not reentrant"):
5618                gradcheck(
5619                    lambda x: NonDetFunc.apply(x, 1e-6),
5620                    inp,
5621                    check_batched_grad=False,
5622                    fast_mode=fast_mode,
5623                )
5624            with self.assertRaisesRegex(RuntimeError, "Backward is not reentrant"):
5625                gradgradcheck(
5626                    lambda x: NonDetFunc.apply(x, 1e-12),
5627                    inp,
5628                    check_batched_grad=False,
5629                    fast_mode=fast_mode,
5630                )
5631            gradcheck(
5632                lambda x: NonDetFunc.apply(x, 0.0),
5633                inp,
5634                nondet_tol=1e-5,
5635                check_batched_grad=False,
5636                fast_mode=fast_mode,
5637            )
5638            gradcheck(
5639                lambda x: NonDetFunc.apply(x, 1e-6),
5640                inp,
5641                nondet_tol=1e-5,
5642                check_batched_grad=False,
5643                fast_mode=fast_mode,
5644            )
5645            gradgradcheck(
5646                lambda x: NonDetFunc.apply(x, 1e-12),
5647                inp,
5648                nondet_tol=1e-5,
5649                check_batched_grad=False,
5650                fast_mode=fast_mode,
5651            )
5652
5653        check(fast_mode=True)
5654        check(fast_mode=False)
5655
5656    def test_gradcheck_validates_inputs(self):
5657        def check(fast_mode):
5658            x = torch.rand(10, requires_grad=True).to_sparse()
5659            self.assertTrue(
5660                gradcheck(
5661                    lambda x: x.to_dense(),
5662                    (x,),
5663                    check_batched_grad=False,
5664                    atol=1e-1,
5665                    fast_mode=fast_mode,
5666                    masked=True,
5667                )
5668            )
5669            self.assertFalse(
5670                gradcheck(
5671                    lambda x: x.to_dense(),
5672                    (x,),
5673                    masked=False,
5674                    check_batched_grad=False,
5675                    raise_exception=False,
5676                    fast_mode=fast_mode,
5677                )
5678            )
5679            self.assertTrue(
5680                gradcheck(
5681                    lambda x: x.to_dense(masked_grad=False),
5682                    (x,),
5683                    masked=False,
5684                    atol=1e-1,
5685                    check_batched_grad=False,
5686                    raise_exception=False,
5687                    fast_mode=fast_mode,
5688                )
5689            )
5690
5691            # when none of the inputs require grad (always raises even if raise_exception=False)
5692            x = torch.rand(10, requires_grad=False)
5693            with self.assertRaisesRegex(
5694                ValueError, "at least one input tensor to require gradient"
5695            ):
5696                gradcheck(lambda x: x, (x,), raise_exception=False, fast_mode=fast_mode)
5697
5698            # (warning) when inputs are not double precision
5699            x = torch.ones(1, dtype=torch.float32, requires_grad=True)
5700            with self.assertWarnsRegex(
5701                UserWarning, "Input #0 requires gradient and is not a double precision"
5702            ):
5703                self.assertTrue(
5704                    gradcheck(lambda x: x, (x,), atol=1e-1, fast_mode=fast_mode)
5705                )
5706
5707            # when layout is not mkldnn(aka has strides) and input has a dimension with stride 0. (always raises
5708            # even if raise_exception=False)
5709            x = torch.ones(1, dtype=torch.float64, requires_grad=True)
5710            x = x.expand((2, 2))
5711            with self.assertRaisesRegex(
5712                RuntimeError, "The 0th input has a dimension with stride 0"
5713            ):
5714                gradcheck(lambda x: x, (x,), raise_exception=False, fast_mode=fast_mode)
5715
5716        check(fast_mode=True)
5717        check(fast_mode=False)
5718
5719    @unittest.skipIf(
5720        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
5721    )
5722    def test_gradcheck_validates_input_mkldnn(self):
5723        # when mkldnn inputs, forward mode testing is not allowed
5724        # Update tolerances below to make sure the gradient match even in single precision floats
5725        # Use the warning assert to hide the float32 warning
5726        x = torch.ones(1).to_mkldnn().requires_grad_()
5727        with self.assertWarnsRegex(
5728            UserWarning, "Input #0 requires gradient and is not a double precision"
5729        ):
5730            with self.assertRaisesRegex(
5731                ValueError, "MKLDNN inputs are not support for forward AD gradcheck."
5732            ):
5733                gradcheck(
5734                    lambda x: x.to_dense(),
5735                    (x,),
5736                    raise_exception=False,
5737                    fast_mode=False,
5738                    check_forward_ad=True,
5739                    atol=1e-1,
5740                    rtol=1e-1,
5741                )
5742
5743        with self.assertWarnsRegex(
5744            UserWarning, "Input #0 requires gradient and is not a double precision"
5745        ):
5746            with self.assertRaisesRegex(
5747                ValueError, "MKLDNN inputs are not support for forward AD gradcheck."
5748            ):
5749                gradcheck(
5750                    lambda x: x.to_dense(),
5751                    (x,),
5752                    raise_exception=False,
5753                    fast_mode=True,
5754                    check_forward_ad=True,
5755                    atol=1e-1,
5756                    rtol=1e-1,
5757                )
5758
5759    @unittest.skipIf(
5760        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
5761    )
5762    def test_gradcheck_test_outputs(self):
5763        def check(fast_mode):
5764            # when sparse outputs (always raise even if raise_exception=False)
5765            x = torch.rand(10, requires_grad=True).to_sparse()
5766            with self.assertRaisesRegex(
5767                ValueError, "Sparse output is not supported at gradcheck yet"
5768            ):
5769                gradcheck(
5770                    lambda x: x,
5771                    (x,),
5772                    masked=True,
5773                    check_batched_grad=False,
5774                    raise_exception=False,
5775                    fast_mode=fast_mode,
5776                )
5777
5778            # when mkldnn outputs (always raise even if raise_exception=False)
5779            root = torch.randn(4, 5, dtype=torch.float32, requires_grad=True)
5780            with self.assertRaisesRegex(
5781                ValueError, "MKLDNN output is not supported at gradcheck yet"
5782            ):
5783                gradcheck(
5784                    lambda x: x.to_mkldnn(),
5785                    (root,),
5786                    check_batched_grad=False,
5787                    raise_exception=False,
5788                    fast_mode=fast_mode,
5789                )
5790
5791        check(fast_mode=True)
5792        check(fast_mode=False)
5793
5794    def test_gradcheck_check_no_differentiable_outputs(self):
5795        def check(fast_mode):
5796            # When none of the outputs are differentiable, but numerical gradient is not zero
5797            x = torch.ones((1,), requires_grad=True)
5798            with self.assertRaisesRegex(
5799                RuntimeError, "Numerical gradient for function expected to be zero"
5800            ):
5801                gradcheck(lambda x: torch.tensor([x]), x)
5802            self.assertFalse(
5803                gradcheck(
5804                    lambda x: torch.tensor([x]),
5805                    x,
5806                    raise_exception=False,
5807                    fast_mode=fast_mode,
5808                )
5809            )
5810
5811            # succeed when no outputs at all
5812            self.assertTrue(gradcheck(lambda x: (), (x,), fast_mode=fast_mode))
5813
5814        check(fast_mode=True)
5815        check(fast_mode=False)
5816
5817    def test_gradcheck_check_batched_grad(self):
5818        def check(fast_mode):
5819            x = torch.rand(10, dtype=torch.double, requires_grad=True).to_sparse()
5820            # runtime error while compute batched grad (print big error)
5821            with self.assertRaisesRegex(
5822                RuntimeError,
5823                "gradcheck or gradgradcheck failed while testing batched gradient",
5824            ):
5825                gradcheck(
5826                    lambda x: x.to_dense(),
5827                    (x,),
5828                    masked=True,
5829                    check_batched_grad=True,
5830                    fast_mode=fast_mode,
5831                )
5832            self.assertFalse(
5833                gradcheck(
5834                    lambda x: x.to_dense(),
5835                    (x,),
5836                    masked=True,
5837                    check_batched_grad=True,
5838                    raise_exception=False,
5839                    fast_mode=fast_mode,
5840                )
5841            )
5842
5843        check(fast_mode=True)
5844        check(fast_mode=False)
5845
5846    def test_gradcheck_backward_mul_by_grad_output(self):
5847        # when grad_input is sparse and has incorrect sparse_dim/dense_dim
5848        def check(fast_mode):
5849            def fn(x):
5850                def hook(grad):
5851                    if grad is not None:
5852                        return grad.to_dense().to_sparse(1)
5853                    return grad
5854
5855                y = x.clone()
5856                y.register_hook(hook)
5857                return y.to_dense()
5858
5859            x = torch.ones((2, 2), dtype=torch.double, requires_grad=True).to_sparse()
5860            with self.assertRaisesRegex(
5861                RuntimeError, "grad is sparse tensor, but has incorrect sparse_dim"
5862            ):
5863                gradcheck(
5864                    fn,
5865                    (x,),
5866                    atol=1e-1,
5867                    masked=True,
5868                    check_batched_grad=False,
5869                    fast_mode=fast_mode,
5870                )
5871            self.assertFalse(
5872                gradcheck(
5873                    fn,
5874                    (x,),
5875                    atol=1e-1,
5876                    masked=True,
5877                    check_batched_grad=False,
5878                    raise_exception=False,
5879                    fast_mode=fast_mode,
5880                )
5881            )
5882
5883            # when backward not multiplied by grad_output (non-sparse case)
5884            def fn2(x):
5885                y = x.clone()
5886                y.register_hook(lambda x: x + 1e-2)
5887                return y
5888
5889            x = torch.ones(1, dtype=torch.double, requires_grad=True)
5890            with self.assertRaisesRegex(
5891                RuntimeError, "backward not multiplied by grad_output"
5892            ):
5893                gradcheck(fn2, (x,), atol=1e-1, fast_mode=fast_mode)
5894            self.assertFalse(
5895                gradcheck(
5896                    fn2, (x,), atol=1e-1, raise_exception=False, fast_mode=fast_mode
5897                )
5898            )
5899
5900            # when backward not multiplied by grad_output (sparse case)
5901            def fn3(x):
5902                y = x.clone().to_dense()
5903                y.register_hook(lambda x: x + 1e-2)
5904                return y
5905
5906            x = torch.ones(1, dtype=torch.double, requires_grad=True).to_sparse()
5907            with self.assertRaisesRegex(
5908                RuntimeError, "backward not multiplied by grad_output"
5909            ):
5910                gradcheck(
5911                    fn3,
5912                    (x,),
5913                    atol=1e-1,
5914                    masked=True,
5915                    check_batched_grad=False,
5916                    fast_mode=fast_mode,
5917                )
5918            self.assertFalse(
5919                gradcheck(
5920                    fn3,
5921                    (x,),
5922                    atol=1e-1,
5923                    masked=True,
5924                    check_batched_grad=False,
5925                    raise_exception=False,
5926                    fast_mode=fast_mode,
5927                )
5928            )
5929
5930            # when layout of grad_input is not the same as input
5931            class Test(Function):
5932                @staticmethod
5933                def forward(ctx, x):
5934                    return x
5935
5936                @staticmethod
5937                def backward(ctx, x):
5938                    return x.to_sparse()
5939
5940            x = torch.ones(1, dtype=torch.double, requires_grad=True)
5941            with self.assertRaisesRegex(RuntimeError, "grad is incorrect layout"):
5942                gradcheck(
5943                    Test.apply, (x,), check_batched_grad=False, fast_mode=fast_mode
5944                )
5945            self.assertFalse(
5946                gradcheck(
5947                    Test.apply,
5948                    (x,),
5949                    check_batched_grad=False,
5950                    raise_exception=False,
5951                    fast_mode=fast_mode,
5952                )
5953            )
5954
5955        check(fast_mode=True)
5956        check(fast_mode=False)
5957
5958    def test_gradcheck_undefined_grad(self):
5959        def check(fast_mode):
5960            # when encounter runtime error while running backward
5961            def fn(x):
5962                def hook(x):
5963                    if x is None:
5964                        raise RuntimeError("x is undefined")
5965
5966                y = x.clone()
5967                y.register_hook(hook)
5968                return y
5969
5970            x = torch.ones(1, dtype=torch.double, requires_grad=True)
5971            with self.assertWarnsRegex(
5972                UserWarning,
5973                "Backwards compatibility: New undefined gradient support checking feature",
5974            ):
5975                with self.assertRaisesRegex(
5976                    RuntimeError,
5977                    "Expected backward function to handle undefined output grads",
5978                ):
5979                    gradcheck(fn, (x,), fast_mode=fast_mode)
5980                self.assertFalse(
5981                    gradcheck(fn, (x,), raise_exception=False, fast_mode=fast_mode)
5982                )
5983
5984        check(fast_mode=True)
5985        check(fast_mode=False)
5986
5987    def test_gradcheck_jacobian_mismatch(self):
5988        def check(fast_mode):
5989            def fn(x):  # R -> R, C -> C
5990                y = x.clone()
5991                y.register_hook(lambda x: x + 1e-2)
5992                return y
5993
5994            x = torch.ones(2, 2, requires_grad=True)
5995            with self.assertRaisesRegex(
5996                RuntimeError, "Jacobian mismatch for output 0 with respect to input 0"
5997            ):
5998                gradcheck(fn, (x,), fast_mode=fast_mode)
5999            self.assertFalse(
6000                gradcheck(fn, (x,), raise_exception=False, fast_mode=fast_mode)
6001            )
6002
6003            x_c = torch.ones(2, 2, requires_grad=True, dtype=torch.complex128)
6004            with self.assertRaisesRegex(
6005                RuntimeError,
6006                "While considering the imaginary part of complex outputs only",
6007            ):
6008                gradcheck(fn, (x_c,), fast_mode=False)
6009            self.assertFalse(
6010                gradcheck(fn, (x_c,), raise_exception=False, fast_mode=False)
6011            )
6012
6013            def fn2(x):  # R -> C
6014                y = torch.complex(x, x)
6015                y.register_hook(lambda x: x + 1e-2)
6016                return y
6017
6018            x = torch.ones(2, 2, requires_grad=True)
6019            with self.assertRaisesRegex(
6020                RuntimeError,
6021                "While considering the imaginary part of complex outputs only",
6022            ):
6023                gradcheck(fn2, (x,), fast_mode=False)
6024            self.assertFalse(
6025                gradcheck(fn2, (x,), raise_exception=False, fast_mode=False)
6026            )
6027
6028            def fn3(x):  # C -> R
6029                y = torch.real(x)
6030                y.register_hook(lambda x: x + 1e-2)
6031                return y
6032
6033            with self.assertRaisesRegex(
6034                RuntimeError, "Jacobian mismatch for output 0 with respect to input 0"
6035            ):
6036                gradcheck(fn3, (x_c,), fast_mode=False)
6037            self.assertFalse(
6038                gradcheck(fn3, (x_c,), raise_exception=False, fast_mode=False)
6039            )
6040
6041        check(fast_mode=True)
6042        check(fast_mode=False)
6043
6044    def test_gradcheck_dense_and_sparse_inputs(self):
6045        def check(fast_mode):
6046            def fn(x, y):
6047                return x * y.coalesce().to_dense()
6048
6049            a = torch.rand(2, 2, dtype=torch.double, requires_grad=True)
6050            b = torch.rand(2, 2, dtype=torch.double).to_sparse().requires_grad_(True)
6051            self.assertTrue(
6052                gradcheck(
6053                    fn,
6054                    (a, b),
6055                    masked=True,
6056                    check_batched_grad=False,
6057                    fast_mode=fast_mode,
6058                )
6059            )
6060
6061        check(fast_mode=True)
6062        check(fast_mode=False)
6063
6064    @unittest.skipIf(
6065        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
6066    )
6067    def test_gradcheck_multiple_mkldnn_inputs(self):
6068        def check(fast_mode):
6069            def fn(x, y):
6070                return x + y.to_dense()
6071
6072            a = torch.rand(10, requires_grad=True)
6073            b = torch.rand(10, dtype=torch.float32).to_mkldnn().requires_grad_(True)
6074            self.assertTrue(
6075                gradcheck(
6076                    fn, (a, b), atol=1e-1, check_batched_grad=False, fast_mode=fast_mode
6077                )
6078            )
6079
6080            def fn2(x, y):
6081                return x.to_dense() + y.to_dense()
6082
6083            c = torch.rand(10, dtype=torch.float32).to_mkldnn().requires_grad_(True)
6084            self.assertTrue(
6085                gradcheck(
6086                    fn, (a, c), atol=1e-1, check_batched_grad=False, fast_mode=fast_mode
6087                )
6088            )
6089
6090        check(fast_mode=True)
6091        check(fast_mode=False)
6092
6093    def test_gradcheck_output_shape_or_dtype_depend_on_values(self):
6094        def check(fast_mode):
6095            def fn(x):
6096                if torch.all(x >= 1):
6097                    return torch.cat([x, x])
6098                else:
6099                    return x
6100
6101            a = torch.ones(1, dtype=torch.double, requires_grad=True)
6102            with self.assertRaisesRegex(
6103                AssertionError,
6104                "return outputs with the same shape when inputs are perturbed",
6105            ):
6106                self.assertTrue(gradcheck(fn, (a,), fast_mode=fast_mode))
6107
6108            def fn2(x):
6109                if torch.all(x >= 1):
6110                    return x.to(torch.float32)
6111                else:
6112                    return x
6113
6114            with self.assertRaisesRegex(
6115                AssertionError,
6116                "return outputs with the same dtype when inputs are perturbed",
6117            ):
6118                self.assertTrue(gradcheck(fn2, (a,), fast_mode=fast_mode))
6119
6120        check(fast_mode=True)
6121        check(fast_mode=False)
6122
6123    def test_gradcheck_complex_non_complex_outputs(self):
6124        def fn(x, y):
6125            z = torch.complex(x, y)
6126            return z, x + 1
6127
6128        a = torch.ones(2, 2, requires_grad=True, dtype=torch.float64)
6129        b = torch.ones(2, 2, requires_grad=True, dtype=torch.float64)
6130        self.assertTrue(gradcheck(fn, (a, b)))
6131
6132        def fn2(z):
6133            return z, torch.real(z)
6134
6135        c = torch.ones(2, 2, requires_grad=True, dtype=torch.complex128)
6136        self.assertTrue(gradcheck(fn2, (c)))
6137
6138    def test_gradcheck_get_numerical_jacobian(self):
6139        # get_numerical_jacobian is deprecated and no longer used internally by gradcheck
6140        from torch.autograd.gradcheck import get_numerical_jacobian
6141
6142        def fn(inputs):
6143            # get_numerical_jacobian requires fn to take inputs as a tuple
6144            # and returns the jacobian wrt the first output
6145            x = inputs[0]
6146            y = inputs[1]
6147            return 2 * x + y, x + 2 * y
6148
6149        a = torch.rand(2, 2, requires_grad=True, dtype=torch.float64)
6150        b = torch.rand(2, 2, requires_grad=True, dtype=torch.float64)
6151
6152        with self.assertWarnsRegex(
6153            FutureWarning, "`get_numerical_jacobian` was part of PyTorch's private API"
6154        ):
6155            jacobian = get_numerical_jacobian(fn, (a, b), target=a, eps=1e-6)
6156        self.assertEqual(jacobian[0], 2 * torch.eye(4, dtype=torch.double))
6157
6158        with self.assertWarnsRegex(
6159            FutureWarning, "`get_numerical_jacobian` was part of PyTorch's private API"
6160        ):
6161            jacobian = get_numerical_jacobian(fn, (a, b), eps=1e-6)
6162        self.assertEqual(jacobian[0], 2 * torch.eye(4, dtype=torch.double))
6163        self.assertEqual(jacobian[1], 1 * torch.eye(4, dtype=torch.double))
6164
6165        with self.assertRaisesRegex(ValueError, "Expected grad_out to be 1.0"):
6166            jacobian = get_numerical_jacobian(fn, (a, b), eps=1e-6, grad_out=2.0)
6167
6168    def test_gradcheck_get_analytical_jacobian(self):
6169        from torch.autograd.gradcheck import get_analytical_jacobian
6170
6171        def fn(x, y):
6172            return 2 * x + y, x + 2 * y
6173
6174        a = torch.rand(2, 2, requires_grad=True, dtype=torch.float64)
6175        b = torch.rand(2, 2, requires_grad=True, dtype=torch.float64)
6176
6177        outputs = fn(a, b)
6178        with self.assertWarnsRegex(
6179            FutureWarning, "`get_analytical_jacobian` was part of PyTorch's private API"
6180        ):
6181            (
6182                jacobians,
6183                reentrant,
6184                correct_grad_sizes,
6185                correct_grad_types,
6186            ) = get_analytical_jacobian((a, b), outputs[0])
6187        self.assertEqual(jacobians[0], 2 * torch.eye(4, dtype=torch.double))
6188        self.assertEqual(jacobians[1], 1 * torch.eye(4, dtype=torch.double))
6189        self.assertTrue(reentrant)
6190
6191        class NonDetFunc(Function):
6192            @staticmethod
6193            def forward(ctx, x, jitter=0.0):
6194                ctx._jitter = jitter
6195                return x
6196
6197            @staticmethod
6198            def backward(ctx, grad_out):
6199                return (
6200                    NonDetFunc.apply(grad_out, ctx._jitter)
6201                    * (1 + torch.rand_like(grad_out) * ctx._jitter),
6202                    None,
6203                )
6204
6205        outputs = NonDetFunc.apply(a, 1e-6)
6206        with self.assertWarnsRegex(
6207            FutureWarning, "`get_analytical_jacobian` was part of PyTorch's private API"
6208        ):
6209            (
6210                jacobians,
6211                reentrant,
6212                correct_grad_sizes,
6213                correct_grad_types,
6214            ) = get_analytical_jacobian((a,), outputs)
6215        self.assertFalse(reentrant)
6216
6217        with self.assertRaisesRegex(ValueError, "Expected grad_out to be 1.0"):
6218            jacobians, _, _, _ = get_analytical_jacobian((a,), outputs, grad_out=2.0)
6219
6220    def test_gradcheck_custom_error(self):
6221        from torch.autograd.gradcheck import GradcheckError
6222
6223        def check(fast_mode):
6224            def fn(x):
6225                y = x.clone()
6226                y.register_hook(lambda x: x + 1e-2)
6227                return y
6228
6229            x = torch.ones(2, 2, requires_grad=True)
6230            with self.assertRaisesRegex(
6231                GradcheckError, "Jacobian mismatch for output 0 with respect to input 0"
6232            ):
6233                gradcheck(fn, (x,), fast_mode=fast_mode)
6234            with self.assertRaisesRegex(
6235                RuntimeError, "Jacobian mismatch for output 0 with respect to input 0"
6236            ):
6237                gradcheck(fn, (x,), fast_mode=fast_mode)
6238            self.assertFalse(
6239                gradcheck(fn, (x,), raise_exception=False, fast_mode=fast_mode)
6240            )
6241
6242            def fn2(x):
6243                raise RuntimeError("Not a GradcheckError!")
6244
6245            # Checks that when raise_exception=False, non-GradcheckErrors are not caught by gradcheck
6246            with self.assertRaisesRegex(RuntimeError, "Not a GradcheckError!"):
6247                gradcheck(fn2, (x,), fast_mode=fast_mode, raise_exception=False)
6248
6249        check(fast_mode=True)
6250        check(fast_mode=False)
6251
6252    def test_gradcheck_forward_ad(self):
6253        def fn(x, y):
6254            return x + y, y
6255
6256        def bad_fn(x, y):
6257            # Hacky way to check if we're currently inside a forward ad level
6258            is_running_forward_ad = fwAD._current_level >= 0
6259
6260            if is_running_forward_ad:
6261                y_p, y_d = fwAD.unpack_dual(y)
6262                y = fwAD.make_dual(y_p, y_d * 1.1)
6263
6264            return x + y, y
6265
6266        err_msg = "Jacobian computed with forward mode mismatch for output 0 with respect to input 1"
6267
6268        for fast_mode in [True, False]:
6269            # Test for all inputs and outputs being real
6270            x = torch.rand(2, dtype=torch.double, requires_grad=True)
6271            y = torch.rand(2, dtype=torch.double, requires_grad=True)
6272
6273            gradcheck(fn, (x, y), check_forward_ad=True, fast_mode=fast_mode)
6274            with self.assertRaisesRegex(RuntimeError, err_msg):
6275                gradcheck(bad_fn, (x, y), check_forward_ad=True, fast_mode=fast_mode)
6276
6277            def basic_mul(x):
6278                return torch.view_as_real(torch.resolve_conj(x * 1j))
6279
6280            gradcheck(basic_mul, x, check_forward_ad=True, fast_mode=fast_mode)
6281
6282            # Test for one input and one output being complex
6283            x = torch.rand(2, dtype=torch.cdouble, requires_grad=True)
6284
6285            gradcheck(fn, (x, y), check_forward_ad=True, fast_mode=fast_mode)
6286            with self.assertRaisesRegex(RuntimeError, err_msg):
6287                gradcheck(bad_fn, (x, y), check_forward_ad=True, fast_mode=fast_mode)
6288
6289            # Test for all inputs and outputs being complex
6290            y = torch.rand(2, dtype=torch.cdouble, requires_grad=True)
6291
6292            gradcheck(fn, (x, y), check_forward_ad=True, fast_mode=fast_mode)
6293            with self.assertRaisesRegex(RuntimeError, err_msg):
6294                gradcheck(bad_fn, (x, y), check_forward_ad=True, fast_mode=fast_mode)
6295
6296    def test_gradcheck_forward_ad_runs_with_no_requires_grad(self):
6297        # Currently requires_grad is used as a easy way for gradcheck to know
6298        # which inputs of the function are meant to be differentiable
6299        # This test checks that when the inputs are passed to the function they should not have
6300        # requires_grad=True even though they may have requires_grad=True when passed
6301        # to gradcheck
6302        class UserFn(Function):
6303            @staticmethod
6304            def forward(ctx, x, y):
6305                if fwAD._current_level >= 0:
6306                    self.assertFalse(x.requires_grad)
6307                    self.assertFalse(y.requires_grad)
6308                return x.clone(), y.clone()
6309
6310            @staticmethod
6311            def jvp(ctx, x_t, y_t):
6312                return x_t, y_t
6313
6314        x = torch.rand(2, dtype=torch.double, requires_grad=True)
6315        y = torch.rand(2, dtype=torch.double, requires_grad=True)
6316
6317        gradcheck(
6318            UserFn.apply,
6319            (x, y),
6320            check_forward_ad=True,
6321            check_undefined_grad=False,
6322            check_backward_ad=False,
6323            check_batched_grad=False,
6324            check_batched_forward_grad=False,
6325        )
6326
6327        gradcheck(
6328            UserFn.apply,
6329            (x, y),
6330            check_forward_ad=True,
6331            check_undefined_grad=True,
6332            check_backward_ad=False,
6333            check_batched_grad=False,
6334            check_batched_forward_grad=False,
6335        )
6336
6337        gradcheck(
6338            UserFn.apply,
6339            (x, y),
6340            check_forward_ad=True,
6341            check_undefined_grad=True,
6342            check_backward_ad=False,
6343            check_batched_grad=False,
6344            check_batched_forward_grad=True,
6345        )
6346
6347        x = torch.rand(2, dtype=torch.double, requires_grad=True)
6348        y = torch.rand(2, dtype=torch.double, requires_grad=False)
6349        gradcheck(
6350            UserFn.apply,
6351            (x, y),
6352            check_forward_ad=True,
6353            check_undefined_grad=True,
6354            check_backward_ad=False,
6355            check_batched_grad=False,
6356            check_batched_forward_grad=True,
6357        )
6358
6359    def test_gradcheck_forward_ad_respects_requires_grad(self):
6360        # Currently requires_grad is used as a easy way for gradcheck to know
6361        # which inputs of the function are meant to be differentiable
6362        jvp_count = [0]
6363
6364        class UserFn(Function):
6365            @staticmethod
6366            def forward(ctx, x, y):
6367                return x.clone(), y.clone()
6368
6369            @staticmethod
6370            def jvp(ctx, x_t, y_t):
6371                jvp_count[0] += 1
6372                return x_t, y_t
6373
6374        # NB: In slow gradcheck we need to loop through numel times so use numel = 1 to ensure
6375        #     that fast and slow have the same counts
6376        x = torch.rand(1, dtype=torch.double, requires_grad=True)
6377        y = torch.rand(1, dtype=torch.double, requires_grad=True)
6378        gradcheck(
6379            UserFn.apply,
6380            (x, y),
6381            check_forward_ad=True,
6382            check_undefined_grad=False,
6383            check_backward_ad=False,
6384            check_batched_grad=False,
6385            check_batched_forward_grad=False,
6386        )
6387        self.assertEqual(jvp_count[0], 2)  # (2) once per input
6388        jvp_count = [0]
6389
6390        gradcheck(
6391            UserFn.apply,
6392            (x, y),
6393            check_forward_ad=True,
6394            check_undefined_grad=True,
6395            check_backward_ad=False,
6396            check_batched_grad=False,
6397            check_batched_forward_grad=False,
6398        )
6399        self.assertEqual(
6400            jvp_count[0], 6
6401        )  # (+4): (once with normal ZT (+1), once with efficient ZT (+1)) for each input (x2)
6402        jvp_count = [0]
6403
6404        gradcheck(
6405            UserFn.apply,
6406            (x, y),
6407            check_forward_ad=True,
6408            check_undefined_grad=True,
6409            check_backward_ad=False,
6410            check_batched_grad=False,
6411            check_batched_forward_grad=True,
6412        )
6413        self.assertEqual(
6414            jvp_count[0], 12
6415        )  # (+6): (compute batch of 2 with vmap (+1), with a loop (+2)) for each input (x2)
6416        jvp_count = [0]
6417
6418        # Repeat the previous test except we mark one input with requires_grad=False
6419        # NB: _test_undefined_forward_mode is only (+1), when function has single differentiable input, not (+2)!
6420        #     Otherwise, other counts are halved.
6421        x = torch.rand(1, dtype=torch.double, requires_grad=True)
6422        y = torch.rand(1, dtype=torch.double, requires_grad=False)
6423        gradcheck(
6424            UserFn.apply,
6425            (x, y),
6426            check_forward_ad=True,
6427            check_undefined_grad=True,
6428            check_backward_ad=False,
6429            check_batched_grad=False,
6430            check_batched_forward_grad=True,
6431        )
6432        self.assertEqual(jvp_count[0], 5)  # 1 + 1 + 3
6433
6434    def test_gradcheck_check_forward_or_backward_only(self):
6435        """Depending on settings for check_forward_ad and check_backward_ad, the
6436        correct codepaths should be reached (or not reached)
6437        """
6438        fwd_fail_err_msg = "FAIL FWD"
6439        bwd_fail_err_msg = "FAIL BWD"
6440
6441        class UserFn(Function):
6442            @staticmethod
6443            def forward(ctx, foo, fwd_bad, bwd_bad):
6444                ctx.fwd_bad = fwd_bad
6445                ctx.bwd_bad = bwd_bad
6446                return foo * 2
6447
6448            @staticmethod
6449            def vjp(ctx, gO):
6450                if ctx.bwd_bad:
6451                    raise RuntimeError(bwd_fail_err_msg)
6452                else:
6453                    return 2 * gO, None, None
6454
6455            @staticmethod
6456            def jvp(ctx, gI, _1, _2):
6457                if ctx.fwd_bad:
6458                    raise RuntimeError(fwd_fail_err_msg)
6459                else:
6460                    return 2 * gI
6461
6462        for fast_mode in (True, False):
6463            for check_forward_ad in (True, False):
6464                for check_backward_ad in (True, False):
6465                    for fwd_bad in (True, False):
6466                        for bwd_bad in (True, False):
6467                            fwd_should_fail = fwd_bad and check_forward_ad
6468                            bwd_should_fail = bwd_bad and check_backward_ad
6469
6470                            def run():
6471                                gradcheck(
6472                                    UserFn.apply,
6473                                    (x, fwd_bad, bwd_bad),
6474                                    check_forward_ad=check_forward_ad,
6475                                    check_backward_ad=check_backward_ad,
6476                                    check_undefined_grad=check_backward_ad,
6477                                    check_batched_grad=check_backward_ad,
6478                                    fast_mode=fast_mode,
6479                                )
6480
6481                            x = torch.rand(2, dtype=torch.double, requires_grad=True)
6482
6483                            if not check_forward_ad and not check_backward_ad:
6484                                with self.assertRaisesRegex(
6485                                    AssertionError, "Expected at least one of"
6486                                ):
6487                                    run()
6488                                continue
6489
6490                            if not fwd_should_fail and not bwd_should_fail:
6491                                run()
6492                            else:
6493                                # If both fail, backward AD failure "hides" forward AD failure
6494                                if fwd_should_fail:
6495                                    fail_msg = fwd_fail_err_msg
6496                                if bwd_should_fail:
6497                                    fail_msg = bwd_fail_err_msg
6498                                with self.assertRaisesRegex(RuntimeError, fail_msg):
6499                                    run()
6500
6501    def test_gradcheck_forward_ad_batched_grad(self):
6502        x = torch.rand(2, dtype=torch.double, requires_grad=True)
6503
6504        # multiple inputs and outputs with non-tensors inputs
6505        def fn1(a: torch.Tensor, b: int):
6506            return a.clone(), a + 1
6507
6508        gradcheck(
6509            fn1,
6510            (x, 1),
6511            check_forward_ad=True,
6512            check_backward_ad=False,
6513            check_batched_grad=False,
6514            check_undefined_grad=False,
6515            check_batched_forward_grad=True,
6516        )
6517
6518        # unrelated inputs: tangent for c is None
6519        def fn2(a: torch.Tensor, c: torch.Tensor):
6520            return a.clone()
6521
6522        gradcheck(
6523            fn2,
6524            (x, x.clone()),
6525            check_forward_ad=True,
6526            check_backward_ad=False,
6527            check_batched_grad=False,
6528            check_undefined_grad=False,
6529            check_batched_forward_grad=True,
6530        )
6531
6532        class Fn(Function):
6533            @staticmethod
6534            def forward(ctx, foo):
6535                return foo * 2
6536
6537            @staticmethod
6538            def vjp(ctx, gO):
6539                return gO * 2
6540
6541            @staticmethod
6542            def jvp(ctx, gI):
6543                torch.randn_like(gI)
6544                return gI * 2
6545
6546        msg = "vmap: We do not yet support calling random operations inside of vmap"
6547        with self.assertRaisesRegex(RuntimeError, msg):
6548            gradcheck(
6549                Fn.apply, (x,), check_forward_ad=True, check_batched_forward_grad=True
6550            )
6551
6552    def test_version_counter(self):
6553        x = torch.randn(1, 2)
6554
6555        # In-place op bumps version
6556        x_saved_version = x._version
6557        x.add_(1).add_(1)
6558        self.assertTrue(x._version > x_saved_version)
6559
6560        # Differentiable view shares version counter
6561        xz = x[:]
6562        self.assertTrue(x._version == xz._version)
6563        xz.add_(1)
6564        self.assertTrue(x._version == xz._version)
6565
6566        # `x.data = y` preserves version counter of `x`
6567        x_saved_version = x._version
6568        x.data = torch.randn(2, 3)
6569        self.assertTrue(x._version == x_saved_version)
6570        x.add_(1)
6571        self.assertTrue(x._version > x_saved_version)
6572        # Make sure `x` is still using the same version counter it shares with `xz`
6573        self.assertTrue(x._version == xz._version)
6574
6575        # In-place op on `xz` also updates version of `x`,
6576        # because they share the version counter
6577        xz.add_(1)
6578        self.assertTrue(x._version == xz._version)
6579
6580    def test_set_data_tensorimpl_type(self):
6581        # Dense tensor has impl of type `TensorImpl`, while sparse tensor has impl
6582        # of type `SparseTensorImpl`.
6583        x = torch.randn(1, 2)
6584        x_s = torch.sparse_coo_tensor(torch.zeros([1, 1]), torch.ones([1]))
6585        with self.assertRaisesRegex(RuntimeError, "incompatible tensor type"):
6586            x.data = x_s
6587
6588    def test_set_data_preserve_pyobj(self):
6589        a = torch.randn(1, 2)
6590        b = torch.randn(1, 2)
6591        b_id_saved = id(b)
6592        b.data = a
6593        self.assertTrue(b_id_saved == id(b))
6594
6595    def test_set_data_self_requires_grad(self):
6596        a = torch.tensor(1.0, requires_grad=True)
6597        b = torch.tensor(2.0)
6598        c = torch.tensor(3, dtype=torch.int64)
6599        a.data = b
6600        with self.assertRaisesRegex(
6601            RuntimeError, "must be floating point or complex dtype"
6602        ):
6603            a.data = c
6604
6605    @unittest.skipIf(IS_WINDOWS, "Skipping because doesn't work for windows")
6606    def test_thread_shutdown(self):
6607        code = """import torch
6608from torch.autograd import Function
6609class MyFunction(Function):
6610    @staticmethod
6611    def forward(ctx, x):
6612        return x
6613
6614    @staticmethod
6615    def backward(ctx, grad):
6616        return grad
6617
6618# Run on cuda if it is available to ensure that the worker thread
6619# is properly initialized by the time we exit.
6620device = "cuda" if torch.cuda.is_available() else "cpu"
6621
6622for shape in [(1,), ()]:
6623    v = torch.ones(shape, requires_grad=True, device=device)
6624    MyFunction.apply(v).backward()
6625"""
6626        s = TestCase.runWithPytorchAPIUsageStderr(code)
6627        # The autograd engine creates worker threads only when GPU devices are present.
6628        # So make sure that we do shutdown threads when we're testing cuda and make sure
6629        # that there is no thread to shutdown when we're not using cuda.
6630        if TEST_CUDA or torch.backends.mps.is_available() or torch.xpu.is_available():
6631            self.assertRegex(s, "PYTORCH_API_USAGE torch.autograd.thread_shutdown")
6632        else:
6633            self.assertNotRegex(s, "PYTORCH_API_USAGE torch.autograd.thread_shutdown")
6634
6635    @unittest.skipIf(
6636        IS_MACOS,
6637        "Fails with SIGBUS on macOS; https://github.com/pytorch/pytorch/issues/25941",
6638    )
6639    def test_deep_reentrant(self):
6640        class DeepReentrant(Function):
6641            @staticmethod
6642            def forward(ctx, x):
6643                with torch.enable_grad():
6644                    ctx.x = Variable(x.detach(), requires_grad=True)
6645                    ctx.x = ctx.x - 1
6646                return ctx.x.detach()
6647
6648            @staticmethod
6649            def backward(ctx, x):
6650                if ctx.x < 0:
6651                    return x
6652                with torch.enable_grad():
6653                    DeepReentrant.apply(ctx.x).sum().backward()
6654                return x
6655
6656        # Test stack overflow escape mechanism
6657        v = torch.tensor(2000.0, requires_grad=True)
6658        # This will cause stack overflow if reentrant calls are handled
6659        # in the same thread recursively
6660        DeepReentrant.apply(v).sum().backward()
6661
6662        # Test stack overflow escape mechanism multiple times
6663        # to ensure reusing workers in the pool works fine
6664        v2 = torch.tensor(200.0, requires_grad=True)
6665        DeepReentrant.apply(v2).sum().backward()
6666
6667    def test_reentrant_priority(self):
6668        order = []
6669
6670        class MyFunction(Function):
6671            @staticmethod
6672            def forward(ctx, x):
6673                return x
6674
6675            @staticmethod
6676            def backward(ctx, x):
6677                order.append("MyFunction")
6678                return x
6679
6680        class Reentrant(Function):
6681            @staticmethod
6682            def forward(ctx, x):
6683                with torch.enable_grad():
6684                    ctx.x = Variable(x.detach(), requires_grad=True)
6685                    ctx.x = ctx.x - 1
6686                return ctx.x.detach()
6687
6688            @staticmethod
6689            def backward(ctx, x):
6690                order.append("Reentrant")
6691                if ctx.x < 0:
6692                    return x
6693                with torch.enable_grad():
6694                    Reentrant.apply(ctx.x).backward()
6695                return x
6696
6697        a = MyFunction.apply(torch.tensor(6.0, requires_grad=True))
6698        b = Reentrant.apply(torch.tensor(9.0, requires_grad=True))
6699        v = a * b
6700        v.backward()
6701        # The tasks for the Reentrant and MyFunction backward() will be added
6702        # to the queue in the autograd engine at the same time. The backward
6703        # for Reentrant will be executed first, which will then add other
6704        # backward tasks to the queue. We want to ensure all the reentrant tasks
6705        # are prioritized over the MyFunction backward task regardless of their
6706        # sequence numbers
6707        self.assertEqual(len(order), 11)
6708        self.assertEqual(order.count("Reentrant"), 10)
6709        self.assertEqual(order[-1], "MyFunction")
6710
6711    @slowTest
6712    def test_checkpointing(self):
6713        num_inp = 2000
6714        nz_inp = 10
6715        nz_out = 10
6716        nz_bottleneck = 1000
6717
6718        # small proxy network for some complex reasoning we want to do per input
6719        module = nn.Sequential(
6720            nn.Linear(nz_inp, nz_bottleneck),
6721            nn.ReLU(),
6722            nn.Linear(nz_bottleneck, nz_inp),
6723        )
6724
6725        feat_combined = []
6726        for r in range(num_inp):
6727            data_r = torch.empty(1, nz_inp)
6728            data_r.uniform_()
6729            data_r.requires_grad = True
6730            feat_r = checkpoint(module, data_r, use_reentrant=True)
6731            feat_combined.append(feat_r)
6732
6733        # compute mean as a proxy for some joint reasoning
6734        mean_combined = torch.stack(feat_combined).mean()
6735        mean_combined.backward()
6736
6737    def _test_checkpointing_non_reentrant_autocast(self, device_type):
6738        for enabled in [True, False]:
6739
6740            def foo(x, y, z):
6741                # torch.mm is on autocast's list of ops that should run in
6742                # the autocast precision
6743                x = torch.mm(x, y)
6744                y = torch.mm(x, z)
6745                z = torch.mm(z, z)
6746                expected_dtype = torch.float32 if not enabled else torch.bfloat16
6747                self.assertEqual(expected_dtype, z.dtype)
6748                return z
6749
6750            x = torch.randn(3, 3, requires_grad=True)
6751            y = torch.randn(3, 3, requires_grad=True)
6752            z = torch.randn(3, 3, requires_grad=True)
6753            if device_type == "cuda":
6754                x = x.cuda()
6755                y = y.cuda()
6756                z = z.cuda()
6757
6758            with torch.autocast(
6759                enabled=enabled, device_type=device_type, dtype=torch.bfloat16
6760            ):
6761                loss = checkpoint(foo, x, y, z, use_reentrant=False)
6762                loss = loss.sum()
6763
6764            # Without saving + recasting the autocast type, would raise error in autograd
6765            # about mismatched dtypes.
6766            loss.backward()  # triggers recomputation to check it runs in bfloat
6767
6768    def test_checkpointing_non_reentrant_autocast_cpu(self):
6769        """
6770        Test that autocast args such as the dtype are preserved during non-reentrant
6771        checkpoint recomputation on CPU.
6772        """
6773        self._test_checkpointing_non_reentrant_autocast(device_type="cpu")
6774
6775    @unittest.skipIf(
6776        not torch.cuda.is_available() or not torch.cuda.is_bf16_supported(),
6777        "Test requires CUDA bf16 support",
6778    )
6779    def test_checkpointing_non_reentrant_autocast_gpu(self):
6780        """
6781        Test that autocast args/kwargs such as the dtype are preserved during
6782        non-reentrant checkpoint recomputation on GPU.
6783        """
6784        self._test_checkpointing_non_reentrant_autocast(device_type="cuda")
6785
6786    @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
6787    @slowTest
6788    def test_checkpointing_without_reentrant_memory_savings(self):
6789        class MyModel(nn.Module):
6790            def __init__(self, n, use_checkpoint, use_reentrant):
6791                super().__init__()
6792                self.n = n
6793                self.use_checkpoint = use_checkpoint
6794                self.use_reentrant = use_reentrant
6795                self.layers = nn.ModuleList()
6796                for i in range(self.n):
6797                    layer = nn.Sequential(
6798                        nn.Linear(256, 256), nn.Linear(256, 256), nn.Linear(256, 256)
6799                    )
6800                    self.layers.append(layer)
6801                # pre-allocate the grad so that increased memory usage is mainly
6802                # due to activations.
6803                for layer in self.layers:
6804                    for lin in layer:
6805                        lin.weight.grad = torch.ones_like(lin.weight)
6806                        lin.bias.grad = torch.ones_like(lin.bias)
6807
6808            def forward(self, x):
6809                for i in range(self.n):
6810                    if not self.use_checkpoint:
6811                        x = self.layers[i](x)
6812                    else:
6813                        x = checkpoint(
6814                            self.layers[i], x, use_reentrant=self.use_reentrant
6815                        )
6816
6817                return x
6818
6819        model_no_checkpoint = MyModel(
6820            8, use_checkpoint=False, use_reentrant=False
6821        ).cuda()
6822        model_reentrant_checkpoint = MyModel(
6823            8, use_checkpoint=True, use_reentrant=True
6824        ).cuda()
6825        model_no_reentrant_checkpoint = MyModel(
6826            8, use_checkpoint=True, use_reentrant=False
6827        ).cuda()
6828
6829        x = torch.randn(100, 256, requires_grad=True, device="cuda")
6830
6831        torch.cuda.reset_peak_memory_stats()
6832        loss = model_no_checkpoint(x.clone()).sum()
6833        loss.backward()
6834        mem_no_checkpoint = torch.cuda.max_memory_allocated()
6835
6836        torch.cuda.reset_peak_memory_stats()
6837        loss = model_reentrant_checkpoint(x.clone()).sum()
6838        loss.backward()
6839        mem_reentrant_checkpoint = torch.cuda.max_memory_allocated()
6840
6841        torch.cuda.reset_peak_memory_stats()
6842        loss = model_no_reentrant_checkpoint(x.clone()).sum()
6843        loss.backward()
6844        mem_no_reentrant_checkpoint = torch.cuda.max_memory_allocated()
6845
6846        self.assertTrue(mem_reentrant_checkpoint < mem_no_checkpoint)
6847        self.assertTrue(mem_no_reentrant_checkpoint < mem_no_checkpoint)
6848
6849    def test_checkpointing_without_reentrant_custom_function_works(self):
6850        msg = "Unpack is being triggered for a tensor that was already unpacked once"
6851
6852        class MyFunc(torch.autograd.Function):
6853            @staticmethod
6854            def forward(ctx, x, y, z):
6855                w = x * y * z
6856                out = w + w
6857                ctx.save_for_backward(x, y, z, w, out)
6858                return out
6859
6860            @staticmethod
6861            def backward(ctx, grad_out):
6862                x, y, z, w, out = ctx.saved_tensors
6863                # Accessing the saved Tensors a second time will raise because
6864                # recomputed tensors get cleared as soon as they are unpacked.
6865                # A recomputation is only triggered if your backward has a new
6866                # graph-task id.
6867                with self.assertRaisesRegex(RuntimeError, msg):
6868                    x_2, y_2, z_2, w_2, out_2 = ctx.saved_tensors
6869                return x, y, z
6870
6871        x = torch.tensor(1.0, requires_grad=True)
6872        y = torch.tensor(2.0, requires_grad=True)
6873        z = torch.tensor(3.0, requires_grad=True)
6874
6875        def foo(x, y, z):
6876            x = x * y * z
6877            y = y * y * z
6878            z = z * z
6879            out = MyFunc.apply(x, y, z)
6880            return out
6881
6882        out = checkpoint(foo, x, y, z, use_reentrant=False)
6883        out.sum().backward()
6884
6885    def test_checkpointing_without_reentrant_with_context_fn(self):
6886        class VerboseTorchDispatchMode(TorchDispatchMode):
6887            def __init__(self) -> None:
6888                self.operators = []
6889
6890            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
6891                if kwargs is None:
6892                    kwargs = {}
6893                self.operators.append(func.__name__)
6894                return func(*args, **kwargs)
6895
6896        x = torch.tensor(1.0, requires_grad=True)
6897        verbose_mode = VerboseTorchDispatchMode()
6898
6899        def context_fn():
6900            return verbose_mode, contextlib.nullcontext()
6901
6902        out = checkpoint(
6903            lambda x: x.exp(), x, use_reentrant=False, context_fn=context_fn
6904        )
6905        self.assertEqual(verbose_mode.operators, ["exp.default"])
6906
6907        verbose_mode.operators = []
6908
6909        def context_fn():
6910            return contextlib.nullcontext(), verbose_mode
6911
6912        out = checkpoint(
6913            lambda x: x.exp(), x, use_reentrant=False, context_fn=context_fn
6914        )
6915        out.backward()
6916        self.assertEqual(
6917            verbose_mode.operators, ["exp.default", "detach.default", "detach.default"]
6918        )
6919
6920        with self.assertRaisesRegex(
6921            Exception, "only supported when use_reentrant=False"
6922        ):
6923            out = checkpoint(
6924                lambda x: x.sin(), x, use_reentrant=True, context_fn=context_fn
6925            )
6926
6927    def test_checkpoint_warns_if_use_reentrant_not_passed_explcitly(self):
6928        a = torch.randn(1, requires_grad=True)
6929
6930        # Passing explicitly should not warn
6931        self.assertNotWarn(lambda: checkpoint(lambda x: x, a, use_reentrant=False))
6932
6933        # Not passing explicitly warns
6934        with self.assertWarnsOnceRegex(
6935            UserWarning, ".*the use_reentrant parameter should be passed explicitly.*"
6936        ):
6937            checkpoint(lambda x: x, a)
6938
6939    def test_checkpoint_sequential_warns_if_use_reentrant_not_passed_explcitly(self):
6940        a = torch.randn(3, requires_grad=True)
6941        modules_list = [
6942            torch.nn.Linear(3, 3),
6943            torch.nn.Linear(3, 3),
6944            torch.nn.Linear(3, 3),
6945        ]
6946
6947        # Passing explicitly should not warn
6948        self.assertNotWarn(
6949            lambda: checkpoint_sequential(modules_list, 3, a, use_reentrant=False)
6950        )
6951
6952        # Not passing explicitly warns
6953        with self.assertWarnsOnceRegex(
6954            UserWarning, ".*the use_reentrant parameter should be passed explicitly.*"
6955        ):
6956            checkpoint_sequential(modules_list, 3, a)
6957
6958    def test_checkpoint_detects_non_determinism(self):
6959        def save_3_tensors(x):
6960            out = x.sin().exp()
6961            out = out.sin()
6962            return out
6963
6964        def save_2_tensors(x):
6965            return x.sin().exp()
6966
6967        def save_2_tensors_alt(x):
6968            return x.sin() * torch.tensor([1.0, 2.0])
6969
6970        def get_non_det_fn(orig_fn, recompute_fn):
6971            counter = [0]
6972
6973            def fn(x):
6974                if counter[0] == 0:
6975                    counter[0] += 1
6976                    return orig_fn(x)
6977                else:
6978                    return recompute_fn(x)
6979
6980            return fn
6981
6982        a = torch.randn(1, requires_grad=True)
6983
6984        # Save fewer tensors during recompute
6985        fn = get_non_det_fn(orig_fn=save_3_tensors, recompute_fn=save_2_tensors)
6986        with self.assertRaisesRegex(
6987            RuntimeError, "A different number of tensors was saved"
6988        ):
6989            out = checkpoint(fn, a, use_reentrant=False)
6990            out.backward()
6991
6992        # Save more tensors during recompute
6993        fn = get_non_det_fn(orig_fn=save_2_tensors, recompute_fn=save_3_tensors)
6994        with torch.utils.checkpoint.set_checkpoint_early_stop(False):
6995            with self.assertRaisesRegex(
6996                RuntimeError, "trying to save more tensors during recomputation"
6997            ):
6998                out = checkpoint(fn, a, use_reentrant=False)
6999                out.backward()
7000
7001        fn = get_non_det_fn(orig_fn=save_2_tensors, recompute_fn=save_3_tensors)
7002        # If early stopping is enabled, we would not raise (the results would be correct anyway)
7003        out = checkpoint(fn, a, use_reentrant=False)
7004        out.backward()
7005
7006        # Save the same number of tensors but the shape is different
7007        fn = get_non_det_fn(orig_fn=save_2_tensors, recompute_fn=save_2_tensors_alt)
7008        with self.assertRaisesRegex(RuntimeError, "tensors have different metadata"):
7009            out = checkpoint(fn, a, use_reentrant=False)
7010            out.backward()
7011
7012        # Get the debug message if debug=True
7013        fn = get_non_det_fn(orig_fn=save_2_tensors, recompute_fn=save_2_tensors_alt)
7014
7015        with self.assertRaisesRegex(
7016            RuntimeError,
7017            "You are seeing this error because you passed `debug=True` to checkpoint",
7018        ):
7019            out = checkpoint(fn, a, use_reentrant=False, debug=True)
7020            out.backward()
7021
7022        fn = get_non_det_fn(orig_fn=save_2_tensors, recompute_fn=save_2_tensors_alt)
7023
7024        with self.assertRaisesRegex(
7025            RuntimeError,
7026            "You are seeing this error because you passed `debug=True` to checkpoint",
7027        ):
7028            with torch.utils.checkpoint.set_checkpoint_debug_enabled(True):
7029                out = checkpoint(fn, a, use_reentrant=False, debug=False)
7030                out.backward()
7031
7032        fn = get_non_det_fn(orig_fn=save_2_tensors, recompute_fn=save_2_tensors_alt)
7033
7034        with self.assertRaisesRegex(
7035            RuntimeError, "Recomputed values for the following tensors have different"
7036        ):
7037            with torch.utils.checkpoint.set_checkpoint_debug_enabled(False):
7038                out = checkpoint(fn, a, use_reentrant=False, debug=True)
7039                out.backward()
7040
7041    def test_access_saved_tensor_twice_without_recomputation_works(self):
7042        count = [0]
7043
7044        def foo(a):
7045            count[0] += 1
7046            b = a * a
7047            c = a * b
7048            d = torch.exp(a)
7049            return d
7050
7051        a = torch.randn(5, requires_grad=True)
7052        d = checkpoint(foo, a, use_reentrant=False)
7053        self.assertEqual(count[0], 1)
7054        # Recomputed variables only persist within a particular backward call.
7055        # If _saved_result is accessed outside of a backward, it will trigger
7056        # a recompute. And afterwards, those recomputed results are immediately
7057        # cleared.
7058        d.grad_fn._saved_result
7059        self.assertEqual(count[0], 2)
7060        # Second access will trigger another recompute
7061        d.grad_fn._saved_result
7062        self.assertEqual(count[0], 3)
7063        # Backward clears the saved variable
7064        d.sum().backward()
7065        self.assertEqual(count[0], 4)
7066        # Now it raises an error
7067        with self.assertRaisesRegex(
7068            RuntimeError,
7069            "or directly access saved tensors after they have already been freed",
7070        ):
7071            d.grad_fn._saved_result
7072
7073    @slowTest
7074    @parametrize("input_requires_grad", [True, False])
7075    def test_checkpointing_without_reentrant(self, input_requires_grad):
7076        """
7077        Basic test for checkpoint without reentrant autograd.
7078        """
7079        num_inp = 2000
7080        nz_inp = 10
7081        nz_out = 10
7082        nz_bottleneck = 1000
7083
7084        # small proxy network for some complex reasoning we want to do per input
7085        module = nn.Sequential(
7086            nn.Linear(nz_inp, nz_bottleneck),
7087            nn.ReLU(),
7088            nn.Linear(nz_bottleneck, nz_inp),
7089        )
7090
7091        # Module holder for testing activation checkpointing with no_reentrant
7092        # supports kwargs.
7093        class MyModule(nn.Module):
7094            def __init__(self, mod):
7095                super().__init__()
7096                self.module = mod
7097
7098            def forward(self, data):
7099                return self.module(data)
7100
7101        module = MyModule(mod=module)
7102
7103        # Run model with and without checkpointing and verify gradients are
7104        # equivalent, regardless of if inputs require grads or not.
7105        module_copy = deepcopy(module)
7106
7107        feat_combined = []
7108        feat_combined_no_checkpoint = []
7109        for r in range(num_inp):
7110            data_r = torch.empty(1, nz_inp)
7111            data_r.uniform_()
7112            data_r.requires_grad = input_requires_grad
7113            data_r_copy = data_r.clone()
7114            feat_r = checkpoint(module, data=data_r, use_reentrant=False)
7115            feat_combined.append(feat_r)
7116            feat_r_no_checkpoint = module_copy(data_r)
7117            feat_combined_no_checkpoint.append(feat_r_no_checkpoint)
7118
7119        # compute mean as a proxy for some joint reasoning
7120        mean_combined = torch.stack(feat_combined).mean()
7121        mean_combined.backward()
7122        mean_combined_no_checkpoint = torch.stack(feat_combined_no_checkpoint).mean()
7123        mean_combined_no_checkpoint.backward()
7124
7125        for checkpoint_param, param in zip(
7126            module.parameters(), module_copy.parameters()
7127        ):
7128            self.assertEqual(checkpoint_param.grad, param.grad)
7129
7130    def test_checkpoint_valid_reset_on_error(self):
7131        a = torch.randn(2, 2, requires_grad=True)
7132
7133        with self.assertRaisesRegex(
7134            Exception, "torch.utils.checkpoint is incompatible"
7135        ):
7136            b = checkpoint(torch.exp, a, use_reentrant=True).sum()
7137            torch.autograd.grad(b, (a,))
7138
7139        c = checkpoint(torch.exp, a, use_reentrant=True).sum()
7140        c.backward()
7141
7142    @parametrize("use_reentrant", [True, False])
7143    def test_checkpointing_without_reentrant_detached_tensor(self, use_reentrant):
7144        class NoGradModule(torch.nn.Module):
7145            def __init__(self) -> None:
7146                super().__init__()
7147                self.linear = nn.Linear(2, 2, bias=False)
7148                self.lin2 = nn.Linear(2, 2, bias=False)
7149
7150            def forward(self, x):
7151                with torch.no_grad():
7152                    return self.lin2(self.linear(x))
7153
7154        module = NoGradModule()
7155
7156        err_ctx = (
7157            self.assertRaisesRegex(
7158                RuntimeError, "none of output has requires_grad=True"
7159            )
7160            if use_reentrant
7161            else contextlib.nullcontext()
7162        )
7163
7164        a = torch.randn(2, 2, requires_grad=True)
7165        for _ in range(3):
7166            with err_ctx:
7167                # out does not require grad
7168                out = checkpoint(module, a, use_reentrant=use_reentrant)
7169                # Make loss require grad, otherwise we would run into
7170                # "element 0 of tensors does not require grad and does not have a grad_fn"
7171                out += a
7172                out.sum().backward()
7173
7174    def test_checkpointing_without_reentrant_saved_object_identity(self):
7175        x_backward = None
7176
7177        class Test(torch.autograd.Function):
7178            @staticmethod
7179            def forward(ctx, x, y):
7180                ctx.save_for_backward(y)
7181                return x
7182
7183            @staticmethod
7184            def backward(ctx, x):
7185                nonlocal x_backward
7186                (x_backward,) = ctx.saved_tensors
7187                return x, None
7188
7189        a = torch.tensor(1.0, requires_grad=True)
7190        b = torch.tensor(1.0, requires_grad=False)
7191
7192        Test.apply(a, b).backward()
7193        self.assertIs(b, x_backward)
7194
7195        x_backward = None
7196        checkpoint(Test.apply, a, b, use_reentrant=False).backward()
7197        self.assertIs(b, x_backward)
7198
7199    def test_checkpointing_without_reentrant_correct_grad(self):
7200        """
7201        Verifies that correct gradients are calculated for checkpoint
7202        without reentrant autograd, for both backward() and autograd.grad().
7203        """
7204        a = torch.randn(2, 2, requires_grad=True)
7205
7206        b = torch.exp(a).sum()
7207        b.backward()
7208        b_grad = a.grad
7209
7210        a.grad = None
7211        c = checkpoint(torch.exp, a, use_reentrant=False).sum()
7212        c.backward()
7213        c_grad = a.grad
7214
7215        a.grad = None
7216        d = checkpoint(torch.exp, a, use_reentrant=False).sum()
7217        (d_grad,) = torch.autograd.grad(d, (a,))
7218
7219        self.assertEqual(b_grad, c_grad)
7220        self.assertEqual(b_grad, d_grad)
7221
7222    # PYTORCH_TEST_WITH_DYNAMO=1 test fails on CI but can't repro locally
7223    @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/127115")
7224    def test_checkpointing_without_reentrant_dataparallel(self):
7225        """
7226        Verifies gradient correctness when checkpoint without reentrant autograd
7227        is used in conjunction with DataParallel.
7228        """
7229
7230        class LinearModule(torch.nn.Module):
7231            def __init__(self) -> None:
7232                super().__init__()
7233                self.linear = nn.Linear(2, 2, bias=False)
7234
7235            def forward(self, inp):
7236                return self.linear(inp)
7237
7238        a = torch.randn(2, 2, requires_grad=True)
7239        if torch.cuda.is_available():
7240            a = a.cuda()
7241
7242        model = LinearModule()
7243        if torch.cuda.is_available():
7244            model = model.cuda()
7245
7246        b = deepcopy(model)(a).sum()
7247        b.backward()
7248        b_grad = a.grad
7249
7250        a.grad = None
7251
7252        module = torch.nn.DataParallel(deepcopy(model))
7253        c = checkpoint(module, a, use_reentrant=False).sum()
7254        c.backward()
7255        c_grad = a.grad
7256
7257        self.assertEqual(b_grad, c_grad)
7258
7259    def test_checkpointing_without_reentrant_parameter_used_in_an_out(self):
7260        """
7261        Ensures that gradient hooks are only called once per tensor.
7262        """
7263        w = torch.randn(10, 10, requires_grad=True)
7264        count = 0
7265
7266        def hook(grad):
7267            nonlocal count
7268            count += 1
7269
7270        w.register_hook(hook)
7271        x = torch.rand(10, 10, requires_grad=True)
7272        h = w * x  # Using w outside the checkpoint
7273        out = checkpoint(
7274            lambda x: w * x, h, use_reentrant=False
7275        )  # Using w inside the checkpoint
7276
7277        out.sum().backward()
7278        # should only call hook once
7279        self.assertEqual(count, 1)
7280
7281    # https://github.com/pytorch/pytorch/issues/127115
7282    @xfailIfTorchDynamo
7283    def test_checkpointing_without_reentrant_arbitrary_input_output(self):
7284        """
7285        Ensures checkpointing without reentrant autograd works with functions
7286        with arbitrary input/output structures.
7287        """
7288
7289        class MyModel(torch.nn.Module):
7290            def __init__(self) -> None:
7291                super().__init__()
7292                self.layer = torch.nn.Linear(5, 5, bias=False)
7293
7294            def forward(self, dict_input):
7295                tensor = dict_input["tensor"]
7296                return {"result": self.layer(tensor)}
7297
7298        model_no_checkpoint = MyModel()
7299        model_checkpoint_without_reentrant = deepcopy(model_no_checkpoint)
7300
7301        inp = {"tensor": torch.randn(5, 5)}
7302
7303        out_no_checkpoint = model_no_checkpoint(inp)["result"].sum()
7304
7305        out_checkpoint = checkpoint(
7306            model_checkpoint_without_reentrant, inp, use_reentrant=False
7307        )["result"].sum()
7308
7309        self.assertEqual(out_checkpoint, out_no_checkpoint)
7310
7311        out_no_checkpoint.backward()
7312        out_checkpoint.backward()
7313
7314        for param, checkpoint_param in zip(
7315            model_no_checkpoint.parameters(),
7316            model_checkpoint_without_reentrant.parameters(),
7317        ):
7318            self.assertEqual(param.grad, checkpoint_param.grad)
7319
7320    def test_callback_adds_callback(self):
7321        called = [0]
7322
7323        def callback_final():
7324            called[0] += 1
7325
7326        def callback_adds_callback():
7327            called[0] += 1
7328            Variable._execution_engine.queue_callback(callback_final)
7329
7330        class MyFunc(Function):
7331            @staticmethod
7332            def forward(ctx, input):
7333                return input
7334
7335            @staticmethod
7336            @once_differentiable
7337            def backward(ctx, grad):
7338                Variable._execution_engine.queue_callback(callback_adds_callback)
7339                return grad
7340
7341        a = torch.rand((3, 3), requires_grad=True)
7342        b = MyFunc.apply(a)
7343        b.sum().backward()
7344
7345        self.assertEqual(called[0], 2)
7346
7347    @unittest.skipIf(not TEST_CUDA, "test requires CUDA")
7348    def test_callback_propagates_errors_from_device_thread(self):
7349        def callback():
7350            raise RuntimeError("blah")
7351
7352        def hook_with_callback(*args):
7353            torch.autograd.Variable._execution_engine.queue_callback(callback)
7354
7355        t = torch.tensor([1.0, 2.0], requires_grad=True, device=torch.device("cuda"))
7356        t.register_hook(hook_with_callback)
7357        output = t**2
7358        loss = output.sum()
7359
7360        with self.assertRaisesRegex(RuntimeError, "blah"):
7361            loss.backward()
7362
7363    def _test_reentrant_with_callbacks(self, install_callbacks_in_depths):
7364        counter = {}
7365        counter["inner"] = 0
7366        counter["outer"] = 0
7367
7368        def inc_inner_counter():
7369            counter["inner"] += 1
7370
7371        def inc_outer_counter():
7372            counter["outer"] += 1
7373
7374        class MyFunc(Function):
7375            @staticmethod
7376            def forward(ctx, input):
7377                return input
7378
7379            @staticmethod
7380            @once_differentiable
7381            def backward(ctx, input):
7382                if 1 in install_callbacks_in_depths:
7383                    # Add a callback to execute.
7384                    Variable._execution_engine.queue_callback(inc_inner_counter)
7385
7386                return input
7387
7388        class MyReentrantFunc(Function):
7389            @staticmethod
7390            def forward(ctx, input):
7391                return input
7392
7393            @staticmethod
7394            @once_differentiable
7395            def backward(ctx, input):
7396                if 0 in install_callbacks_in_depths:
7397                    # Add a callback to execute.
7398                    Variable._execution_engine.queue_callback(inc_outer_counter)
7399                # Reentrant backward call.
7400                tmp_inp = input.detach().requires_grad_()
7401                with torch.enable_grad():
7402                    tmp_out = (MyFunc.apply(tmp_inp)).sum()
7403                tmp_out.backward()
7404                return input
7405
7406        t1 = torch.rand((3, 3), requires_grad=True)
7407        t2 = MyReentrantFunc.apply(t1)
7408        t3 = t2.sum()
7409        torch.autograd.backward([t3])
7410
7411        return counter
7412
7413    def test_reentrant_with_callbacks_depth_0(self):
7414        # Verify callback is called only once.
7415        ret = self._test_reentrant_with_callbacks([0])
7416        self.assertEqual(1, ret["outer"])
7417        self.assertEqual(0, ret["inner"])
7418
7419    def test_reentrant_with_callbacks_depth_1(self):
7420        # Verify callback is called only once.
7421        ret = self._test_reentrant_with_callbacks([1])
7422        self.assertEqual(0, ret["outer"])
7423        self.assertEqual(1, ret["inner"])
7424
7425    def test_reentrant_with_callbacks_both_depths(self):
7426        # Verify callback is called twice.
7427        ret = self._test_reentrant_with_callbacks([0, 1])
7428        self.assertEqual(1, ret["outer"])
7429        self.assertEqual(1, ret["inner"])
7430
7431    def test_reentrant_with_leaf_variable_hook(self):
7432        handle = None
7433        param = torch.rand(10, requires_grad=True)
7434
7435        def add_gradient_penalty_to_grad(grad):
7436            handle.remove()
7437            old_param_grad = grad
7438            param.grad = None
7439            # Add some sort of gradient penalty by directly updating the gradients
7440            with torch.enable_grad():
7441                g = grad.detach().requires_grad_()
7442                new_param = param.detach().requires_grad_()
7443                out = ((g * 2) + new_param).sum()
7444                out.backward()
7445            res = g.grad + grad
7446            param.grad = old_param_grad
7447            return res
7448
7449        handle = param.register_hook(add_gradient_penalty_to_grad)
7450        # Forward pass
7451        tmp = param * param
7452        loss = tmp.sum()
7453        # Compute the gradients
7454        loss.backward()
7455
7456    def test_reentrant_with_non_leaf_variable_hook(self):
7457        handle = None
7458        param = torch.rand(10, requires_grad=True)
7459
7460        def manual_increase_gradient(grad):
7461            handle.remove()
7462            # Add some sort of gradient penalty by directly updating the gradients
7463            with torch.enable_grad():
7464                g = grad.detach().requires_grad_()
7465                out = ((g * 2) + 5).sum()
7466                out.backward()
7467            res = g.grad + grad
7468            return res
7469
7470        # Forward pass
7471        tmp = param * param
7472        handle = tmp.register_hook(manual_increase_gradient)
7473        loss = tmp.sum()
7474        # Compute the gradients
7475        loss.backward()
7476        self.assertEqual(param.grad, 6 * param)
7477
7478    def test_grad_fn_attr_bindings(self):
7479        # Check that the getter of each type returns what we want
7480        # See `gen_autograd_functions.py` for how the getters are generated
7481        #
7482        # This test is only meant to check if the codegen'd bindings work
7483        # Please help update this test if you update the names of any the fields we check!
7484        #
7485        a = torch.ones(1, requires_grad=True)
7486        b = torch.zeros(1, requires_grad=True)
7487        out1 = torch.stack([a, b], dim=0)
7488        out2 = (a * 2) * b
7489        # TODO: I don't think we have a backward saving a list of tensors
7490        #       at the moment. It used to be stack, but for no reason...
7491        #       see discussion in #84993
7492        # self.assertEqual(out.grad_fn._saved_tensors, (a, b))              # TewnsorList -> Tuple[Tensor]
7493        self.assertEqual(out2.grad_fn._saved_self, a * 2)
7494        self.assertIsInstance(out2.grad_fn._saved_self, torch.Tensor)
7495        self.assertIsInstance(
7496            out2.grad_fn._raw_saved_self, torch._C._autograd.SavedTensor
7497        )
7498        self.assertEqual(out1.grad_fn._saved_dim, 0)  # int64_t -> int
7499        self.assertIsInstance(out1.grad_fn._saved_dim, int)
7500
7501        out2.grad_fn._raw_saved_self.register_hooks(lambda x: x, lambda x: x)
7502
7503        out2.sum().backward()
7504        with self.assertRaisesRegex(RuntimeError, "after they have already been freed"):
7505            out2.grad_fn._saved_self
7506        # TODO: interestingly, this only happens if indexing into a list grad_fn._raw_saved_tensors[0],
7507        #       not when using a saved tensor, see discussion in #84993
7508        # with self.assertRaisesRegex(RuntimeError, "after they have already been freed"):
7509        #     out2.grad_fn._raw_saved_self
7510        self.assertEqual(out1.grad_fn._saved_dim, 0)
7511
7512        a = torch.ones(2, 2, requires_grad=True)
7513        indices = torch.tensor([0, 1])
7514        out = a[:, indices]
7515        self.assertEqual(
7516            out.grad_fn._saved_indices, (None, indices)
7517        )  # c10::List<std::optional<Tensor>> -> Tuple[Tensor?]
7518        self.assertIsInstance(out.grad_fn._saved_indices[1], torch.Tensor)
7519        self.assertIsInstance(
7520            out.grad_fn._raw_saved_indices[1], torch._C._autograd.SavedTensor
7521        )
7522        self.assertEqual(
7523            out.grad_fn._saved_self_sym_sizes, a.shape
7524        )  # SymIntArrayRef -> Tuple[SymInt]
7525        self.assertIsInstance(out.grad_fn._saved_self_sym_sizes[0], int)
7526
7527        out.grad_fn._raw_saved_indices[1].register_hooks(lambda x: x, lambda x: x)
7528        with self.assertRaisesRegex(RuntimeError, "None is forbidden"):
7529            out.grad_fn._raw_saved_indices[0].register_hooks(lambda x: x, lambda x: x)
7530
7531        out = a.mean()
7532        self.assertEqual(
7533            out.grad_fn._saved_self_sym_sizes, a.shape
7534        )  # IntArrayRef -> Tuple[int]
7535
7536        a = torch.ones(2, 2, requires_grad=True)
7537        out = a * a
7538        out.grad_fn._raw_saved_self.register_hooks(lambda x: x, lambda x: x)
7539        out.sum().backward()
7540        with self.assertRaisesRegex(RuntimeError, "after it has been freed"):
7541            out.grad_fn._raw_saved_self.register_hooks(lambda x: x, lambda x: x)
7542
7543        a = torch.ones(1, 1, 2, requires_grad=True)
7544        out = torch.nn.functional.interpolate(a, 4, mode="linear")
7545        self.assertEqual(
7546            out.grad_fn._saved_output_size, (4,)
7547        )  # std::optional<IntArrayRef> -> int[]?
7548        self.assertIsInstance(out.grad_fn._saved_output_size[0], int)
7549        self.assertEqual(out.grad_fn._saved_align_corners, False)  # bool -> bool
7550        self.assertIsInstance(out.grad_fn._saved_align_corners, bool)
7551        if hasattr(out.grad_fn, "_saved_scale_factors"):
7552            self.assertIsNone(
7553                out.grad_fn._saved_scale_factors
7554            )  # std::optional<ArrayRef<double>> -> float[]?
7555        else:
7556            self.assertIsNone(
7557                out.grad_fn._saved_scales
7558            )  # std::optional<ArrayRef<double>> -> float[]?
7559
7560        a = torch.ones(1, 1, 3, 3, requires_grad=True)
7561        out = nn.Conv2d(1, 1, 3)(a)
7562        self.assertEqual(
7563            out.grad_fn._saved_bias_sym_sizes_opt, (1,)
7564        )  # std::optional<SymIntArrayRef> -> SymInt[]?
7565        out = nn.Conv2d(1, 1, 3, bias=False)(a)
7566        # TODO: This is BAD! we converted a std::nullopt into a (0,)
7567        self.assertEqual(out.grad_fn._saved_bias_sym_sizes_opt, (0,))
7568
7569        a = torch.ones(1, 3, 3, requires_grad=True)
7570        out = torch.addbmm(a.squeeze(0), a, a)
7571        self.assertEqual(out.grad_fn._saved_batch1_sym_argsize_0, 1)  # int64_t
7572        self.assertEqual(out.grad_fn._saved_batch1_sym_argsize_1, 3)  # int64_t
7573
7574        a = torch.ones(1, 1, 3, 3, requires_grad=True)
7575        out = torch.nn.functional.unfold(a, 3)
7576        self.assertEqual(out.grad_fn._saved_self_sym_argsize_minus_2, 3)  # SymInt
7577        self.assertEqual(out.grad_fn._saved_self_sym_argsize_minus_1, 3)  # SymInt
7578
7579        a = torch.ones(1, 1, 2, requires_grad=True)
7580        out = torch.nn.functional.interpolate(a, scale_factor=0.5, mode="linear")
7581        self.assertEqual(out.grad_fn._saved_scales, 0.5)
7582
7583        a = torch.ones(2, 2, requires_grad=True)
7584        out = torch.pdist(a, p=1)
7585        self.assertEqual(out.grad_fn._saved_p, 1.0)  # double -> float
7586        self.assertIsInstance(out.grad_fn._saved_p, float)
7587
7588        a = torch.ones(1, 1, 2, requires_grad=True)
7589        out = torch.logit(a, 1.0)
7590        self.assertEqual(out.grad_fn._saved_eps, 1.0)  # c10:optional<double> -> float?
7591        self.assertIsInstance(out.grad_fn._saved_eps, float)
7592        out = torch.logit(a)
7593        self.assertIsNone(out.grad_fn._saved_eps)
7594
7595        if torch._C.has_lapack:
7596            a = torch.ones(1, 1, requires_grad=True)
7597            q, r = torch.linalg.qr(a, mode="reduced")
7598            self.assertEqual(q.grad_fn._saved_mode, "reduced")  # std::string -> str
7599
7600        a = torch.tensor([1.0], requires_grad=True)
7601        out = torch.div(a, 2.0, rounding_mode="trunc")
7602        self.assertEqual(
7603            out.grad_fn._saved_rounding_mode, "trunc"
7604        )  # std::optional<std::string> -> str?
7605        out = torch.div(a, 2.0, rounding_mode=None)
7606        self.assertIsNone(
7607            out.grad_fn._saved_rounding_mode
7608        )  # std::optional<std::string> -> str?
7609
7610        x = torch.zeros(5, requires_grad=True)
7611        out = torch.threshold(x, threshold=(1 + 0j), value=(1 + 0j))
7612        self.assertIsInstance(
7613            out.grad_fn._saved_threshold, complex
7614        )  # Scalar(complex double) -> complex
7615        cfloat = torch.tensor(1 + 0j, dtype=torch.complex64)
7616        out = torch.threshold(x, threshold=cfloat, value=(1 + 0j))
7617        self.assertIsInstance(
7618            out.grad_fn._saved_threshold, complex
7619        )  # Scalar(complex float) -> complex
7620        out = torch.threshold(x, threshold=1.0, value=1.0)
7621        self.assertIsInstance(
7622            out.grad_fn._saved_threshold, float
7623        )  # Scalar(floating point) -> float
7624        out = torch.threshold(x, threshold=1, value=1)
7625        self.assertIsInstance(
7626            out.grad_fn._saved_threshold, int
7627        )  # Scalar(integral) -> int
7628        out = torch.threshold(x, threshold=False, value=False)
7629        self.assertIsInstance(
7630            out.grad_fn._saved_threshold, bool
7631        )  # Scalar(bool) -> bool
7632
7633        a = torch.ones(2, 2, requires_grad=True)
7634        out = a.as_strided((3,), (1,), 1)
7635        self.assertEqual(
7636            out.grad_fn._saved_storage_offset, 1
7637        )  # c10:optional<int64_t> -> int?
7638        self.assertIsInstance(out.grad_fn._saved_storage_offset, int)
7639        out = a.as_strided((3,), (1,))
7640        self.assertIsNone(out.grad_fn._saved_storage_offset)
7641
7642        a = torch.ones(2, requires_grad=True)
7643        out = torch.tanh(a)
7644        self.assertEqual(out, out.grad_fn._saved_result)  # saved variable when output
7645
7646        a = torch.randn(3, 5, requires_grad=True)
7647        b = torch.tensor([1, 0, 4])
7648        loss = nn.NLLLoss()
7649        out = loss(a, b)
7650        self.assertIsNone(out.grad_fn._saved_weight)
7651        loss = nn.NLLLoss(weight=torch.ones((5,)))
7652        out = loss(a, b)
7653        self.assertEqual(
7654            out.grad_fn._saved_weight, torch.ones((5,))
7655        )  # c10:optional<Tensor> -> Tensor?
7656
7657        out.sum().backward()
7658        with self.assertRaisesRegex(RuntimeError, "after they have already been freed"):
7659            out.grad_fn._saved_weight
7660
7661        num_tensors = 3
7662        input_tensors = [
7663            torch.ones(2, 2, requires_grad=True) for _ in range(num_tensors)
7664        ]
7665        scalars = [
7666            0.0 for _ in range(num_tensors)
7667        ]  # ArrayRef<Scalar> -> Tuple[Scalar, ...]
7668        results = torch._foreach_maximum(input_tensors, scalars)
7669        for t in results:
7670            self.assertEqual(t.grad_fn._saved_scalars, scalars)
7671
7672    def test_cant_create_saved_tensors(self):
7673        with self.assertRaisesRegex(
7674            RuntimeError,
7675            "Trying to create a SavedTensor object from Python is forbidden",
7676        ):
7677            torch.autograd.SavedTensor()
7678
7679    def test_custom_function_saved_tensors(self):
7680        def getFn(save=True):
7681            class MyFn(Function):
7682                @staticmethod
7683                def forward(ctx, x):
7684                    if save:
7685                        ctx.save_for_backward(x, None)
7686                    return x
7687
7688                @staticmethod
7689                def backward(ctx, g):
7690                    return g
7691
7692            return MyFn
7693
7694        a = torch.randn(5, requires_grad=True)
7695
7696        y = getFn(True).apply(a)
7697
7698        self.assertEqual((a, None), y.grad_fn.saved_tensors)
7699        saved = y.grad_fn._raw_saved_tensors
7700        self.assertIsInstance(saved[0], torch._C._autograd.SavedTensor)
7701        # We can't tell the underlying tensor is None without unpacking it
7702        self.assertIsInstance(saved[1], torch._C._autograd.SavedTensor)
7703
7704        # We catch that error when the user calls register_hooks on it
7705        with self.assertRaisesRegex(RuntimeError, "None is forbidden"):
7706            saved[1].register_hooks(lambda x: x, lambda x: x)
7707
7708        with self.assertRaisesRegex(TypeError, "incompatible function arguments"):
7709            saved[0].register_hooks(lambda x: x)
7710        with self.assertRaisesRegex(TypeError, "incompatible function arguments"):
7711            saved[0].register_hooks(1, 1)
7712        saved[0].register_hooks(lambda x: x, lambda x: x)
7713        with self.assertRaisesRegex(RuntimeError, "already been set"):
7714            saved[0].register_hooks(lambda x: x, lambda x: x)
7715        y.sum().backward()
7716
7717        # Using a reference to the SavedTensor object after the
7718        # saved variables have been released can lead to undefined behavior
7719        del saved
7720        with self.assertRaisesRegex(RuntimeError, "after they have already been freed"):
7721            y.grad_fn._raw_saved_tensors
7722        with self.assertRaisesRegex(RuntimeError, "after they have already been freed"):
7723            y.grad_fn.saved_tensors
7724
7725        y = getFn(False).apply(a)
7726        self.assertEqual(y.grad_fn.saved_tensors, ())
7727        self.assertEqual(y.grad_fn._raw_saved_tensors, ())
7728
7729    def test_autograd_node_isinstance(self):
7730        # Node is a "virtual" base class of codegen'd nodes. This means that
7731        # isinstance and issubclass are overridden, but mro is unchanged
7732        Node = torch.autograd.graph.Node
7733
7734        a = torch.rand(3, 3, requires_grad=True)
7735        b = a.exp()
7736
7737        # Some nodes have codegened registrations to the torch._C._function module
7738        self.assertIsInstance(b.grad_fn, Node)
7739        self.assertTrue(issubclass(type(b.grad_fn), Node))
7740        self.assertTrue(Node not in type(b.grad_fn).mro())
7741
7742        # Other nodes have manual registrations to the torch._C._function module
7743        self.assertNotIsInstance(torch._C._functions.AccumulateGrad, Node)
7744        self.assertTrue(issubclass(torch._C._functions.AccumulateGrad, Node))
7745        self.assertIsInstance(b.grad_fn.next_functions[0][0], Node)
7746        self.assertTrue(issubclass(torch._C._functions.DelayedError, Node))
7747
7748        # Special cases
7749        self.assertNotIsInstance(None, Node)
7750        self.assertNotIsInstance(1, Node)
7751        self.assertNotIsInstance(Node, Node)
7752        self.assertTrue(issubclass(Node, Node))
7753
7754        # Custom function case
7755        self.assertTrue(issubclass(torch.autograd.function.BackwardCFunction, Node))
7756
7757        class Func(torch.autograd.Function):
7758            @staticmethod
7759            def forward(ctx, x):
7760                self.assertIsInstance(ctx, Node)
7761                return x
7762
7763            @staticmethod
7764            def backward(ctx, x):
7765                self.assertIsInstance(ctx, Node)
7766                return x
7767
7768        out = Func.apply(a)
7769        self.assertIsInstance(out.grad_fn, Node)
7770        self.assertTrue(issubclass(type(out.grad_fn), Node))
7771        self.assertTrue(Node not in type(out.grad_fn).mro())
7772        out.sum().backward()
7773
7774    def test_autograd_views_codegen(self):
7775        # This is not necessarily the absolute correct behavior, but this is the current
7776        # one. This test is here to make sure that any change to this behavior is detected
7777        # and not silent. The TODOs below mark the places with unexpected behavior.
7778        # Note that any change in these test will be BC-breaking and should be done carefully.
7779
7780        # This test checks the behavior of two codegen functions (view_as and unbind)
7781        # with respect to view tracking and inplace operation on the output.
7782
7783        def run_test(grad_mode, requires_grad, is_view, should_raise_tuple):
7784            def maybe_check_raise(fn, should_raise):
7785                self.assertTrue(should_raise is None or isinstance(should_raise, str))
7786                if should_raise is not None:
7787                    with self.assertRaisesRegex(RuntimeError, should_raise):
7788                        fn()
7789                else:
7790                    fn()
7791
7792            inp = torch.rand(2, requires_grad=requires_grad).clone()
7793            with torch.set_grad_enabled(grad_mode):
7794                out = inp.view_as(inp)
7795            # Are they differentiable views?
7796            self.assertTrue(out._is_view() == is_view)
7797            # Are inplace allowed?
7798            maybe_check_raise(lambda: out.add_(1), should_raise_tuple[0])
7799
7800            inp = torch.rand(2, requires_grad=requires_grad).clone()
7801            with torch.set_grad_enabled(grad_mode):
7802                out = inp.unbind()
7803            # Are they differentiable views?
7804            self.assertTrue(out[0]._is_view() == is_view)
7805            self.assertTrue(out[1]._is_view() == is_view)
7806            # Are inplace allowed?
7807            maybe_check_raise(lambda: out[0].add_(1), should_raise_tuple[1])
7808            maybe_check_raise(lambda: out[1].add_(1), should_raise_tuple[2])
7809
7810        # should_raise contains None if it should not raise
7811        # should_raise contains a string of the error if it should raise
7812        # The 3 elements are for view_as, first output of unbind and second output of unbind
7813        run_test(
7814            grad_mode=True,
7815            requires_grad=False,
7816            is_view=True,
7817            should_raise_tuple=(None, None, None),
7818        )
7819        inp_change_err = (
7820            "Output {} of UnbindBackward0 is a view and is being modified inplace."
7821        )
7822        run_test(
7823            grad_mode=True,
7824            requires_grad=True,
7825            is_view=True,
7826            should_raise_tuple=(
7827                None,
7828                inp_change_err.format("0"),
7829                inp_change_err.format("1"),
7830            ),
7831        )
7832        leaf_grad_err = (
7833            "A view was created in no_grad mode and is being modified inplace"
7834        )
7835        run_test(
7836            grad_mode=False,
7837            requires_grad=True,
7838            is_view=True,
7839            should_raise_tuple=(leaf_grad_err, leaf_grad_err, leaf_grad_err),
7840        )
7841        run_test(
7842            grad_mode=False,
7843            requires_grad=False,
7844            is_view=True,
7845            should_raise_tuple=(None, None, None),
7846        )
7847
7848    def test_inplace_not_requires_grad(self):
7849        class MyFn(torch.autograd.Function):
7850            @staticmethod
7851            def forward(ctx, inp):
7852                return inp.view_as(inp)
7853
7854            @staticmethod
7855            def backward(ctx, grad):
7856                return grad
7857
7858        # Original Tensor does not require grad
7859        a = torch.rand(1, 2)
7860
7861        # Tensor being written does require grad
7862        b = torch.rand(1, requires_grad=True)
7863
7864        # Take an invalid view on 'a' that should raise an error (warns during deprecation)
7865        view_a = MyFn.apply(a)
7866
7867        with self.assertRaisesRegex(
7868            RuntimeError, "This view was created inside a custom Function"
7869        ):
7870            view_a += b
7871
7872        # Extra test for copy_ that is a manual implementation and could be easily
7873        # forgotten when the codegen is updated (warns during deprecation)
7874        a = torch.rand(1, 2)
7875        b = torch.rand(1, requires_grad=True)
7876        view_a = MyFn.apply(a)
7877
7878        with self.assertRaisesRegex(
7879            RuntimeError, "This view was created inside a custom Function"
7880        ):
7881            view_a.copy_(b)
7882
7883        # Functions that should throw must properly throw
7884        a = torch.rand(1, 2)
7885        b = torch.rand(1, requires_grad=True)
7886        view_a = a.unbind()[0]
7887        with self.assertRaisesRegex(
7888            RuntimeError,
7889            "This view is the output of a function that returns " "multiple views.",
7890        ):
7891            view_a.copy_(b)
7892
7893        # Sanity check that views that should work still work
7894        a = torch.rand(1, 2)
7895        b = torch.rand(1, requires_grad=True)
7896        a.select(1, 0).copy_(b)
7897
7898    def _do_test_autograd_simple_views_python(self, dtype):
7899        # This is not necessarily the absolute correct behavior, but this is the current
7900        # one. This test is here to make sure that any change to this behavior is detected
7901        # and not silent. The TODOs below mark the places with unexpected behavior.
7902        # Note that any change in these test will be BC-breaking and should be done carefully.
7903
7904        # This checks the autograd.Function behavior when we return one or multiple outputs
7905        # while one of these is an input, a view of an input or of a temporary tensor.
7906
7907        # This indicator is used to track how many times the backward function was called
7908        bw_called = [0]
7909        # This indicator is used to check if the argument `ga` contains non-zero values
7910        ga_nz = [False]
7911
7912        class IdOneOutput(Function):
7913            @staticmethod
7914            def forward(ctx, a, b, make_view):
7915                if make_view:
7916                    a = a.narrow(0, 0, 2)
7917                else:
7918                    a = a.clone()
7919                return a
7920
7921            @staticmethod
7922            def backward(ctx, ga):
7923                bw_called[0] += 1
7924                return ga, None, None
7925
7926        class IdTwoOutput(Function):
7927            @staticmethod
7928            def forward(ctx, a, b, make_view):
7929                if make_view:
7930                    a = a.narrow(0, 0, 2)
7931                else:
7932                    a = a.clone()
7933                return a, a + b
7934
7935            @staticmethod
7936            def backward(ctx, ga, gab):
7937                bw_called[0] += 1
7938                if ga.eq(0).all():
7939                    ga_nz[0] = False
7940                else:
7941                    ga_nz[0] = True
7942                return ga + gab, gab, None
7943
7944        class ViewOfTemp(Function):
7945            @staticmethod
7946            def forward(ctx, a, make_view):
7947                ctx.save_for_backward(a)
7948                if make_view:
7949                    a = a.narrow(0, 0, 2)
7950                else:
7951                    a = a.clone()
7952                b = a.clone()
7953                return b.select(0, 0)
7954
7955            @staticmethod
7956            def backward(ctx, grad):
7957                bw_called[0] += 1
7958                (a,) = ctx.saved_tensors
7959                res = torch.zeros_like(a)
7960                res.select(0, 0).copy_(grad)
7961                return res, None
7962
7963        fn_id_to_inplace_on_view_err_msg = {
7964            "one_output": (
7965                "Output 0 of IdOneOutputBackward is a view and is being "
7966                "modified inplace. This view was created inside a custom Function"
7967            ),
7968            "two_output": (
7969                "Output 0 of IdTwoOutputBackward is a view and is being modified inplace."
7970                " This view is the output of a function that returns multiple views."
7971            ),
7972            "view_of_temp": (
7973                "Output 0 of ViewOfTempBackward is a view and is being "
7974                "modified inplace. This view was created inside a custom Function"
7975            ),
7976        }
7977
7978        for fn_id in ["one_output", "two_output", "view_of_temp"]:
7979            for inplace in [True, False]:
7980                for make_view in [True, False]:
7981                    # Used for special casing the tests below
7982                    output_is_a_view = make_view or fn_id == "view_of_temp"
7983
7984                    def fn(a, b):
7985                        # never modify a, b inplace for gracheck
7986                        a = a.clone()
7987                        b = b.clone()
7988                        if fn_id == "two_output":
7989                            tmp1, tmp2 = IdTwoOutput.apply(a, b, make_view)
7990                            if inplace:
7991                                tmp1 += 3
7992                                tmp2 += 3
7993                            else:
7994                                tmp1 = tmp1 + 3
7995                                tmp2 = tmp2 + 3
7996                            tmp = tmp1 * tmp2
7997                        else:
7998                            if fn_id == "one_output":
7999                                tmp = IdOneOutput.apply(a, b, make_view)
8000                            else:
8001                                tmp = ViewOfTemp.apply(a + b, make_view)
8002                            if inplace:
8003                                tmp += 3
8004                            else:
8005                                tmp = tmp + 3
8006
8007                        return tmp.sum()
8008
8009                    a = torch.ones(2, dtype=dtype, requires_grad=True)
8010                    b = torch.ones(2, dtype=dtype, requires_grad=True)
8011
8012                    err_msg = fn_id_to_inplace_on_view_err_msg[fn_id]
8013
8014                    if not inplace or not output_is_a_view:
8015                        gradcheck(fn, (a, b), check_batched_grad=False)
8016
8017                    # Was the custom backward called properly
8018                    bw_called[0] = 0
8019                    ga_nz[0] = True  # For the case where the backward is called
8020
8021                    if inplace and output_is_a_view:
8022                        with self.assertRaisesRegex(RuntimeError, err_msg):
8023                            fn(a, b)
8024                    else:
8025                        fn(a, b).abs().backward()
8026
8027                    expected_called = 1
8028                    expected_ga_nz = True
8029
8030                    if output_is_a_view and inplace:
8031                        expected_called = 0
8032
8033                    self.assertTrue(bw_called[0] == expected_called)
8034                    self.assertTrue(ga_nz[0] == expected_ga_nz)
8035
8036    def test_autograd_simple_views_python(self):
8037        self._do_test_autograd_simple_views_python(torch.double)
8038        self._do_test_autograd_simple_views_python(torch.cdouble)
8039
8040    def test_autograd_inplace_views_creation_meta(self):
8041        # Tests creation_meta properly handled for inplace views
8042
8043        class Func(torch.autograd.Function):
8044            @staticmethod
8045            def forward(ctx, x):
8046                return x.view_as(x)
8047
8048            @staticmethod
8049            def backward(ctx, x):
8050                return x
8051
8052        view_custom = Func.apply
8053
8054        def run_test(
8055            fn, fn_type, grad_mode_view, grad_mode_iview, requires_grad, error1, error2
8056        ):
8057            # This test checks the behavior of inplace-view functions when
8058            # the views are created in grad mode or not
8059            base = torch.rand(2, 3, requires_grad=requires_grad).clone()
8060            # 1. Create a view with `grad_mode=grad_mode_view`
8061            with torch.set_grad_enabled(grad_mode_view):
8062                if fn_type == "multi_view":
8063                    inp = base.unbind()[0]
8064                elif fn_type == "custom":
8065                    inp = view_custom(base)
8066                else:
8067                    inp = base.view_as(base)
8068
8069            # 2. Perform inplace view with `grad_mode=grad_mode_iview`
8070            with torch.set_grad_enabled(grad_mode_iview):
8071                if error1 is not None:
8072                    with self.assertRaisesRegex(RuntimeError, error1):
8073                        fn(inp)
8074                    return
8075                else:
8076                    # If error is None, check that runs without error
8077                    fn(inp)
8078            # 3. Do inplace on the (new) view
8079            if error2 is not None:
8080                with self.assertRaisesRegex(RuntimeError, error2):
8081                    inp.add_(1)
8082            else:
8083                # If error is None, check that runs without error
8084                inp.add_(1)
8085
8086        no_grad_err = "A view was created in no_grad mode"
8087        multi_view_err = "function that returns multiple views"
8088        custom_err = "view was created inside a custom Function"
8089
8090        def run_tests(fn):
8091            for fn_type in ("normal", "multi_view", "custom"):
8092                for grad_mode_view in (True, False):
8093                    for grad_mode_iview in (True, False):
8094                        for requires_grad in (True, False):
8095                            error1 = None  # expected error when we do inplace_view on original view
8096                            error2 = None  # expected error when we do inplace on the resulting view
8097
8098                            if requires_grad:
8099                                if not grad_mode_view and grad_mode_iview:
8100                                    error1 = no_grad_err
8101                                if not grad_mode_view and not grad_mode_iview:
8102                                    error2 = no_grad_err
8103
8104                                if fn_type == "multi_view":
8105                                    if grad_mode_view and grad_mode_iview:
8106                                        error1 = multi_view_err
8107                                    if grad_mode_view and not grad_mode_iview:
8108                                        error2 = multi_view_err
8109
8110                                if fn_type == "custom":
8111                                    if grad_mode_view and grad_mode_iview:
8112                                        error1 = custom_err
8113                                    if grad_mode_view and not grad_mode_iview:
8114                                        error2 = custom_err
8115
8116                            run_test(
8117                                fn,
8118                                fn_type,
8119                                grad_mode_view,
8120                                grad_mode_iview,
8121                                requires_grad,
8122                                error1,
8123                                error2,
8124                            )
8125
8126        # This list was created by logging gen_inplace_or_view_type.py
8127        #   detach_ is excluded for this test because it cannot be applied to
8128        #   views and thus does not return a view
8129        run_tests(lambda v: v.as_strided_((1, 0), (2, 2)))
8130        run_tests(lambda v: v.transpose_(0, 0))
8131        run_tests(lambda v: v.t_())
8132        run_tests(lambda v: v.squeeze_(0))
8133        run_tests(lambda v: v.unsqueeze_(0))
8134        run_tests(lambda v: v.swapdims_(0, 0))
8135        run_tests(lambda v: v.swapaxes_(0, 0))
8136
8137    def test_autograd_print_tensor(self):
8138        a = torch.ones(1, requires_grad=True)
8139        a_clone = a.clone()
8140        self.assertEqual(repr(a), "tensor([1.], requires_grad=True)")
8141        self.assertEqual(repr(a_clone), "tensor([1.], grad_fn=<CloneBackward0>)")
8142
8143        with torch.no_grad():
8144            b = a[:]
8145            b *= 2
8146
8147        # Special handling for printing view created in no-grad and modified
8148        # in-placed in no-grad.
8149        self.assertEqual(repr(b), "tensor([2.], grad_fn=<Invalid>)")
8150
8151        class Func(torch.autograd.Function):
8152            @staticmethod
8153            def forward(ctx, x):
8154                return x
8155
8156            @staticmethod
8157            def backward(ctx, x):
8158                return x
8159
8160        c = Func.apply(a)
8161        self.assertEqual(repr(c), "tensor([2.], grad_fn=<FuncBackward>)")
8162
8163    def test_autograd_inplace_view_of_view(self):
8164        x = torch.zeros(2)
8165        with torch.no_grad():
8166            y = x.view(2)
8167        y.requires_grad_(True)
8168        z = y.view(2)
8169        with self.assertRaisesRegex(
8170            RuntimeError, "a view of a view .* is being .* inside the no_grad block"
8171        ):
8172            z /= 2
8173
8174        x = torch.zeros(2)
8175        with torch.inference_mode():
8176            y = x.view(2)
8177        y.requires_grad_(True)
8178        z = y.view(2)
8179        with self.assertRaisesRegex(
8180            RuntimeError, "a view of a view .* is being .* inside the inference_mode"
8181        ):
8182            z /= 2
8183
8184    # TODO This is not the correct behavior -
8185    # See https://github.com/pytorch/pytorch/issues/49825#issuecomment-794466627
8186    def test_autograd_inplace_views_cross_dtype(self):
8187        # This test is here to make sure that any change to this behavior is detected
8188        # and not silent. The TODOs below mark the places with unexpected behavior.
8189        a_orig = torch.rand(3, 3, requires_grad=True, dtype=torch.complex64)
8190        a = a_orig.clone()
8191        b = torch.view_as_real(a)
8192        b = b.transpose(0, 1)
8193        b += 1
8194        b.backward(torch.arange(0, 18, dtype=torch.float).view(3, 3, 2))
8195        non_inplace_grad = a_orig.grad
8196
8197        a_orig = torch.rand(3, 3, requires_grad=True, dtype=torch.complex64)
8198        a = a_orig.clone()
8199        b = torch.view_as_real(a)
8200        b.transpose_(0, 1)
8201        b += 1
8202        b.backward(torch.arange(0, 18, dtype=torch.float).view(3, 3, 2))
8203        inplace_grad = a_orig.grad
8204
8205        # TODO: this is a bug!
8206        # once this is fixed, it should have the transpose removed:
8207        # self.assertEqual(non_inplace_grad, inplace_grad)
8208        self.assertEqual(non_inplace_grad.T, inplace_grad)
8209
8210    def test_autograd_multiple_views_python(self):
8211        # This is not necessarily the absolute correct behavior, but this is the current
8212        # one. This test is here to make sure that any change to this behavior is detected
8213        # and not silent. The TODOs below mark the places with unexpected behavior.
8214        # Note that any change in these test will be BC-breaking and should be done carefully.
8215
8216        # This checks that multiples views in the forward are properly traced and how they
8217        # behave with respect to inplace operations.
8218
8219        # This indicator is used to track how many times the backward function was called
8220        bw_called = [0]
8221
8222        class ComplexView(Function):
8223            @staticmethod
8224            def forward(ctx, a, idx):
8225                res = a.narrow(0, idx, 1)
8226                res = a.select(0, idx)
8227                ctx.save_for_backward(a)
8228                ctx.idx = idx
8229                return res
8230
8231            @staticmethod
8232            def backward(ctx, grad):
8233                bw_called[0] += 1
8234                (a,) = ctx.saved_tensors
8235                res = torch.zeros_like(a)
8236                res.select(0, ctx.idx).copy_(grad)
8237                return res, None
8238
8239        a = torch.ones(2, requires_grad=True)
8240        idx = 1
8241
8242        bw_called[0] = 0
8243        out = ComplexView.apply(a.clone(), idx)
8244        out.sum().backward()
8245        self.assertTrue(bw_called[0] == 1)
8246
8247        out = ComplexView.apply(a.clone(), idx)
8248        with self.assertRaisesRegex(
8249            RuntimeError,
8250            "Output 0 of ComplexViewBackward is a view and is being modified inplace",
8251        ):
8252            out += 1
8253
8254    def test_autograd_python_custom_function_inplace(self):
8255        # This is not necessarily the absolute correct behavior, but this is the current
8256        # one. This test is here to make sure that any change to this behavior is detected
8257        # and not silent. The TODOs below mark the places with unexpected behavior.
8258        # Note that any change in these test will be BC-breaking and should be done carefully.
8259
8260        # This test checks custom autograd.Function that perform inplace operations
8261
8262        bw_called = [0]
8263
8264        # I) Single output
8265        class MyAdder(Function):
8266            @staticmethod
8267            def forward(ctx, a, b):
8268                a.add_(b)
8269                ctx.mark_dirty(a)
8270                return a
8271
8272            @staticmethod
8273            def backward(ctx, grad):
8274                bw_called[0] += 1
8275                return grad, grad
8276
8277        a = torch.ones(2, requires_grad=True)
8278        b = torch.ones(2, requires_grad=True)
8279
8280        # No extra inplace
8281        c = MyAdder.apply(a.clone(), b)
8282        c.sum().backward()
8283        self.assertTrue(bw_called[0] == 1)
8284
8285        # With extra inplace on the output
8286        bw_called[0] = 0
8287        c = MyAdder.apply(a.clone(), b)
8288        c += 2
8289        c.sum().backward()
8290        self.assertTrue(bw_called[0] == 1)
8291
8292        # The input is a view
8293        bw_called[0] = 0
8294        c = MyAdder.apply(a.clone().view_as(a), b)
8295        c.sum().backward()
8296        self.assertTrue(bw_called[0] == 1)
8297
8298        # Should not give non-inputs to mark_dirty
8299        class MyAdderBad(Function):
8300            @staticmethod
8301            def forward(ctx, a, b):
8302                c = 3 * a
8303                c.add_(b)
8304                ctx.mark_dirty(c)
8305                return c
8306
8307            @staticmethod
8308            def backward(ctx, grad):
8309                bw_called[0] += 1
8310                grad = 3 * grad
8311                return grad, grad
8312
8313        a = torch.ones(2, requires_grad=True)
8314        b = torch.ones(2, requires_grad=True)
8315
8316        with warnings.catch_warnings(record=True) as w:
8317            MyAdderBad.apply(a.clone(), b)
8318        self.assertEqual(len(w), 1)
8319
8320        # II) Multiple outputs
8321        class MyBadAdder(Function):
8322            @staticmethod
8323            def forward(ctx, a, b):
8324                a.add_(b)
8325                ctx.mark_dirty(a)
8326                return a, a + b
8327
8328            @staticmethod
8329            def backward(ctx, ga, gab):
8330                bw_called[0] += 1
8331                return ga + gab, ga + gab
8332
8333        # No extra inplace
8334        bw_called[0] = 0
8335        c, d = MyBadAdder.apply(a.clone(), b)
8336        (c * d).sum().backward()
8337        self.assertTrue(bw_called[0] == 1)
8338
8339        # With extra inplace on the output
8340        bw_called[0] = 0
8341        c, d = MyBadAdder.apply(a.clone(), b)
8342        c += 2
8343        (c * d).sum().backward()
8344        self.assertTrue(bw_called[0] == 1)
8345
8346        # The input is a view
8347        inplace_on_view_err = (
8348            "your Function modifies inplace an input that is a view of another Tensor"
8349        )
8350        with self.assertRaisesRegex(RuntimeError, inplace_on_view_err):
8351            c, d = MyBadAdder.apply(a.clone().view_as(a), b)
8352
8353        # III) Inplace + other op
8354        class MyOutPlaceAdder(Function):
8355            @staticmethod
8356            def forward(ctx, a, b):
8357                a.add_(b)
8358                ctx.mark_dirty(a)
8359                return a.clone(), a + b
8360
8361            @staticmethod
8362            def backward(ctx, ga, gab):
8363                bw_called[0] += 1
8364                return ga + gab, ga + 2 * gab
8365
8366        # We don't reuse the input
8367        def fn(a, b):
8368            orig_a = a.clone().view_as(a)
8369            c, d = MyOutPlaceAdder.apply(orig_a, b)
8370            return (c * d).sum()
8371
8372        bad_mark_dirty_err = "Some elements marked as dirty during the forward method were not returned as output."
8373        with self.assertRaisesRegex(RuntimeError, bad_mark_dirty_err):
8374            fn(a, b)
8375
8376    def test_custom_function_mark_dirty_not_differentiable(self):
8377        def get_custom_fn(jvp_err):
8378            class InplaceMul(torch.autograd.Function):
8379                @staticmethod
8380                def forward(ctx, x):
8381                    result = x.mul_(2)
8382                    ctx.mark_dirty(result)
8383                    return result
8384
8385                @staticmethod
8386                def backward(ctx, grad_output):
8387                    pass
8388
8389                @staticmethod
8390                def jvp(ctx, x_t):
8391                    if jvp_err:
8392                        return x_t
8393                    else:
8394                        return x_t.mul_(2)
8395
8396            return InplaceMul
8397
8398        for requires_grad, jvp_err in product([True, False], repeat=2):
8399            InplaceMul = get_custom_fn(jvp_err)
8400            # Make sure that tensor is always returned as-is if marked dirty
8401            z = torch.tensor(1.0, requires_grad=requires_grad)
8402            x = z.clone()
8403            y = InplaceMul.apply(x)
8404            self.assertTrue(x is y)
8405            self.assertEqual(x, z * 2)
8406
8407            # jvp must properly modify the input grad if mark_dirty is set
8408            with fwAD.dual_level():
8409                x_tangent = torch.ones_like(x)
8410                x_dual = fwAD.make_dual(x, x_tangent)
8411
8412                if jvp_err:
8413                    bad_mark_dirty_err = (
8414                        "jvp function must modify the corresponding gradient inplace"
8415                    )
8416                    with self.assertRaisesRegex(RuntimeError, bad_mark_dirty_err):
8417                        InplaceMul.apply(x_dual)
8418                else:
8419                    out_dual = InplaceMul.apply(x_dual)
8420                    _, out_tangent = fwAD.unpack_dual(out_dual)
8421                    self.assertTrue(out_dual is x_dual)
8422                    self.assertTrue(out_tangent is x_tangent)
8423
8424    def test_named_tensor_for_complex_views(self):
8425        names = ["batch", "height", "width", "complex"]
8426        z = torch.ones((2, 1, 2, 2), requires_grad=True)
8427        z_named = z.refine_names(*names)
8428        z_complex = torch.view_as_complex(z_named.rename(None)).refine_names(
8429            *names[:-1]
8430        )
8431        z_complex.sum().abs().backward()
8432        expected = torch.ones_like(z_complex).rename(None)
8433        abs_1_1j = abs(1 + 1j)
8434        expected.fill_(complex(abs_1_1j / 2, abs_1_1j / 2))
8435        self.assertEqual(z.grad, torch.view_as_real(expected))
8436
8437    def test_custom_function_return_view_in_nograd(self):
8438        class Alias(Function):
8439            @staticmethod
8440            def forward(ctx, x):
8441                return x[:]
8442
8443            @staticmethod
8444            def backward(ctx, gx):
8445                return gx
8446
8447        inp = torch.rand(2, requires_grad=True)
8448
8449        with torch.no_grad():
8450            output = Alias.apply(inp)
8451
8452        with torch.no_grad():
8453            expected_output = inp[:]
8454
8455        # Calling the custom function should operate as if we called an equivalent op
8456        self.assertEqual(output.requires_grad, expected_output.requires_grad)
8457
8458        # Check that in-place modification on view throws
8459        leaf_grad_err = (
8460            "A view was created in no_grad mode and is being modified inplace"
8461        )
8462        with self.assertRaisesRegex(RuntimeError, leaf_grad_err):
8463            output.zero_()
8464
8465    def test_custom_function_preserve_torch_function_when_return_as_is(self):
8466        class Custom(torch.Tensor):
8467            def __init__(self, data):
8468                super().__init__()
8469                self._data = data
8470
8471            @classmethod
8472            def __torch_function__(cls, func, types, args=(), kwargs=None):
8473                kwargs = {} if kwargs is None else kwargs
8474                args = tuple(a._data if isinstance(a, cls) else a for a in args)
8475                out = func(*args, **kwargs)
8476                if isinstance(out, torch.Tensor):
8477                    out = cls(out)
8478                return out
8479
8480        class Fn(torch.autograd.Function):
8481            @staticmethod
8482            def forward(ctx, input):
8483                return input
8484
8485            @staticmethod
8486            def backward(ctx):
8487                pass
8488
8489        x = Custom(torch.randn(2, 3))
8490        y = Fn.apply(x)
8491        self.assertTrue(isinstance(y, Custom))
8492
8493    def test_grad_mode_restored_reentrant(self):
8494        class MyFunction(Function):
8495            @staticmethod
8496            def forward(ctx, inp):
8497                return inp.clone()
8498
8499            @staticmethod
8500            def backward(ctx, go):
8501                original = torch._C.is_grad_enabled()
8502                with torch.enable_grad():
8503                    self.assertTrue(torch._C.is_grad_enabled())
8504                    foo = torch.rand(go.size(), requires_grad=True)
8505                    (grad,) = torch.autograd.grad(foo**3, foo, grad_outputs=go)
8506                    self.assertTrue(torch._C.is_grad_enabled())
8507                self.assertTrue(torch._C.is_grad_enabled() == original)
8508                return grad
8509
8510        inp = torch.rand(3, requires_grad=True)
8511
8512        # Case where original==False
8513        MyFunction.apply(inp).sum().backward()
8514        # Case where original==True
8515        MyFunction.apply(inp).sum().backward(create_graph=True)
8516
8517    def test_power_function(self):
8518        a = torch.tensor([0.0, 0.0, 0.0])
8519        b = torch.tensor([-1.0, 0.0, 1.0], requires_grad=True)
8520        c = torch.sum(a**b)
8521        c.backward()
8522        self.assertEqual(b.grad, torch.tensor([-inf, 0.0, 0.0]))
8523
8524        s = 0
8525        b = torch.tensor([-1.0, 0.0, 1.0], requires_grad=True)
8526        c = torch.sum(s**b)
8527        c.backward()
8528        self.assertEqual(b.grad, torch.tensor([-inf, 0.0, 0.0]))
8529
8530    def test_custom_function_error(self):
8531        class BadFw(Function):
8532            @staticmethod
8533            def backward(ctx, foo):
8534                return foo
8535
8536        class BadBw(Function):
8537            @staticmethod
8538            def forward(ctx, foo):
8539                return foo.clone()
8540
8541        class BadBw2(Function):
8542            @staticmethod
8543            def forward(ctx, foo):
8544                return foo.clone()
8545
8546            @staticmethod
8547            def backward(ctx, foo):
8548                return foo
8549
8550            @staticmethod
8551            def vjp(ctx, foo):
8552                return foo
8553
8554        class BadJvp(Function):
8555            @staticmethod
8556            def forward(ctx, foo):
8557                return foo.clone()
8558
8559        inp = torch.rand(1, requires_grad=True)
8560        with self.assertRaisesRegex(NotImplementedError, "must implement the forward"):
8561            BadFw.apply(inp)
8562
8563        with self.assertRaisesRegex(RuntimeError, "must implement either the backward"):
8564            BadBw.apply(inp).sum().backward()
8565
8566        with self.assertRaisesRegex(
8567            RuntimeError, "Implementing both 'backward' and 'vjp'"
8568        ):
8569            BadBw2.apply(inp).sum().backward()
8570
8571        with self.assertRaisesRegex(RuntimeError, "must implement the jvp function"):
8572            with fwAD.dual_level():
8573                d = fwAD.make_dual(inp, torch.rand_like(inp))
8574                res = BadJvp.apply(d)
8575
8576    def test_custom_function_forward_mode_view_checks(self):
8577        flag_to_error = {
8578            "ok": None,
8579            "not_a_view": "jvp is not returning a view",
8580            "not_a_view_of_inp": "jvp is not returning a view of the given",
8581            "not_a_view_of_inp_base": "jvp is not returning a view of the same base",
8582        }
8583
8584        class ViewFn(Function):
8585            @staticmethod
8586            def forward(ctx, foo, flag):
8587                ctx.flag = flag
8588                ctx.size = foo.size()
8589                return foo.narrow(0, 0, 2)
8590
8591            @staticmethod
8592            def vjp(ctx, gO):
8593                gI = gO.new_zeros(ctx.size)
8594                gI.narrow(0, 0, 2).copy_(gO)
8595                return gI, None
8596
8597            @staticmethod
8598            def jvp(ctx, gI, _):
8599                res = gI.narrow(0, 0, 2)
8600                if ctx.flag != "ok":
8601                    # Break the view in the gradients!
8602                    res = res.clone()
8603                if ctx.flag in ["not_a_view_of_inp", "not_a_view_of_inp_base"]:
8604                    # Result should be a view, just of the wrong thing
8605                    res = res.view_as(res)
8606                return res
8607
8608        inp = torch.rand(4, 4, dtype=torch.double, requires_grad=True)
8609
8610        for flag, msg in flag_to_error.items():
8611
8612            def test_fn(inp):
8613                if flag == "not_a_view_of_inp_base":
8614                    inp = inp.view_as(inp)
8615                return ViewFn.apply(inp, flag)
8616
8617            if msg is None:
8618                gradcheck(test_fn, inp, check_forward_ad=True)
8619            else:
8620                with self.assertRaisesRegex(RuntimeError, msg):
8621                    gradcheck(test_fn, inp, check_forward_ad=True)
8622
8623    def test_custom_function_forward_mode_inplace_checks(self):
8624        class InplaceFn(Function):
8625            @staticmethod
8626            def forward(ctx, foo, flag):
8627                ctx.mark_dirty(foo)
8628                ctx.flag = flag
8629                foo.mul_(2)
8630                return foo
8631
8632            @staticmethod
8633            def vjp(ctx, gO):
8634                return 2 * gO, None
8635
8636            @staticmethod
8637            def jvp(ctx, gI, _):
8638                if ctx.flag:
8639                    # Don't do the change inplace
8640                    return 2 * gI
8641                else:
8642                    gI.mul_(2)
8643                    return gI
8644
8645        inp = torch.rand(4, 4, dtype=torch.double, requires_grad=True)
8646
8647        def test_fn(inp, flag):
8648            inp = inp.clone()
8649            return InplaceFn.apply(inp, flag)
8650
8651        gradcheck(test_fn, (inp, False), check_forward_ad=True)
8652
8653        with self.assertRaisesRegex(
8654            RuntimeError,
8655            "inplace custom Function is not modifying the forward mode gradients inplace",
8656        ):
8657            gradcheck(test_fn, (inp, True), check_forward_ad=True)
8658
8659    def test_custom_function_forward_mode_wrong_formula(self):
8660        class UserFn(Function):
8661            @staticmethod
8662            def forward(ctx, foo, should_fail):
8663                ctx.should_fail = should_fail
8664                return foo * 2
8665
8666            @staticmethod
8667            def vjp(ctx, gO):
8668                return 2 * gO, None
8669
8670            @staticmethod
8671            def jvp(ctx, gI, _):
8672                if ctx.should_fail:
8673                    # Wrong gradient formula
8674                    return 3 * gI
8675                else:
8676                    return 2 * gI
8677
8678        inp = torch.rand(10, dtype=torch.double, requires_grad=True)
8679        gradcheck(UserFn.apply, (inp, False), check_forward_ad=True)
8680
8681        with self.assertRaisesRegex(
8682            RuntimeError, "Jacobian computed with forward mode mismatch for output 0"
8683        ):
8684            gradcheck(UserFn.apply, (inp, True), check_forward_ad=True)
8685
8686    def test_custom_function_forward_mode_non_tensor_before_tensor_args(self):
8687        class MyFn(torch.autograd.Function):
8688            @staticmethod
8689            def forward(ctx, nt, x, nt2, y):
8690                return x * 2 + y * 3
8691
8692            @staticmethod
8693            def jvp(ctx, nt, x_t, nt2, y_t):
8694                self.assertIsNone(nt)
8695                self.assertIsNone(nt2)
8696                return x_t * 2 + y_t * 3
8697
8698        x = torch.tensor(1.0, dtype=torch.double)
8699        t = torch.tensor(1.0, dtype=torch.double)
8700        y = torch.tensor(1.0, dtype=torch.double)
8701
8702        with fwAD.dual_level():
8703            dual_x = fwAD.make_dual(x, t)
8704            MyFn.apply(1, dual_x, 1, y)
8705
8706        gradcheck(
8707            MyFn.apply,
8708            (1, x.requires_grad_(True), 1, y.requires_grad_(True)),
8709            check_forward_ad=True,
8710            check_backward_ad=False,
8711            check_batched_grad=False,
8712        )
8713
8714    def test_custom_function_forward_mode_forward_is_no_op(self):
8715        error_regex = (
8716            "A custom Function's forward is returning a view \\(or an input as-is\\)"
8717        )
8718
8719        return_lambdas = {
8720            # If we return an input as-is in forward, that is treated
8721            # as if self.view_as(self) is performed. If jvp returns x.view_as(x),
8722            # this is OK.
8723            "view_as": lambda x: x.view_as(x),
8724            # Expect this to raise an error
8725            "self": lambda x: x,
8726            # Expect this to raise the same error
8727            "mul_by_2": lambda x: x * 2,
8728        }
8729
8730        for k, fn in return_lambdas.items():
8731
8732            class MyFn(torch.autograd.Function):
8733                @staticmethod
8734                def forward(ctx, x, y):
8735                    return x + y, x
8736
8737                @staticmethod
8738                def vjp(ctx, gO1, gO2):
8739                    return gO1 + gO2, gO1
8740
8741                @staticmethod
8742                def jvp(ctx, x_t, y_t):
8743                    return x_t + y_t, fn(x_t)
8744
8745            a = torch.tensor(1.0, dtype=torch.double, requires_grad=True)
8746            t = torch.tensor(1.0, dtype=torch.double)
8747            b = torch.tensor(1.0, dtype=torch.double, requires_grad=True)
8748
8749            c = torch.tensor(1.0, dtype=torch.double)
8750            t2 = torch.tensor(1.0, dtype=torch.double)
8751            d = torch.tensor(1.0, dtype=torch.double)
8752
8753            with fwAD.dual_level():
8754                a_dual = fwAD.make_dual(a, t)
8755                c_dual = fwAD.make_dual(c, t2)
8756
8757                if k == "view_as":
8758                    _, out2 = MyFn.apply(a_dual, b)
8759                    self.assertTrue(fwAD.unpack_dual(out2).tangent._base is t)
8760
8761                    _, out2 = MyFn.apply(c_dual, d)
8762                    self.assertTrue(fwAD.unpack_dual(out2).tangent._base is t2)
8763                else:
8764                    with self.assertRaisesRegex(RuntimeError, error_regex):
8765                        MyFn.apply(a_dual, b)
8766
8767                    with self.assertRaisesRegex(RuntimeError, error_regex):
8768                        MyFn.apply(c_dual, d)
8769
8770            if k == "view_as":
8771                gradcheck(MyFn.apply, (a, c), check_forward_ad=True)
8772            else:
8773                with self.assertRaisesRegex(RuntimeError, error_regex):
8774                    gradcheck(MyFn.apply, (a, c), check_forward_ad=True)
8775
8776    def test_custom_function_save_for_forward(self):
8777        class Func(torch.autograd.Function):
8778            @staticmethod
8779            def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
8780                ctx.save_for_backward(x, y)
8781                ctx.save_for_forward(x, y)
8782                ctx.z = z
8783                ctx.prod = x * y
8784                return z * ctx.prod
8785
8786            @staticmethod
8787            def jvp(ctx, x_t, y_t, _):
8788                x_p, y_p = ctx.saved_tensors
8789                z = ctx.z
8790                return z * (y_p * x_t + x_p * y_t)
8791
8792            @staticmethod
8793            def vjp(ctx, grad_out):
8794                x, y = ctx.saved_tensors
8795                z = ctx.z
8796                return z * grad_out * y, z * grad_out * x, None
8797
8798        a = torch.tensor(1.0, requires_grad=True, dtype=torch.double)
8799        t = torch.tensor(1.0, dtype=torch.double)
8800        b = torch.tensor(2.0, requires_grad=True, dtype=torch.double)
8801        c = 4
8802
8803        with fwAD.dual_level():
8804            a_dual = fwAD.make_dual(a, t)
8805            out = Func.apply(a_dual, b, c)
8806            out.backward()
8807
8808        gradcheck(Func.apply, (a, b, c), check_forward_ad=True)
8809
8810        # When saved for backward, but not saved for forward
8811        class Func(torch.autograd.Function):
8812            @staticmethod
8813            def forward(ctx, x: torch.Tensor):
8814                ctx.save_for_backward(x)
8815                return x.clone()
8816
8817            @staticmethod
8818            def jvp(ctx, x_t):
8819                self.assertEqual(len(ctx.saved_tensors), 0)
8820                return x_t
8821
8822            @staticmethod
8823            def vjp(ctx, grad_out):
8824                (x,) = ctx.saved_tensors
8825                self.assertEqual(len(ctx.saved_tensors), 1)
8826                return grad_out
8827
8828        with fwAD.dual_level():
8829            a_dual = fwAD.make_dual(a, t)
8830            out = Func.apply(a_dual)
8831            out.backward()
8832
8833        gradcheck(Func.apply, (a,), check_forward_ad=True)
8834
8835    @skipIfTorchDynamo("compile tested in test/dynamo/test_autograd_function.py")
8836    def test_custom_function_forward_mode_non_differentiable(self):
8837        # returns differentiable type, marked non-differentiable
8838        class Func(torch.autograd.Function):
8839            @staticmethod
8840            def forward(ctx, x, y):
8841                out = y.clone()
8842                ctx.mark_non_differentiable(out)
8843                return x.clone(), out
8844
8845            @staticmethod
8846            def jvp(ctx, x_tangent, y_tangent):
8847                return x_tangent, None
8848
8849        x = torch.tensor(2.0)
8850        x_tangent = torch.tensor(1.0)
8851        y = torch.tensor(3.0)
8852
8853        with fwAD.dual_level():
8854            x_dual = fwAD.make_dual(x, x_tangent)
8855            _, out2_dual = Func.apply(x_dual, y)
8856            self.assertEqual(fwAD.unpack_dual(out2_dual).tangent, None)
8857
8858        y = torch.tensor(3)
8859
8860        # returns non-differentiable type, NOT marked non-differentiable
8861        class Func(torch.autograd.Function):
8862            @staticmethod
8863            def forward(ctx, x, y):
8864                return x.clone(), y.clone()
8865
8866            @staticmethod
8867            def jvp(ctx, x_tangent, y_tangent):
8868                self.assertIsNone(y_tangent)
8869                return x_tangent, None
8870
8871        with fwAD.dual_level():
8872            x_dual = fwAD.make_dual(x, x_tangent)
8873            _, out2_dual = Func.apply(x_dual, y)
8874            self.assertEqual(fwAD.unpack_dual(out2_dual).tangent, None)
8875
8876        class FuncWrong(torch.autograd.Function):
8877            @staticmethod
8878            def forward(ctx, x, y):
8879                out = y.clone()
8880                ctx.mark_non_differentiable(out)
8881                return x.clone(), out
8882
8883            @staticmethod
8884            def jvp(ctx, x_tangent, y_tangent):
8885                return x_tangent, x_tangent.clone()
8886
8887        with fwAD.dual_level():
8888            x_dual = fwAD.make_dual(x, x_tangent)
8889            with self.assertRaisesRegex(
8890                RuntimeError, "You should return None at that position instead"
8891            ):
8892                FuncWrong.apply(x_dual, y)
8893
8894        # returns non-tensor
8895        class Func(torch.autograd.Function):
8896            @staticmethod
8897            def forward(ctx, x):
8898                return x.clone(), object(), x.clone()
8899
8900            @staticmethod
8901            def jvp(ctx, x_tangent):
8902                return x_tangent, None, x_tangent
8903
8904        with fwAD.dual_level():
8905            x_dual = fwAD.make_dual(x, x_tangent)
8906            out_dual, _, out2_dual = Func.apply(x_dual)
8907            self.assertEqual(fwAD.unpack_dual(out_dual).tangent, x_tangent)
8908            self.assertEqual(fwAD.unpack_dual(out2_dual).tangent, x_tangent)
8909
8910    def test_custom_function_local_inplace(self):
8911        class MyFn(torch.autograd.Function):
8912            @staticmethod
8913            def forward(ctx, inp, inplace):
8914                view = inp.clone()[:3]
8915                if inplace:
8916                    view += 2
8917                return view
8918
8919            @staticmethod
8920            def backward(ctx, grad):
8921                return grad, None
8922
8923        base = torch.rand(10, requires_grad=True)
8924
8925        foo = MyFn.apply(base, False)
8926        self.assertEqual(foo.grad_fn.__class__.__name__, "MyFnBackward")
8927
8928        foo = MyFn.apply(base, True)
8929        self.assertEqual(foo.grad_fn.__class__.__name__, "MyFnBackward")
8930
8931    def test_integer_outputs(self):
8932        inp = torch.rand(4, requires_grad=True)
8933
8934        out = inp.argmax()
8935        self.assertFalse(out.dtype.is_floating_point)
8936        self.assertFalse(out.requires_grad)
8937
8938        out = inp.argmin()
8939        self.assertFalse(out.dtype.is_floating_point)
8940        self.assertFalse(out.requires_grad)
8941
8942        out = inp.argsort()
8943        self.assertFalse(out.dtype.is_floating_point)
8944        self.assertFalse(out.requires_grad)
8945
8946        val = torch.rand((), requires_grad=True)
8947
8948        out = torch.searchsorted(inp, val)
8949        self.assertFalse(out.dtype.is_floating_point)
8950        self.assertFalse(out.requires_grad)
8951
8952        bins = torch.linspace(0, 1.0, steps=100, requires_grad=True)
8953        vals = torch.rand(5, 5, requires_grad=True)
8954        out = torch.bucketize(vals, bins)
8955        self.assertFalse(out.dtype.is_floating_point)
8956        self.assertFalse(out.requires_grad)
8957
8958        val = torch.empty(5).requires_grad_()
8959        out = val.count_nonzero()
8960        self.assertFalse(out.requires_grad)
8961
8962        def assert_only_first_requires_grad(res):
8963            if not isinstance(res, tuple):
8964                res = (res,)
8965            self.assertTrue(res[0].requires_grad)
8966            for out in res[1:]:
8967                if out is not None:
8968                    self.assertFalse(out.requires_grad)
8969
8970        for sort in [True, False]:
8971            for return_inverse in [True, False]:
8972                for return_counts in [True, False]:
8973                    res = torch.unique(
8974                        inp,
8975                        sorted=sort,
8976                        return_inverse=return_inverse,
8977                        return_counts=return_counts,
8978                    )
8979                    assert_only_first_requires_grad(res)
8980
8981                    res = torch.unique(
8982                        inp,
8983                        sorted=sort,
8984                        return_inverse=return_inverse,
8985                        return_counts=return_counts,
8986                        dim=0,
8987                    )
8988                    assert_only_first_requires_grad(res)
8989
8990                    res = torch.unique_consecutive(
8991                        inp, return_inverse=return_inverse, return_counts=return_counts
8992                    )
8993                    assert_only_first_requires_grad(res)
8994
8995                    res = torch.unique_consecutive(
8996                        inp,
8997                        return_inverse=return_inverse,
8998                        return_counts=return_counts,
8999                        dim=0,
9000                    )
9001                    assert_only_first_requires_grad(res)
9002
9003                    # Here we test the internal functions to make sure all of them are
9004                    # covered on top of the public API
9005                    res = torch._unique(inp, sorted=sort, return_inverse=return_inverse)
9006                    assert_only_first_requires_grad(res)
9007
9008                    # This looks public but is actually manually deleted from the
9009                    # torch namespace in torch/functional.py
9010                    res = torch._VF.unique_dim(
9011                        inp,
9012                        dim=0,
9013                        sorted=sort,
9014                        return_inverse=return_inverse,
9015                        return_counts=return_counts,
9016                    )
9017                    assert_only_first_requires_grad(res)
9018
9019                    # We don't test `unique_dim_consecutive` here.
9020                    # It looks public but the python binding is actually manually disabled in
9021                    # tools/autograd/gen_python_functions.py
9022
9023                    res = torch._unique2(
9024                        inp,
9025                        sorted=sort,
9026                        return_inverse=return_inverse,
9027                        return_counts=return_counts,
9028                    )
9029                    assert_only_first_requires_grad(res)
9030
9031    def test_custom_function_cycle(self):
9032        class MyFn(Function):
9033            @staticmethod
9034            def forward(ctx, x, metadata):
9035                x = x.clone()
9036                ctx.meta = metadata
9037                ctx.save_for_backward(x)
9038                return x
9039
9040            @staticmethod
9041            def backward(ctx, gO):
9042                (x,) = ctx.saved_tensors
9043                self.assertEqual(x, 3.14)
9044                self.assertEqual(ctx.meta["foo"], 3.14)
9045                return gO * x, None
9046
9047        def get_refs(with_backward):
9048            a = torch.tensor(3.14, requires_grad=True)
9049
9050            metadata = {}
9051            out = MyFn.apply(a, metadata)
9052
9053            metadata["foo"] = out
9054
9055            if with_backward:
9056                out.sum().backward()
9057                self.assertEqual(a.grad, a)
9058
9059            return torch._C._WeakTensorRef(out)
9060
9061        with disable_gc():
9062            ref = get_refs(False)
9063            self.assertFalse(ref.expired())
9064        gc.collect()
9065        self.assertTrue(ref.expired())
9066
9067        # The backward clears the saved_variables but not the __dict__
9068        with disable_gc():
9069            ref = get_refs(True)
9070            self.assertFalse(ref.expired())
9071        gc.collect()
9072        self.assertTrue(ref.expired())
9073
9074    def test_create_graph_and_full_backward_hook_cycle(self):
9075        # If BackwardHook saves grad_output, it can create a cycle when we perform backward
9076        # with create_graph=True
9077        #
9078        #   grad_output -> grad_output.grad_fn -> graph -> hook -> grad_output
9079        #
9080        class TestCls:
9081            # Dummy class for the purpose of creating a weakref
9082            pass
9083
9084        def get_ref(input_requires_grad, nb_hooks):
9085            t = torch.randn(10, requires_grad=input_requires_grad)
9086            a = torch.tensor(1.0, requires_grad=True)
9087
9088            class Test(nn.Module):
9089                def forward(self, x):
9090                    return x**2 * a**2
9091
9092            mod = Test()
9093
9094            for _ in range(nb_hooks):
9095                mod.register_full_backward_hook(lambda a, b, c: None)
9096
9097            tmp = mod(t)
9098
9099            # Save dummy object to graph and get a weak ref to it
9100            test = TestCls()
9101            ref = weakref.ref(test)
9102            tmp.grad_fn.metadata["a"] = test
9103
9104            with set_warn_always_context(True):
9105                with warnings.catch_warnings(record=True) as w:
9106                    tmp.exp().sum().backward(create_graph=True)
9107                    self.assertTrue(len(w) == 1)
9108                    self.assertTrue(
9109                        "Using backward() with create_graph=True" in str(w[0].message)
9110                    )
9111
9112            # Remove the backward + create_graph=True cycle
9113            a.grad = None
9114            t.grad = None
9115
9116            return ref
9117
9118        for nb_hooks in (1, 2, 3):
9119            for input_requires_grad in (True, False):
9120                ref_ = get_ref(
9121                    input_requires_grad=input_requires_grad,
9122                    nb_hooks=nb_hooks,
9123                )
9124                gc.collect()
9125                self.assertIsNone(ref_())
9126
9127    @parametrize("use_custom_function", [True, False])
9128    @parametrize("use_tensor_hook", [True, False])
9129    def test_hook_closure_cycle(self, use_custom_function, use_tensor_hook):
9130        # This creates a cycle between the hook and grad_fn_b
9131        # hook -> closure -> grad_fn_b (python) -> grad_fn (cpp) -> hook (cpp)
9132        # -> dict -> hook
9133        #
9134        # This test is testing that the grad_fn_b (python) only traverses the
9135        # dict if it is the only one holding a reference to the grad_fn_b (cpp)
9136        # shared_ptr
9137        #
9138        # See: https://github.com/pytorch/pytorch/issues/102174
9139        class Function(torch.autograd.Function):
9140            @staticmethod
9141            def forward(ctx, x):
9142                return x
9143
9144            @staticmethod
9145            def backward(ctx, grad):
9146                return grad
9147
9148        class Test:
9149            pass
9150
9151        count = [0]
9152
9153        def scope():
9154            a = torch.tensor(1.0, requires_grad=True)
9155            if use_custom_function:
9156                b = Function.apply(a)
9157            else:
9158                b = a.clone()
9159            grad_fn_b = b.grad_fn
9160            obj = Test()
9161
9162            def hook(*args):
9163                # Make sure this hook's closure holds onto grad_fn_b
9164                # This forms a cycle between the hook and grad_fn_b
9165                # We also hold onto a sentinel object 'obj' to track
9166                # whether this cycle is still alive. See 'ref' below.
9167                grad_fn_b
9168                obj
9169                count[0] += 1
9170
9171            if use_tensor_hook:
9172                b.register_hook(hook)
9173            else:
9174                b.grad_fn.register_hook(hook)
9175            c = b.clone()
9176            ref = weakref.ref(obj)
9177            return c, ref
9178
9179        with disable_gc():
9180            out, ref = scope()
9181            out.backward(retain_graph=True)
9182
9183            gc.collect()
9184
9185            # Make sure gc does not clear the cycle noted above.
9186            # e.g. the hook is alive and gets fired even after gc runs
9187            out.backward(retain_graph=True)
9188            self.assertEqual(count[0], 2)
9189
9190            # ref is still alive because the use_count of the cpp grad_fn
9191            # shared_ptr > 1 since (1) the python grad_fn is alive, and (2) the
9192            # rest of the graph holds onto the shared_ptr
9193            self.assertIsNotNone(ref())
9194
9195            # Then delete the rest of the graph and check that ref is dead
9196            del out
9197            gc.collect()
9198            self.assertIsNone(ref())
9199
9200    def test_full_backward_hook_double_backward(self):
9201        x = torch.rand(1, requires_grad=True)
9202        y = torch.rand_like(x)
9203
9204        func = torch.nn.MSELoss()
9205        counter = [0]
9206
9207        def hook(module, grad_input, grad_output):
9208            counter[0] += 1
9209
9210        func.register_full_backward_hook(hook)
9211
9212        f = func(x, y)
9213
9214        (gradx_f,) = torch.autograd.grad(f, x, create_graph=True)
9215        self.assertEqual(counter[0], 1)
9216        _ = torch.autograd.grad(gradx_f, x)
9217        # We should not error, and counter should not be incremented
9218        self.assertEqual(counter[0], 1)
9219
9220    def test_input_buffer_accum(self):
9221        leaf = torch.rand(2, 2, requires_grad=True)
9222
9223        # An op that returns sparse gradients
9224        ind = torch.tensor([[0, 0]], dtype=torch.long)
9225        out2 = leaf.gather(0, ind, sparse_grad=True)
9226
9227        # An op that returns the gradients as-is
9228        out1 = leaf.clone()
9229
9230        grad_out1_original = torch.rand_like(out1)
9231        grad_out1 = grad_out1_original.clone()
9232        grad_out2 = torch.rand_like(out2)
9233
9234        torch.autograd.backward((out1, out2), (grad_out1, grad_out2))
9235
9236        # Given gradients should not be modified inplace
9237        self.assertEqual(grad_out1, grad_out1_original)
9238
9239    def test_no_unnecessary_unwrapping(self):
9240        a = torch.randn(5, requires_grad=True)
9241        a_orig = a.detach().clone()
9242        b = a * a
9243        c = a * b
9244        d = torch.exp(a)
9245
9246        # a is leaf
9247        self.assertIs(b.grad_fn._saved_self, a)
9248        self.assertIs(b.grad_fn._saved_other, a)
9249        self.assertIs(c.grad_fn._saved_self, a)
9250
9251        # b is not an output
9252        self.assertIs(c.grad_fn._saved_other, b)
9253
9254        # d is an output
9255        self.assertEqual(d.grad_fn._saved_result, d)
9256        self.assertIsNot(d.grad_fn._saved_result, d)
9257
9258        c.sum().backward()
9259
9260        with self.assertRaisesRegex(RuntimeError, "after they have already been freed"):
9261            c.grad_fn._saved_self
9262
9263        # a is left untouched
9264        self.assertEqual(a, a_orig)
9265
9266    def test_saved_variable_version_counter(self):
9267        a = torch.rand(2, requires_grad=True)
9268
9269        b = torch.exp(a)
9270
9271        b_unpacked = b.grad_fn._saved_result
9272        self.assertEqual(b, b_unpacked)
9273        self.assertEqual(b._version, b_unpacked._version)
9274
9275        with torch.no_grad():
9276            b += 1
9277
9278        self.assertEqual(b, b_unpacked)
9279        self.assertEqual(b._version, b_unpacked._version)
9280
9281    def test_saved_variable_packing_unpacking_saved_original_with_hooks(self):
9282        # Tests that packing/unpacking a SavedVariable works correctly with user-defined hooks
9283        # The saved_original / did_not_save_original distinction corresponds to the `save_original`
9284        # attribute of `SavedVariable`.
9285
9286        def test(get_input, is_leaf):
9287            a = get_input()
9288            grad_fn = a.grad_fn
9289            y = a * a
9290            y.grad_fn._raw_saved_self.register_hooks(lambda x: 2 * x, lambda x: x / 2)
9291            self.assertEqual(a, y.grad_fn._saved_self)
9292            if not is_leaf:
9293                self.assertIs(grad_fn, y.grad_fn._saved_self.grad_fn)
9294                y.sum().backward()
9295            else:
9296                y.sum().backward()
9297                self.assertEqual(2 * a, a.grad)
9298
9299            a = get_input()
9300            grad_fn = a.grad_fn
9301            y = a * a
9302            y.grad_fn._raw_saved_self.register_hooks(lambda x: 2 * x, lambda x: x)
9303            self.assertEqual(2 * a, y.grad_fn._saved_self)
9304            if not is_leaf:
9305                self.assertIs(grad_fn, y.grad_fn._saved_self.grad_fn)
9306                y.sum().backward()
9307            else:
9308                y.sum().backward()
9309                self.assertEqual(3 * a, a.grad)
9310
9311            # double backward
9312            a = get_input()
9313            grad_fn = a.grad_fn
9314            y = a**3
9315            y.grad_fn._raw_saved_self.register_hooks(lambda x: x, lambda x: x)
9316            s = torch.sum(y)
9317            (g,) = torch.autograd.grad(s, (a,), create_graph=True)
9318            if not is_leaf:
9319                self.assertIs(grad_fn, y.grad_fn._saved_self.grad_fn)
9320                g.sum().backward()
9321            else:
9322                g.sum().backward()
9323                self.assertEqual(6 * a, a.grad)
9324
9325            a = get_input()
9326            y = a * a
9327            y.grad_fn._raw_saved_self.register_hooks(lambda x: x, lambda x: 1)
9328            with self.assertRaisesRegex(
9329                TypeError, "Output of saved tensor unpack_hook expected to be a Tensor"
9330            ):
9331                print(y.grad_fn._saved_self)
9332
9333            a = get_input()
9334            y = a * a
9335            with self.assertRaisesRegex(
9336                TypeError, "missing 1 required positional argument"
9337            ):
9338                y.grad_fn._raw_saved_self.register_hooks(lambda x, b: x, lambda x: x)
9339
9340            a = get_input()
9341            y = a * a
9342            with self.assertRaisesRegex(
9343                TypeError, "missing 1 required positional argument"
9344            ):
9345                y.grad_fn._raw_saved_self.register_hooks(
9346                    lambda x, b: (x, b), lambda x: x
9347                )
9348
9349            def inplace_double(x):
9350                x *= 2
9351                return x
9352
9353            a = get_input()
9354            t = a * a
9355
9356            with self.assertRaisesRegex(
9357                RuntimeError,
9358                "A saved tensor pack hook is modifying its input in place.",
9359            ):
9360                t.grad_fn._raw_saved_self.register_hooks(
9361                    inplace_double, lambda x: x / 2
9362                )
9363
9364        # leaf
9365        test(lambda: torch.randn(5, requires_grad=True), True)
9366
9367        # not leaf, not output
9368        test(lambda: (1 + torch.randn(5, requires_grad=True)), False)
9369
9370    def test_saved_variable_saved_original_inplace_detach(self):
9371        # Detaching a tensor that is saved input raises
9372        a = torch.tensor(1.0, requires_grad=True).clone()
9373        b = a.sin()
9374        a.detach_()
9375        with self.assertRaisesRegex(
9376            RuntimeError, "Trying to use a saved tensor that has been detached"
9377        ):
9378            b.backward()
9379
9380        # Detaching a tensor that is saved as output is OK
9381        a = torch.tensor(1.0, requires_grad=True).clone()
9382        b = a.exp()
9383        a.detach_()
9384        b.backward()
9385
9386    def test_saved_variable_packing_unpacking_did_not_save_original_with_hooks(self):
9387        # Tests that packing/unpacking a SavedVariable works correctly with user-defined hooks
9388        # The saved_original / did_not_save_original distinction corresponds to the `save_original`
9389        # attribute of `SavedVariable`.
9390
9391        a = torch.randn(5, requires_grad=True)
9392        y = torch.exp(a)
9393        y.grad_fn._raw_saved_result.register_hooks(lambda x: x, lambda x: x)
9394        self.assertEqual(y, y.grad_fn._saved_result)
9395        self.assertIs(y.grad_fn, y.grad_fn._saved_result.grad_fn)
9396        y.sum().backward()
9397        self.assertEqual(a.grad, y)
9398
9399    def test_saved_variable_packing_unpacking_saved_original_with_default_hooks(self):
9400        # Tests that default hooks are properly registered, used and reset
9401        # The saved_original / did_not_save_original distinction corresponds to the `save_original`
9402        # attribute of `SavedVariable`.
9403        # See also:
9404        #  - test_saved_variable_packing_unpacking_saved_original_with_hooks
9405
9406        def pack(x):
9407            warnings.warn("pack")
9408            return x
9409
9410        with torch.autograd.graph.saved_tensors_hooks(pack, lambda x: x):
9411            a = torch.ones(5, requires_grad=True)
9412
9413            with warnings.catch_warnings(record=True) as w:
9414                warnings.simplefilter("always")
9415                y = a * a
9416                # should raise two warnings from a being saved twice
9417                self.assertEqual(len(w), 2)
9418
9419        with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
9420            a = torch.randn(5, requires_grad=True)
9421            y = a * a
9422            self.assertEqual(a, y.grad_fn._saved_self)
9423            self.assertEqual(a, y.grad_fn._saved_other)
9424            y.sum().backward()
9425            self.assertEqual(2 * a, a.grad)
9426
9427        with torch.autograd.graph.saved_tensors_hooks(lambda x: 2 * x, lambda x: x / 2):
9428            a = torch.randn(5, requires_grad=True)
9429            y = a * a
9430            self.assertEqual(a, y.grad_fn._saved_self)
9431            self.assertEqual(a, y.grad_fn._saved_other)
9432            y.sum().backward()
9433            self.assertEqual(2 * a, a.grad)
9434
9435        with torch.autograd.graph.saved_tensors_hooks(lambda x: 2 * x, lambda x: x):
9436            a = torch.randn(5, requires_grad=True)
9437            y = a * a
9438            self.assertEqual(2 * a, y.grad_fn._saved_self)
9439            self.assertEqual(2 * a, y.grad_fn._saved_other)
9440            y.sum().backward()
9441            self.assertEqual(4 * a, a.grad)
9442
9443        # Exited hooks correctly
9444        a = torch.randn(5, requires_grad=True)
9445        y = a * a
9446        self.assertEqual(a, y.grad_fn._saved_self)
9447        self.assertEqual(a, y.grad_fn._saved_other)
9448        y.sum().backward()
9449        self.assertEqual(2 * a, a.grad)
9450
9451    def test_saved_variable_packing_unpacking_did_not_save_original_with_default_hooks(
9452        self,
9453    ):
9454        # See also test_saved_variable_packing_unpacking_did_not_save_original_with_hooks
9455
9456        with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
9457            a = torch.randn(5, requires_grad=True)
9458            y = torch.exp(a)
9459            self.assertEqual(y, y.grad_fn._saved_result)
9460            y.sum().backward()
9461            self.assertEqual(a.grad, y)
9462
9463    def test_setting_default_saved_variable_hooks_twice_should_not_fail(self):
9464        with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
9465            with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
9466                pass
9467
9468    def test_setting_default_saved_variable_hooks_twice_should_use_inner(self):
9469        with torch.autograd.graph.saved_tensors_hooks(lambda x: 3 * x, lambda x: 3 * x):
9470            b = torch.randn(5, requires_grad=True)
9471            with torch.autograd.graph.saved_tensors_hooks(
9472                lambda x: 5 * x, lambda x: 5 * x
9473            ):
9474                a = torch.randn(5, requires_grad=True)
9475                y = a * a
9476            z = b * b
9477        y.sum().backward()
9478        z.sum().backward()
9479        self.assertEqual(2 * 5 * 5 * a, a.grad)
9480        self.assertEqual(2 * 3 * 3 * b, b.grad)
9481
9482    def test_disabling_saved_tensor_hooks(self):
9483        with torch.autograd.graph.disable_saved_tensors_hooks("error message"):
9484            with self.assertRaisesRegex(RuntimeError, "error message"):
9485                with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
9486                    pass
9487
9488        self.assertTrue(torch._C._autograd._saved_tensors_hooks_is_enabled())
9489
9490        with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
9491            with self.assertRaisesRegex(RuntimeError, "error message"):
9492                with torch.autograd.graph.disable_saved_tensors_hooks("error message"):
9493                    pass
9494
9495        self.assertTrue(torch._C._autograd._saved_tensors_hooks_is_enabled())
9496
9497    def test_disabling_saved_tensor_hooks_nested(self):
9498        with torch.autograd.graph.disable_saved_tensors_hooks("outer"):
9499            with torch.autograd.graph.disable_saved_tensors_hooks("inner"):
9500                with self.assertRaisesRegex(RuntimeError, "inner"):
9501                    with torch.autograd.graph.saved_tensors_hooks(
9502                        lambda x: x, lambda x: x
9503                    ):
9504                        pass
9505
9506            self.assertFalse(torch._C._autograd._saved_tensors_hooks_is_enabled())
9507
9508        self.assertTrue(torch._C._autograd._saved_tensors_hooks_is_enabled())
9509
9510    def test_saved_tensor_hooks_custom_error_propagation(self):
9511        class CustomError(Exception):
9512            pass
9513
9514        class error_on_pack_hook(torch.autograd.graph.saved_tensors_hooks):
9515            def __init__(self) -> None:
9516                def pack_hook(x):
9517                    raise CustomError("pack")
9518
9519                super().__init__(pack_hook, lambda x: x)
9520
9521        class error_on_unpack_hook(torch.autograd.graph.saved_tensors_hooks):
9522            def __init__(self) -> None:
9523                def unpack_hook(x):
9524                    raise CustomError("unpack")
9525
9526                super().__init__(lambda x: x, unpack_hook)
9527
9528        a = torch.tensor(1.0, requires_grad=True)
9529
9530        with error_on_pack_hook():
9531            with self.assertRaisesRegex(CustomError, "pack"):
9532                out = torch.sin(a)
9533
9534        with error_on_unpack_hook():
9535            out = torch.sin(a)
9536            with self.assertRaisesRegex(CustomError, "unpack"):
9537                out.backward()
9538
9539    def test_saved_tensor_hooks_custom_function_intermediates(self):
9540        class Func(torch.autograd.Function):
9541            @staticmethod
9542            def forward(ctx, x):
9543                intermediate = x.exp()
9544                ctx.save_for_backward(
9545                    intermediate.clone().detach_().requires_grad_(True)
9546                )
9547                return x.exp()
9548
9549            @staticmethod
9550            def backward(ctx, grad_out):
9551                (intermediate,) = ctx.saved_tensors
9552                return grad_out * intermediate
9553
9554        a = torch.tensor(1.0, requires_grad=True)
9555
9556        with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
9557            out = Func.apply(a)
9558        out.backward()
9559
9560    def test_unpack_hooks_exec_count(self):
9561        def f(x, y):
9562            return x * y
9563
9564        pack_count = 0
9565        unpack_count = 0
9566
9567        def pack_hook(x):
9568            nonlocal pack_count
9569            pack_count += 1
9570            return x
9571
9572        # unpack hook shouldn't run during compilation, while we trace the forward
9573        def unpack_hook(x):
9574            nonlocal unpack_count
9575            unpack_count += 1
9576            return x
9577
9578        x = torch.ones(4, requires_grad=True)
9579        y = torch.ones(4, requires_grad=False)
9580        with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
9581            out_test = f(x, y)
9582            self.assertEqual(pack_count, 1)
9583            self.assertEqual(unpack_count, 0)
9584            out_test.sum().backward()
9585            self.assertEqual(pack_count, 1)
9586            self.assertEqual(unpack_count, 1)
9587
9588    def test_saved_tensors_hook_version_counter_not_shared(self):
9589        class Test(torch.autograd.Function):
9590            @staticmethod
9591            def forward(ctx, x):
9592                ctx.save_for_backward(x)
9593                return x.sin()
9594
9595            @staticmethod
9596            def backward(ctx, grad_output):
9597                (x,) = ctx.saved_tensors
9598                before = a._version
9599                x.add_(1)
9600                self.assertEqual(a._version, before)
9601                return grad_output
9602
9603        a = torch.tensor(1.0, requires_grad=True)
9604        a_replacement = a.clone()
9605
9606        def pack_hook(x):
9607            return a_replacement
9608
9609        def unpack_hook(x):
9610            return x
9611
9612        with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
9613            b = Test.apply(a)
9614
9615        b.backward()
9616
9617    def test_save_on_cpu_and_checkpoint(self):
9618        a = torch.randn(2, 2, requires_grad=True)
9619
9620        b = a.pow(2).pow(2).pow(2).pow(2)
9621        b.sum().backward()
9622        b_grad = a.grad.clone()
9623        a.grad.zero_()
9624
9625        with torch.autograd.graph.save_on_cpu():
9626            h = a.pow(2)
9627            h = checkpoint(lambda x: x.pow(2).pow(2), h, use_reentrant=False)
9628            c = h.pow(2)
9629        c.sum().backward()
9630        c_grad = a.grad.clone()
9631        a.grad.zero_()
9632
9633        def f(a):
9634            h = a.pow(2)
9635            with torch.autograd.graph.save_on_cpu():
9636                h = h.pow(2).pow(2)
9637            return h.pow(2)
9638
9639        d = checkpoint(f, a, use_reentrant=False)
9640        d.sum().backward()
9641        d_grad = a.grad.clone()
9642
9643        self.assertEqual(b_grad, c_grad)
9644        self.assertEqual(b_grad, d_grad)
9645
9646    def test_pack_hook_with_inplace_modification_should_fail(self):
9647        a = torch.randn(5, requires_grad=True)
9648
9649        def inc(x):
9650            x += 1
9651            return x
9652
9653        with torch.autograd.graph.saved_tensors_hooks(inc, lambda x: x):
9654            with self.assertRaisesRegex(
9655                RuntimeError,
9656                "A saved tensor pack hook is modifying its input in place.",
9657            ):
9658                y = torch.exp(a)
9659
9660        y = torch.exp(a)
9661        with self.assertRaisesRegex(
9662            RuntimeError, "A saved tensor pack hook is modifying its input in place."
9663        ):
9664            y.grad_fn._raw_saved_result.register_hooks(inc, lambda x: x)
9665
9666    def test_saving_variable_to_disk(self):
9667        with tempfile.TemporaryDirectory() as tmp_dir:
9668
9669            def pack(x):
9670                name = os.path.join(tmp_dir, str(uuid.uuid4()))
9671                torch.save(x, name)
9672                return name
9673
9674            def unpack(name):
9675                return torch.load(name)
9676
9677            with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
9678                a = torch.ones(5, requires_grad=True)
9679                y = a * a
9680                self.assertEqual(a, y.grad_fn._saved_self)
9681
9682                y.sum().backward()
9683                self.assertEqual(2 * a, a.grad)
9684
9685    def test_default_saved_tensors_hooks_double_backward(self):
9686        with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
9687            a = torch.randn(5, requires_grad=True)
9688            y = a**3
9689            s = torch.sum(y)
9690            (g,) = torch.autograd.grad(s, (a,), create_graph=True)
9691            g.sum().backward()
9692            self.assertEqual(6 * a, a.grad)
9693
9694        with torch.autograd.graph.saved_tensors_hooks(lambda x: 2 * x, lambda x: x):
9695            a = torch.randn(5, requires_grad=True)
9696            y = a**3
9697            s = torch.sum(y)
9698        (g,) = torch.autograd.grad(s, (a,), create_graph=True)
9699        g.sum().backward()
9700        # factor 2 because only a is saved once
9701        self.assertEqual(6 * 2 * a, a.grad)
9702
9703        a = torch.randn(5, requires_grad=True)
9704        y = a**3
9705        s = torch.sum(y)
9706        with torch.autograd.graph.saved_tensors_hooks(lambda x: 2 * x, lambda x: x):
9707            (g,) = torch.autograd.grad(s, (a,), create_graph=True)
9708            g.sum().backward()
9709            # factor 4 because pow_backward is grad * (exp * self.pow(exp - 1))
9710            # so grad is saved and self (i.e. a) is saved
9711            self.assertEqual(6 * 4 * a, a.grad)
9712
9713        with torch.autograd.graph.saved_tensors_hooks(lambda x: 2 * x, lambda x: x):
9714            a = torch.randn(5, requires_grad=True)
9715            y = a**3
9716            s = torch.sum(y)
9717            (g,) = torch.autograd.grad(s, (a,), create_graph=True)
9718            g.sum().backward()
9719            # combining the two above blocks: 2 * 4 = 8
9720            # note that in that sense, a is saved twice
9721            self.assertEqual(6 * 8 * a, a.grad)
9722
9723    def test_wrapped_number_saved_tensors_hooks(self):
9724        def err_hook(x):
9725            raise RuntimeError("this hook should not be called")
9726
9727        with torch.autograd.graph.saved_tensors_hooks(err_hook, err_hook):
9728            a = torch.randn(5, requires_grad=True)
9729            out = (a * 3).sum()
9730            # 3 is saved as a saved tensor because it is a wrapped number, but
9731            # wrapped numbers should be special cased to not trigger saved variable hooks
9732            torch.autograd.grad(out, (a,))
9733
9734    def test_graph_save_on_cpu(self):
9735        def test(get_input, cuda, pin_memory):
9736            with torch.autograd.graph.save_on_cpu(pin_memory):
9737                a = get_input()
9738                if cuda:
9739                    a.cuda()
9740                y = a * a
9741                self.assertEqual(a, y.grad_fn._saved_self)
9742                self.assertEqual(a, y.grad_fn._saved_other)
9743                self.assertEqual(a.dtype, y.grad_fn._saved_self.dtype)
9744                self.assertEqual(a.layout, y.grad_fn._saved_self.layout)
9745                if y.is_sparse:
9746                    y = y.to_dense()
9747                y.sum().backward()
9748
9749                actual = 2 * a
9750                expected = a.grad
9751                if a.is_sparse:
9752                    actual = actual.coalesce()
9753                    expected = expected.coalesce()
9754
9755                self.assertEqual(actual, expected)
9756
9757        for cuda in [False] + ([True] if torch.cuda.is_available() else []):
9758            for pin_memory in [True, False]:
9759                # FloatTensor
9760                test(lambda: torch.randn(5, requires_grad=True), cuda, pin_memory)
9761                # DoubleTensor
9762                test(
9763                    lambda: torch.randn(5, requires_grad=True, dtype=torch.double),
9764                    cuda,
9765                    pin_memory,
9766                )
9767                # Sparse tensor
9768                x = torch.sparse_coo_tensor(
9769                    torch.tensor([[1, 1]]).long(),
9770                    torch.tensor([1.0, 1.0]),
9771                    requires_grad=True,
9772                )
9773                test(lambda: x, cuda, pin_memory)
9774
9775    @unittest.skipIf(not TEST_CUDA, "test requires CUDA")
9776    def test_graph_save_on_cpu_cuda(self):
9777        def f(x):
9778            a = x + 1
9779            return a * a
9780
9781        # with grad
9782        a = torch.ones(1, requires_grad=True, device="cuda")
9783        y = f(a)
9784        memory_with_grad = torch.cuda.memory_allocated()
9785
9786        del a
9787        del y
9788
9789        # without grad
9790        a = torch.ones(1, requires_grad=True, device="cuda")
9791        with torch.no_grad():
9792            y = f(a)
9793        memory_without_grad = torch.cuda.memory_allocated()
9794
9795        self.assertGreater(memory_with_grad, memory_without_grad)
9796
9797        del a
9798        del y
9799
9800        # with hooks
9801        with torch.autograd.graph.save_on_cpu():
9802            a = torch.ones(1, requires_grad=True, device="cuda")
9803            y = f(a)
9804            memory_with_hooks = torch.cuda.memory_allocated()
9805            self.assertEqual(memory_with_hooks, memory_without_grad)
9806
9807    @unittest.skipIf(not TEST_CUDA, "test requires CUDA")
9808    def test_scalar_grad_mixed_device(self):
9809        x = torch.tensor(1.0, requires_grad=True)
9810        y = torch.randn(2, 2, device="cuda")
9811        out = x * y
9812        out.sum().backward()
9813
9814    def test_multi_grad_all_hooks(self):
9815        t1 = torch.rand(2, requires_grad=True)
9816        t2 = torch.rand(2, requires_grad=True)
9817        t3 = torch.rand(2, requires_grad=True)
9818        t4 = torch.rand(2, requires_grad=True)
9819
9820        # Ensure we properly detect all types of Nodes here
9821        # C++ Node
9822        t1 = t1.mul(2)
9823
9824        # Python custom Function
9825        class Foo(Function):
9826            @staticmethod
9827            def forward(ctx, a):
9828                return a.clone()
9829
9830            @staticmethod
9831            def backward(ctx, gO):
9832                return gO
9833
9834        t2 = Foo.apply(t2)
9835
9836        # C++ Node
9837        t3 = torch._C._functions.UndefinedGrad()(t3)
9838
9839        # C++ Custom Op
9840        cpp_source = """
9841struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
9842  static torch::Tensor forward(
9843      torch::autograd::AutogradContext* ctx,
9844      const torch::Tensor& x) {
9845    return x.clone();
9846  }
9847
9848  static torch::autograd::variable_list backward(
9849      torch::autograd::AutogradContext *ctx,
9850      torch::autograd::variable_list grad_output) {
9851    return grad_output;
9852  }
9853};
9854
9855torch::Tensor custom_op_backed_by_autograd_fn(torch::Tensor x) {
9856  return CustomOpAutogradFunction::apply(x);
9857}
9858
9859TORCH_LIBRARY(test_autograd_cpp_node, m) {
9860    m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn);
9861}
9862        """
9863
9864        module = load_inline(
9865            name="test_autograd_cpp_node",
9866            cpp_sources=cpp_source,
9867            functions="custom_op_backed_by_autograd_fn",
9868            verbose=True,
9869        )
9870
9871        t4 = torch.ops.test_autograd_cpp_node.custom_op_backed_by_autograd_fn(t4)
9872
9873        res = [None] * 4
9874        count = [0]
9875
9876        def hook(grads):
9877            nonlocal res
9878            count[0] += 1
9879            res = [g is not None for g in grads]
9880
9881        handle = torch.autograd.graph.register_multi_grad_hook((t1, t2, t3, t4), hook)
9882
9883        out = t2 * t3
9884
9885        out.sum().backward(inputs=(t2, t3), retain_graph=True)
9886        self.assertEqual(count[0], 1)
9887        self.assertEqual(res, [False, True, True, False])
9888
9889        out.sum().backward(inputs=(t1, t4), retain_graph=True)
9890        self.assertEqual(count[0], 1)
9891
9892        out.sum().backward(inputs=(t1, t3), retain_graph=True)
9893        self.assertEqual(count[0], 2)
9894        self.assertEqual(res, [False, False, True, False])
9895
9896        class Func(torch.autograd.Function):
9897            @staticmethod
9898            def forward(ctx, x):
9899                return x
9900
9901            @staticmethod
9902            def backward(ctx, gO):
9903                raise RuntimeError("error message")
9904
9905        out = Func.apply(t2) * t3
9906        with self.assertRaisesRegex(RuntimeError, "error message"):
9907            out.sum().backward(inputs=(t2, t3), retain_graph=True)
9908        self.assertEqual(count[0], 2)
9909
9910        handle.remove()
9911        out.sum().backward(inputs=(t1, t3), retain_graph=True)
9912        self.assertEqual(count[0], 2)
9913
9914    def test_multi_grad_any_hooks(self):
9915        hook_id = 0
9916        any_hook_handles: List[RemovableHandle] = []
9917
9918        class MultiOutputModule(nn.Module):
9919            def __init__(self) -> None:
9920                super().__init__()
9921                self.lin = nn.Linear(3, 3)
9922
9923            def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
9924                z = self.lin(x)
9925                out = torch.sin(z), torch.cos(z)
9926                nonlocal hook_id
9927                z.register_hook(partial(hook, hook_id))
9928                hook_id += 1
9929                any_hook_handles.append(
9930                    torch.autograd.graph.register_multi_grad_hook(
9931                        out, partial(hook, hook_id), mode="any"
9932                    )
9933                )
9934                hook_id += 1
9935                return out
9936
9937        class Model(nn.Module):
9938            def __init__(self) -> None:
9939                super().__init__()
9940                self.mod1 = MultiOutputModule()
9941                self.mod2 = MultiOutputModule()
9942
9943            def forward(self, x: torch.Tensor) -> torch.Tensor:
9944                y = self.mod1(x)
9945                z = y[0] + y[1]
9946                return self.mod2(z)
9947
9948        hook_order: List[int] = []
9949        hook_count = 0
9950
9951        def hook(hook_id: int, *unused):
9952            nonlocal hook_count
9953            nonlocal hook_order
9954            hook_count += 1
9955            hook_order.append(hook_id)
9956
9957        # Any hooks: IDs 1 and 3; regular hooks: IDs 0 and 2
9958        model = Model()
9959        inp = torch.randn((2, 3))
9960        out = model(inp)
9961        (out[0] + out[1]).sum().backward()
9962        # Check that the any-hook runs only once and before the regular hook
9963        # for each module
9964        self.assertEqual(len(any_hook_handles), 2)
9965        self.assertEqual(hook_order, [3, 2, 1, 0])
9966
9967        hook_id = 0
9968        hook_order.clear()
9969        any_hook_handles.clear()
9970        out = model(inp)
9971        for handle in any_hook_handles:
9972            handle.remove()
9973        (out[0] + out[1]).sum().backward()
9974        # Check that the any-hook does not run if removed
9975        self.assertEqual(hook_order, [2, 0])
9976
9977    def test_multi_grad_hooks_invalid_mode(self):
9978        t1 = torch.rand(2, requires_grad=True)
9979        t2 = torch.rand(2, requires_grad=True)
9980        regex = r"Expects mode to be one of \('all', 'any'\) but got foo"
9981        with self.assertRaisesRegex(ValueError, regex):
9982            torch.autograd.graph.register_multi_grad_hook(
9983                (t1, t2), lambda _: None, mode="foo"
9984            )
9985
9986    def test_pynode_destruction_deadlock(self):
9987        script = """
9988import torch
9989
9990class Foo(torch.autograd.Function):
9991    @staticmethod
9992    def forward(ctx, x):
9993        return x.clone()
9994
9995    @staticmethod
9996    def forward(ctx, gO):
9997        return gO.clone()
9998
9999def get_out():
10000    inp = torch.rand(2, requires_grad=True)
10001
10002    # The python function is first so that it runs
10003    # last in the backward pass
10004    right = Foo.apply(inp)
10005
10006    # An op that creates new memory
10007    left1 = inp.clone()
10008    # An op that saves its input
10009    left2 = left1 ** 2
10010
10011    # Inplace modify so that the backward for
10012    # left2 always raises an error
10013    left1 += 1
10014
10015    # An op that takes both side as input.
10016    # After running, both side's last op will be in
10017    # the ready queue
10018    # And the op for left will run first as it was
10019    # executed last during the forward
10020    out = left2 + right
10021
10022    return out
10023
10024# Nothing should be global variables here as, from what
10025# I can see, python leaks all the global objects
10026get_out().sum().backward()
10027
10028# This used to deadlock when the PyNode is being destroyed after
10029# the error is raised.
10030"""
10031        try:
10032            subprocess.check_output(
10033                [sys.executable, "-c", script],
10034                stderr=subprocess.STDOUT,
10035                # On Windows, opening the subprocess with the default CWD makes `import torch`
10036                # fail, so just set CWD to this script's directory
10037                cwd=os.path.dirname(os.path.realpath(__file__)),
10038                # It is ok to have an extra long timeout here as a timeout means the test failed
10039                timeout=20,
10040            )
10041        except subprocess.TimeoutExpired as e:
10042            self.fail(
10043                msg="Example code timed out! See the code sample in the test for details."
10044            )
10045        except subprocess.CalledProcessError as e:
10046            if e.returncode < 0:
10047                # Sometimes we segfault instead of deadlocking
10048                self.fail("Subprocess exited with a fatal signal")
10049            else:
10050                err_msg = (
10051                    "RuntimeError: one of the variables needed for gradient computation"
10052                )
10053                self.assertTrue(err_msg in e.output.decode("utf-8"))
10054
10055    def test_view_func_replay(self):
10056        with torch.autograd._force_original_view_tracking(True):
10057
10058            def _assert_match_metadata(a, b):
10059                self.assertEqual(a.size(), b.size())
10060                self.assertEqual(a.stride(), b.stride())
10061                self.assertEqual(a.storage_offset(), b.storage_offset())
10062                self.assertEqual(a.device, b.device)
10063                self.assertEqual(a.dtype, b.dtype)
10064
10065            def _test_fn(fn, inp, *args, use_unsafe_view_func=False):
10066                outs = fn(inp, *args)
10067                # handle functions that return multiple views (e.g. split)
10068                if isinstance(outs, torch.Tensor):
10069                    outs = [outs]
10070
10071                for out in outs:
10072                    self.assertTrue(out._is_view())
10073                    self.assertTrue(out._base is inp)
10074
10075                    # forward view_func
10076                    new_inp = inp.clone()
10077                    _assert_match_metadata(new_inp, inp)
10078                    if use_unsafe_view_func:
10079                        new_out = out._view_func_unsafe(new_inp)
10080                    else:
10081                        new_out = out._view_func(new_inp)
10082                    _assert_match_metadata(new_out, out)
10083                    self.assertEqual(new_out, out)
10084
10085                    # reverse view_func
10086                    new_out = out.detach()
10087                    new_inp = out._rev_view_func_unsafe(new_out)
10088                    _assert_match_metadata(new_inp, inp)
10089                    self.assertTrue(new_inp._is_view())
10090                    self.assertTrue(new_inp._base is new_out)
10091
10092            # test individual view ops
10093            _test_fn(torch.ops.aten.alias.default, torch.rand(2, 2))
10094            _test_fn(torch.as_strided, torch.rand(2, 2), (4,), (1,))
10095            _test_fn(torch.chunk, torch.rand(2, 4), 2, -1)
10096            _test_fn(torch.diagonal, torch.rand(4, 4))
10097            _test_fn(torch.ops.aten.expand.default, torch.rand(4, 1), (-1, 3))
10098            _test_fn(torch.narrow, torch.rand(2, 2), 0, 1, 1)
10099            _test_fn(torch.permute, torch.rand(2, 3, 4), (1, 0, 2))
10100            _test_fn(torch.select, torch.rand(2, 2), 0, 0)
10101            _test_fn(torch.ops.aten.slice.Tensor, torch.rand(2, 2), 1, 1, 2)
10102            _test_fn(torch.split, torch.rand(2, 2), 1)
10103            _test_fn(torch.split_with_sizes, torch.rand(2, 4), [1, 3], -1)
10104            _test_fn(torch.squeeze, torch.rand(2, 1, 4))
10105            _test_fn(torch.squeeze, torch.rand(2, 1, 4), 1)
10106            _test_fn(torch.squeeze, torch.rand(2, 1, 1, 4), [1, 2])
10107            _test_fn(torch.t, torch.rand(2, 4))
10108            _test_fn(torch.transpose, torch.rand(2, 4), 0, 1)
10109            _test_fn(torch.unbind, torch.rand(1, 5))
10110            _test_fn(torch.ops.aten.unfold.default, torch.rand(1, 5), 1, 3, 2)
10111            _test_fn(torch.unsqueeze, torch.rand(2, 4), -2)
10112            _test_fn(torch.ops.aten.view.default, torch.rand(2, 10), (-1, 5, 2))
10113            _test_fn(torch.view_as_complex, torch.rand(2, 2))
10114            _test_fn(torch.view_as_real, torch.rand(2, 2, dtype=torch.cfloat))
10115
10116            # test view chains
10117            _test_fn(
10118                lambda x: x.unsqueeze(-1).transpose(-1, -2).squeeze(1),
10119                torch.randn(2, 4),
10120            )
10121            _test_fn(
10122                lambda x: x.chunk(2, -1)[0].transpose(0, 1).unsqueeze(-1),
10123                torch.randn(2, 3, 4),
10124            )
10125            _test_fn(
10126                lambda x: x.split_with_sizes([1, 3], -1)[0].chunk(2, 0),
10127                torch.randn(2, 3, 4),
10128            )
10129
10130            # chains with missing view_func()s use as_strided() to cover the gaps
10131            def chain_with_only_parent_view_func(x):
10132                with torch.autograd._force_original_view_tracking(True):
10133                    x = x.split_with_sizes([1, 3], -1)[0]
10134
10135                with torch.autograd._force_original_view_tracking(False):
10136                    x = x.chunk(2, 0)
10137
10138                return x
10139
10140            _test_fn(chain_with_only_parent_view_func, torch.randn(2, 3, 4))
10141
10142            def chain_with_only_current_view_func(x):
10143                with torch.autograd._force_original_view_tracking(False):
10144                    x = x.split_with_sizes([1, 3], -1)[0]
10145
10146                with torch.autograd._force_original_view_tracking(True):
10147                    x = x.chunk(2, 0)
10148
10149                return x
10150
10151            _test_fn(chain_with_only_current_view_func, torch.randn(2, 3, 4))
10152
10153            # TODO: Move this somewhere else
10154            # test NT views
10155            from torch.nested._internal.nested_tensor import (
10156                nested_view_from_values_offsets,
10157            )
10158
10159            values = torch.randn(10, 5)
10160            offsets = torch.tensor([0, 3, 6, 10])
10161            _test_fn(nested_view_from_values_offsets, values, offsets)
10162
10163            nt = nested_view_from_values_offsets(values, offsets).clone().detach()
10164            _test_fn(
10165                torch.ops.aten._nested_get_values.default, nt, use_unsafe_view_func=True
10166            )
10167
10168            def chain_nt_to_dense_back_and_forth(nt):
10169                # NJT1 -> dense -> NJT2 -> dense
10170                offsets2 = nt.offsets().clone().detach()
10171                return nested_view_from_values_offsets(nt.values(), offsets2).values()
10172
10173            _test_fn(chain_nt_to_dense_back_and_forth, nt, use_unsafe_view_func=True)
10174
10175            def chain_dense_to_nt_back_and_forth(values, offsets):
10176                offsets2 = offsets.clone().detach()
10177                # dense -> NJT1 -> dense -> NJT2
10178                return nested_view_from_values_offsets(
10179                    nested_view_from_values_offsets(values, offsets).values(), offsets2
10180                )
10181
10182            _test_fn(
10183                chain_dense_to_nt_back_and_forth,
10184                values,
10185                offsets,
10186                use_unsafe_view_func=True,
10187            )
10188
10189    def test_view_func_replay_with_modified_state(self):
10190        with torch.autograd._force_original_view_tracking(True):
10191            base = torch.randn(3, 4, 5)
10192            view = base.select(1, 2)
10193
10194            def symint_visitor_fn(x):
10195                # modify saved index
10196                return x + 1
10197
10198            # ensure modifying state changes view replay
10199            new_base = torch.randn_like(base)
10200            new_view = view._view_func(new_base, symint_visitor_fn=symint_visitor_fn)
10201            self.assertEqual(new_view, new_base.select(1, 3))
10202
10203            # ensure saved state reverts back afterwards
10204            self.assertEqual(view._view_func(new_base), new_base.select(1, 2))
10205
10206            # check modifying tensor state. currently, slice_inverse() is the only
10207            # view that saves a tensor
10208            base = torch.randn(3, 4, 5)
10209            sliced = base[:, 2:3, :].detach()
10210            view = torch.ops.aten.slice_inverse(sliced, base, 1, 2, 3, 1)
10211
10212            replacement_shape = (1, 2, 3)
10213
10214            def tensor_visitor_fn(x):
10215                # return tensor with a smaller shape than the saved one
10216                return torch.randn(*replacement_shape)
10217
10218            # ensure modifying state changes view replay
10219            new_sliced = torch.ones_like(base)[:, 2:3, :].detach()
10220            new_view = view._view_func(new_sliced, tensor_visitor_fn=tensor_visitor_fn)
10221            self.assertEqual(new_view.shape, replacement_shape)
10222            self.assertEqual(
10223                new_view, new_sliced.as_strided(replacement_shape, (6, 3, 1))
10224            )
10225
10226            # ensure saved state reverts back afterwards
10227            self.assertEqual(view._view_func(sliced), base)
10228
10229    def test_setup_context_when_forward_has_default_args(self):
10230        class PowFunction(Function):
10231            @staticmethod
10232            def forward(x, y=3):
10233                return torch.pow(x, y)
10234
10235            @staticmethod
10236            def setup_context(ctx, inputs, output):
10237                x, y = inputs
10238                ctx.save_for_backward(x)
10239                ctx.y = y
10240
10241            @staticmethod
10242            def backward(ctx, gO):
10243                (x,) = ctx.saved_tensors
10244                y = ctx.y
10245                return gO * y * torch.pow(x, y - 1), None
10246
10247        class PowFunctionWithClassmethod(Function):
10248            @classmethod
10249            def forward(cls, x, y=3):
10250                return torch.pow(x, y)
10251
10252            @classmethod
10253            def setup_context(cls, ctx, inputs, output):
10254                x, y = inputs
10255                ctx.save_for_backward(x)
10256                ctx.y = y
10257
10258            @classmethod
10259            def backward(cls, ctx, gO):
10260                (x,) = ctx.saved_tensors
10261                y = ctx.y
10262                return gO * y * torch.pow(x, y - 1), None
10263
10264        x = torch.tensor(2.0, requires_grad=True)
10265
10266        y = torch.tensor(8.0)
10267        y_expected = torch.tensor(12.0)
10268
10269        y1 = PowFunction.apply(x)
10270        (y1_expected,) = torch.autograd.grad(y1, x)
10271
10272        y2 = PowFunctionWithClassmethod.apply(x)
10273        (y2_expected,) = torch.autograd.grad(y2, x)
10274
10275        self.assertEqual(y, y1)
10276        self.assertEqual(y_expected, y1_expected)
10277        self.assertEqual(y, y2)
10278        self.assertEqual(y_expected, y2_expected)
10279
10280    @unittest.skipIf(not TEST_CUDA, "test requires CUDA")
10281    def test_gradcheck_default_device_placement_context(self):
10282        # During gradcheck with fast_mode=True, we create a random vector on the CPU device using a CPU generator.
10283        # This test ensures that this still works when the default device is set to something else by the user.
10284        with torch.device("cuda"):
10285            x = torch.randn(3, dtype=torch.double, requires_grad=True)
10286
10287            def func(inp):
10288                return inp**2.0
10289
10290            self.assertTrue(gradcheck(func, x, fast_mode=True))
10291
10292
10293def index_perm_variable(shape, max_indices):
10294    if not isinstance(shape, tuple):
10295        shape = (shape,)
10296
10297    index = torch.randperm(max_indices).narrow(0, 0, reduce(mul, shape)).view(shape)
10298    return index
10299
10300
10301def bernoulli_scalar():
10302    return torch.tensor(0, dtype=torch.uint8).bernoulli_()
10303
10304
10305class TestAutogradForwardModeBatchedGrad(TestCase):
10306    def test_out_of_place_basic(self):
10307        a = torch.rand(4, 4, dtype=torch.double, requires_grad=True)
10308        b = torch.rand(4, 4, dtype=torch.double, requires_grad=True)
10309        self.assertTrue(
10310            gradcheck(
10311                torch.sin,
10312                a,
10313                check_forward_ad=True,
10314                check_batched_grad=True,
10315                check_batched_forward_grad=True,
10316            )
10317        )
10318        self.assertTrue(
10319            gradcheck(
10320                torch.add,
10321                (a, b),
10322                check_forward_ad=True,
10323                check_batched_grad=True,
10324                check_batched_forward_grad=True,
10325            )
10326        )
10327
10328    def test_out_of_place_not_same_layout(self):
10329        input = torch.zeros([2, 2]).transpose(0, 1)
10330        tangent = torch.zeros([2, 2, 2])
10331
10332        def jvp(tangent):
10333            with fwAD.dual_level():
10334                x = fwAD.make_dual(input, tangent)
10335                return fwAD.unpack_dual(x)[1]
10336
10337        x_tangent = torch._vmap_internals._vmap(jvp, 0, 0)(tangent)
10338
10339        self.assertIsNot(x_tangent, tangent)
10340
10341    def test_inplace_on_view_same_layout(self):
10342        input = torch.zeros([2, 2])
10343        tangent = torch.zeros([2, 2, 2])
10344        base = torch.zeros([2, 2])
10345        view = base.view_as(base)
10346
10347        def jvp(tangent):
10348            with fwAD.dual_level():
10349                x = fwAD.make_dual(input, tangent)
10350                view.copy_(x)
10351                return (
10352                    fwAD.unpack_dual(x)[1],
10353                    fwAD.unpack_dual(view)[1],
10354                    fwAD.unpack_dual(view._base)[1],
10355                )
10356
10357        x_tangent, view_tangent, base_tangent = torch._vmap_internals._vmap(jvp, 0, 0)(
10358            tangent
10359        )
10360
10361        self.assertFalse(
10362            view_tangent._is_view()
10363        )  # Optimization to share the same tensor!
10364        self.assertIs(view_tangent, base_tangent)
10365        self.assertIs(x_tangent, tangent)
10366
10367    def test_inplace_on_view_not_same_layout(self):
10368        input = torch.zeros([2, 2])
10369        tangent = torch.zeros([2, 2, 2])
10370        view = torch.zeros([2, 2]).transpose(0, 1)
10371
10372        def jvp(tangent):
10373            with fwAD.dual_level():
10374                x = fwAD.make_dual(input, tangent)
10375                view.copy_(x)
10376                return (
10377                    fwAD.unpack_dual(x)[1],
10378                    fwAD.unpack_dual(view)[1],
10379                    fwAD.unpack_dual(view._base)[1],
10380                )
10381
10382        x_tangent, view_tangent, base_tangent = torch._vmap_internals._vmap(jvp, 0, 0)(
10383            tangent
10384        )
10385
10386        self.assertIs(view_tangent._base, base_tangent)
10387        self.assertIs(x_tangent, tangent)
10388        self.assertIsNot(view_tangent, tangent)
10389
10390    def test_metadata_check_for_storage_numel_skipped(self):
10391        # See: test_metadata_check_checks_storage_numel for the reverse of this test
10392        primal = torch.randn(5)[:4].detach()
10393        self.assertEqual(len(primal.storage()), 5)
10394        tangent = torch.randn(10, 4)
10395
10396        def jvp(tangent):
10397            with fwAD.dual_level():
10398                dual = fwAD.make_dual(primal, tangent)
10399                _, unpacked_tangent = fwAD.unpack_dual(dual)
10400
10401                # No copy is made
10402                self.assertIs(tangent, unpacked_tangent)
10403
10404                # as_strided raises
10405                with self.assertRaisesRegex(
10406                    RuntimeError, "can access memory outside of `tensor`"
10407                ):
10408                    dual.as_strided((5,), (1,), 0)
10409            return unpacked_tangent
10410
10411        torch._vmap_internals._vmap(jvp, 0, 0)(tangent)
10412
10413
10414class TestAutogradForwardMode(TestCase):
10415    def tearDown(self):
10416        # Ensure that a failing test won't make others fail
10417        while fwAD._current_level >= 0:
10418            fwAD.exit_dual_level()
10419
10420        super().tearDown()
10421
10422    def test_forward_level_cleanup(self):
10423        def get_tensor_and_weak_ref():
10424            # Create a new Tensor and weak reference
10425            t = torch.rand(2, requires_grad=True)
10426            return t, torch._C._WeakTensorRef(t)
10427
10428        # Sanity check that the helper function works as expected
10429        t, t_ref = get_tensor_and_weak_ref()
10430        self.assertFalse(t_ref.expired())
10431
10432        del t
10433        self.assertTrue(t_ref.expired())
10434
10435        # Main test code
10436        foo = torch.rand(2)
10437
10438        with fwAD.dual_level():
10439            tangent, tangent_ref = get_tensor_and_weak_ref()
10440            self.assertFalse(tangent_ref.expired())
10441
10442            dual = fwAD.make_dual(foo, tangent)
10443            self.assertFalse(tangent_ref.expired())
10444
10445            # Make sure that the tangent we provided has been re-used as is
10446            self.assertTrue(fwAD.unpack_dual(dual)[1] is tangent)
10447
10448            # Make sure that dual is keeping the tangent alive
10449            del tangent
10450            self.assertFalse(tangent_ref.expired())
10451
10452            # Make sure that the dual level does not keep the c++
10453            # version of the tangent alive
10454            del dual
10455            self.assertTrue(tangent_ref.expired())
10456
10457    def test_size_check(self):
10458        foo = torch.rand(2)
10459        tangent = torch.rand(3)
10460
10461        with fwAD.dual_level():
10462            with self.assertRaisesRegex(
10463                RuntimeError,
10464                "Trying to set a forward gradient that has a different size",
10465            ):
10466                dual = fwAD.make_dual(foo, tangent)
10467
10468            dual = fwAD.make_dual(foo, tangent[1:])
10469
10470    def test_metadata_check_checks_storage_numel(self):
10471        primal = torch.randn(5)[:4].detach()
10472        self.assertEqual(len(primal.storage()), 5)
10473        tangent = torch.randn(4)
10474
10475        with fwAD.dual_level():
10476            dual = fwAD.make_dual(primal, tangent)
10477            _, unpacked_tangent = fwAD.unpack_dual(dual)
10478
10479            # # Verify that mutating unpacked tangent does not affect the original tangent
10480            tangent_clone = tangent.clone()
10481            unpacked_tangent *= 2
10482            self.assertTrue(torch.allclose(tangent_clone, tangent))
10483
10484            # as_strided runs without error
10485            dual.as_strided((5,), (1,), 0)
10486
10487    def test_metadata_check_checks_ignores_size_zero(self):
10488        a = torch.ones(0).as_strided((0, 1), (1, 1), 0)
10489        b = torch.ones(0).as_strided((0, 1), (1, 0), 0)
10490
10491        with fwAD.dual_level():
10492            dual = fwAD.make_dual(a, b)
10493            torch.diagonal(dual, offset=0)
10494
10495        input = torch.rand([0, 1], dtype=torch.complex128, requires_grad=True)
10496        func = partial(torch.diagonal, offset=0)
10497        torch.autograd.gradcheck(func, (input,), check_forward_ad=True)
10498
10499    def test_metadata_check_when_primal_has_conj_bit(self):
10500        # Make sure the _has_same_storage_numel is a fallthrough, so that
10501        # conj bit does not materialize. If it materializes it would
10502        # cause the layout check to fail for views that do not index the
10503        # the entire storage.
10504        a = torch.randn(2, 2, dtype=torch.cdouble).conj()
10505        b = torch.rand_like(a)
10506
10507        self.assertTrue(torch.is_conj(a))
10508        self.assertEqual(len(a.storage()), len(b.storage()))
10509
10510        with fwAD.dual_level():
10511            dual = fwAD.make_dual(a, b)
10512            dual[1:]
10513
10514    def test_metadata_check_when_primal_has_neg_bit(self):
10515        # Make sure the _has_same_storage_numel is a fallthrough, so that
10516        # conj bit does not materialize. If it materializes it would
10517        # cause the layout check to fail for views that do not index the
10518        # the entire storage.
10519        a = torch.randn(2, 2, dtype=torch.cdouble).conj().imag
10520        b = torch.randn(2, 2, dtype=torch.cdouble).imag
10521
10522        self.assertTrue(torch.is_neg(a))
10523        self.assertEqual(len(a.storage()), len(b.storage()))
10524
10525        with fwAD.dual_level():
10526            dual = fwAD.make_dual(a, b)
10527            dual[1:]
10528
10529    def test_metadata_check_check_conj(self):
10530        keys = {
10531            "NEITHER": lambda x: x,
10532            "CONJ": lambda x: x.conj(),
10533            "NEG": lambda x: x._neg_view(),
10534        }
10535
10536        for primal_key, tangent_key in product(keys, keys):
10537            x = keys[primal_key](torch.randn(2, 3, 4, dtype=torch.cdouble))
10538            t = keys[tangent_key](torch.randn(2, 3, 4, dtype=torch.cdouble))
10539
10540            if primal_key == tangent_key:
10541                with fwAD.dual_level():
10542                    dual = fwAD.make_dual(x, t)
10543                    self.assertTrue(fwAD.unpack_dual(dual).tangent is t)
10544                    torch.real(dual)
10545                    torch.imag(dual)
10546            else:
10547                with fwAD.dual_level():
10548                    dual = fwAD.make_dual(x, t)
10549                    self.assertTrue(fwAD.unpack_dual(dual).tangent is not t)
10550                    torch.real(dual)
10551                    torch.imag(dual)
10552
10553    def test_metadata_check_ignore_storage_offset_for_zero_numel_tensor(self):
10554        # See https://github.com/pytorch/pytorch/issues/80507
10555        a = torch.tensor([1.0]).as_strided((0,), (1,), 1)
10556        b = torch.tensor([1.0]).as_strided((0,), (1,), 2)
10557
10558        with fwAD.dual_level():
10559            dual_input = fwAD.make_dual(a, b)
10560            # Check that no copy is made
10561            self.assertIs(fwAD.unpack_dual(dual_input).tangent, b)
10562
10563        a = torch.tensor([1.0]).as_strided((1,), (2,), 0)
10564        b = torch.tensor([1.0]).as_strided((1,), (1,), 0)
10565
10566        with fwAD.dual_level():
10567            dual_input = fwAD.make_dual(a, b)
10568            dual_input[1:]
10569
10570    # The following test functions want to ensure all the following behaviors:
10571    #   - Ensure that default level system in the python binding works
10572    #   - Ensure that only level 0 exists and nesting is properly disabled
10573    #   - Ensure that printing works fine
10574    #   - Ensure that basic packing/unpacking works
10575    #   - Ensure that advanced packing/unpacking works
10576    #     - For memory / version counter share
10577    #     - For backward AD (regular ops)
10578    #   - Ensure that view + inplace for both modes work fine
10579    #   - Ensure we do proper cleanup on exit of a level
10580
10581    def test_default_level(self):
10582        foo = torch.rand(2)
10583        bar = torch.rand(2)
10584
10585        with fwAD.dual_level():
10586            baz = fwAD.make_dual(foo, bar)
10587            baz_primal, baz_tangent = fwAD.unpack_dual(baz)
10588        self.assertEqual(baz_primal, foo)
10589        # We don't actually need to enforce that these two are the exact same python
10590        # object, feel free to relax in the future
10591        self.assertIs(baz_tangent, bar)
10592
10593        baz_primal, baz_tangent = fwAD.unpack_dual(baz)
10594        self.assertEqual(baz_primal, foo)
10595        self.assertEqual(baz_tangent, None)
10596
10597    def test_fwd_grad_enabled(self):
10598        # Tests some private helper functions to enable/disable fwd grad mode
10599        enabled = fwAD._is_fwd_grad_enabled()
10600        self.assertTrue(enabled)
10601
10602        try:
10603            torch._C._set_fwd_grad_enabled(False)
10604            enabled = fwAD._is_fwd_grad_enabled()
10605            self.assertFalse(enabled)
10606        finally:
10607            torch._C._set_fwd_grad_enabled(True)
10608
10609        enabled = fwAD._is_fwd_grad_enabled()
10610        self.assertTrue(enabled)
10611
10612    def test_set_fwd_grad_enabled(self):
10613        # Tests a private helper function
10614        try:
10615            torch._C._set_fwd_grad_enabled(False)
10616            enabled = fwAD._is_fwd_grad_enabled()
10617            self.assertFalse(enabled)
10618
10619            with fwAD._set_fwd_grad_enabled(True):
10620                enabled = fwAD._is_fwd_grad_enabled()
10621                self.assertTrue(enabled)
10622
10623            enabled = fwAD._is_fwd_grad_enabled()
10624            self.assertFalse(enabled)
10625        finally:
10626            torch._C._set_fwd_grad_enabled(True)
10627
10628    def test_nested_level(self):
10629        with fwAD.dual_level() as level:
10630            # For now only level 0 exists
10631            self.assertEqual(level, 0)
10632
10633        with fwAD.dual_level():
10634            with self.assertRaisesRegex(
10635                RuntimeError, "Nested forward mode AD is not supported at the moment"
10636            ):
10637                nest_level = fwAD.enter_dual_level()
10638
10639    def test_set_fw_grad_having_own_fw_grad_at_same_level(self):
10640        foo = torch.rand(2)
10641        bar = torch.rand(2)
10642        baz = torch.rand(2)
10643
10644        with fwAD.dual_level():
10645            dual = fwAD.make_dual(foo, bar)
10646            with self.assertRaisesRegex(
10647                RuntimeError, "has a forward gradient at the same level"
10648            ):
10649                fwAD.make_dual(baz, dual)
10650
10651    def test_codegen_ignores_undefined_outputs(self):
10652        # This test checks that codegen silently ignores undefined outputs
10653        # Below, grad_input is specified as False in grad_output_mask, so
10654        # convolution backward will return a undefined tensor in that position.
10655        # Note that for this test to work we need to make sure either grad_output
10656        # or weight to be a dual tensor, so grad_input requires forward grad
10657        weight = torch.randn(6, 1, 30, 30)
10658        inp = torch.rand((1, 1, 32, 32))
10659        out = torch.nn.functional.conv2d(inp, weight)
10660        grad_out = torch.ones_like(out)
10661
10662        with fwAD.dual_level():
10663            dual_weight = fwAD.make_dual(weight, torch.ones_like(weight))
10664            grad_input, _, _ = torch.ops.aten.convolution_backward(
10665                grad_out,
10666                inp,
10667                dual_weight,
10668                (0,),
10669                (1, 1),
10670                (0, 0),
10671                (1, 1),
10672                False,
10673                (0, 0),
10674                1,
10675                (False, True, False),
10676            )
10677        self.assertIsNone(grad_input)
10678
10679    def test_make_dual_inference_tensor_in_inference_mode(self):
10680        with torch.inference_mode():
10681            foo = torch.rand(2)
10682            bar = torch.rand(2)
10683            foo_copy = foo.clone()
10684
10685            with fwAD.dual_level():
10686                dual = fwAD.make_dual(foo, bar)
10687                self.assertFalse(dual._is_view())
10688
10689                dual += 1
10690                self.assertFalse(torch.allclose(foo, foo_copy))
10691
10692    def test_make_dual_torch_dispatch(self):
10693        counter = [0]
10694
10695        class MySubclass(torch.Tensor):
10696            def __new__(cls, data=None):
10697                return torch.Tensor._make_subclass(cls, data)
10698
10699            @classmethod
10700            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
10701                if func.overloadpacket == torch.ops.aten.alias:
10702                    counter[0] += 1
10703
10704                    # Make sure we can re-enable autograd here
10705                    with torch.overrides.enable_reentrant_dispatch():
10706                        foo = torch.rand(1, requires_grad=True)
10707                        self.assertIsNotNone(foo.exp().grad_fn)
10708
10709                with no_dispatch():
10710                    return func(*args, **kwargs)
10711
10712        a = torch.tensor(1.0)
10713        s = MySubclass(a)
10714
10715        with fwAD.dual_level():
10716            # Only the primal has "alias" called on it
10717            fwAD.make_dual(s, torch.rand_like(s))
10718            self.assertEqual(counter[0], 1)
10719            fwAD.make_dual(torch.rand_like(s), s)
10720            self.assertEqual(counter[0], 1)
10721
10722    def test_make_dual_forbid_integral_dtype(self):
10723        primal_f = torch.ones(2, 2, dtype=torch.float)
10724        primal_l = torch.ones(2, 2, dtype=torch.long)
10725
10726        tangent_f = torch.ones(2, 2, dtype=torch.float)
10727        tangent_l = torch.ones(2, 2, dtype=torch.long)
10728
10729        with fwAD.dual_level():
10730            # Float Primal and Long Tangent
10731            with self.assertRaisesRegex(
10732                ValueError, "Expected tangent to be floating point or complex"
10733            ):
10734                fwAD.make_dual(primal_f, tangent_l)
10735
10736            # Long Primal and Long Tangent
10737            with self.assertRaisesRegex(
10738                ValueError, "Expected primal to be floating point or complex"
10739            ):
10740                fwAD.make_dual(primal_l, tangent_l)
10741
10742            # Long Primal and Float Tangent
10743            with self.assertRaisesRegex(
10744                ValueError, "Expected primal to be floating point or complex"
10745            ):
10746                fwAD.make_dual(primal_l, tangent_f)
10747
10748    def test_print(self):
10749        with fwAD.dual_level() as level:
10750            a = torch.rand(3)
10751            self.assertFalse("tangent=" in str(a))
10752
10753            b = fwAD.make_dual(a, torch.rand(3))
10754            self.assertFalse("tangent=" in str(a))
10755            self.assertTrue("tangent=" in str(b))
10756
10757            b_primal, b_tangent = fwAD.unpack_dual(b)
10758            self.assertFalse("tangent=" in str(b_primal))
10759            self.assertFalse("tangent=" in str(b_tangent))
10760
10761    def test_basic_packing_unpacking(self):
10762        foo = torch.rand(2)
10763        bar = torch.rand(2)
10764
10765        with fwAD.dual_level():
10766            baz = fwAD.make_dual(foo, bar)
10767            baz_primal, baz_tangent = fwAD.unpack_dual(baz)
10768            self.assertEqual(baz_primal, foo)
10769            self.assertIs(baz_tangent, bar)
10770
10771            # Check unpacked dual is returned as a named tuple
10772            # NB: Every invocation of unpack_dual returns a new tensor view
10773            self.assertIsNot(baz_primal, fwAD.unpack_dual(baz).primal)
10774            self.assertEqual(baz_primal, fwAD.unpack_dual(baz).primal)
10775            self.assertIs(baz_tangent, fwAD.unpack_dual(baz).tangent)
10776
10777            # Check that packing/unpacking did not change the input
10778            foo_primal, foo_tangent = fwAD.unpack_dual(foo)
10779            self.assertEqual(foo_primal, foo)
10780            self.assertIsNone(foo_tangent)
10781
10782    def test_advanced_packing_unpacking(self):
10783        foo = torch.rand(2)
10784        bar = torch.ones(2)
10785
10786        # Memory and version counter check
10787        with fwAD.dual_level():
10788            dual = fwAD.make_dual(foo, bar)
10789
10790            # Ensure that they are sharing memory and version counter
10791            self.assertEqual(dual.storage().data_ptr(), foo.storage().data_ptr())
10792
10793            # Ensure we properly share the version counter
10794            self.assertEqual(foo._version, dual._version)
10795            foo.add_(1)
10796            self.assertEqual(foo._version, dual._version)
10797
10798            # Unpacking should only create aliases as well
10799            dual_primal, dual_tangent = fwAD.unpack_dual(dual)
10800            self.assertEqual(dual_primal.storage().data_ptr(), foo.storage().data_ptr())
10801            self.assertEqual(
10802                dual_tangent.storage().data_ptr(), bar.storage().data_ptr()
10803            )
10804            # And the tangent is actually re-used as-is so it is still the same Tensor
10805            self.assertIs(dual_tangent, bar)
10806
10807            # Ensure we properly share the version counter
10808            self.assertEqual(foo._version, dual_primal._version)
10809            foo.add_(1)
10810            self.assertEqual(foo._version, dual_primal._version)
10811            self.assertEqual(bar._version, dual_tangent._version)
10812            bar.add_(1)
10813            self.assertEqual(bar._version, dual_tangent._version)
10814
10815        # backward mode check
10816        with fwAD.dual_level():
10817            foo.requires_grad_()
10818            bar.requires_grad_()
10819
10820            # Check that backward gradients properly propagates through packing/unpacking
10821            dual = fwAD.make_dual(foo, bar)
10822            p, t = fwAD.unpack_dual(dual)
10823
10824            gfoo, gbar = torch.autograd.grad(
10825                p.sum(), (foo, bar), retain_graph=True, allow_unused=True
10826            )
10827            self.assertEqual(gfoo, torch.ones_like(foo))
10828            self.assertIsNone(gbar)
10829
10830            gfoo, gbar = torch.autograd.grad(
10831                t.sum(), (foo, bar), retain_graph=True, allow_unused=True
10832            )
10833            self.assertIsNone(gfoo)
10834            self.assertEqual(gbar, torch.ones_like(bar))
10835
10836            # Check that forward gradients are impacted by detach()
10837            detached_dual = dual.detach()
10838            out = detached_dual * 2
10839            p, t = fwAD.unpack_dual(out)
10840            self.assertFalse(p.requires_grad)
10841            self.assertEqual(p, foo * 2)
10842            self.assertIsNone(t)
10843
10844            # Check that forward gradients are not impacted by no_grad
10845            with torch.no_grad():
10846                out = dual * 3
10847            p, t = fwAD.unpack_dual(out)
10848            self.assertFalse(p.requires_grad)
10849            self.assertFalse(t.requires_grad)
10850            self.assertEqual(p, foo * 3)
10851            self.assertEqual(t, bar * 3)
10852
10853            # Check that forward gradients are not impacted by inplace detach
10854            dual = dual.clone()
10855            dual.detach_()
10856            out = dual * 2
10857            p, t = fwAD.unpack_dual(out)
10858            self.assertFalse(p.requires_grad)
10859            self.assertEqual(p, foo * 2)
10860            self.assertIsNone(t)
10861
10862    def test_view_inplace_non_differentiable_views(self):
10863        original_foo = torch.rand(2, dtype=torch.double)
10864        original_bar = torch.ones(2, dtype=torch.double)
10865
10866        # Do clones to be able to compare the values updated inplace
10867        # with the original content of these Tensors
10868        foo = original_foo.clone()
10869        bar = original_bar.clone()
10870
10871        with fwAD.dual_level():
10872            # Note that in this test, we use "update" to mean computing the right tangent for the dual
10873            # All the inplace operations here are expected to update the primal value of the Tensors but
10874            # not always their tangents.
10875            # Also all mentions of "non differentiable view" here means non forward differentiable view
10876            # unless specified otherwise.
10877            # See note [Forward Grad View/inplace] for more details on how these views work.
10878
10879            # Check that inplace ops do not update non-differentiable views
10880            # Non differentiable view
10881            dual = fwAD.make_dual(foo, bar)
10882            dual *= 2
10883            # Check that non differentiable view's tangent was not updated
10884            self.assertIsNone(fwAD.unpack_dual(foo)[1])
10885            # Check that the computed result is correct
10886            self.assertEqual(bar, original_bar * 2)
10887            self.assertEqual(fwAD.unpack_dual(dual)[1], original_bar * 2)
10888            self.assertEqual(foo, original_foo * 2)
10889            self.assertEqual(fwAD.unpack_dual(dual)[0], original_foo * 2)
10890            # Other non differentiable view
10891            dual_primal, dual_tangent = fwAD.unpack_dual(dual)
10892            self.assertIsNone(fwAD.unpack_dual(dual_primal)[1])
10893            self.assertIsNone(fwAD.unpack_dual(dual_tangent)[1])
10894            dual_primal *= 2
10895            # Ensure dual's tangent did not change
10896            self.assertEqual(fwAD.unpack_dual(dual)[0], original_foo * 4)
10897            self.assertEqual(fwAD.unpack_dual(dual)[1], original_bar * 2)
10898            dual_tangent *= 2
10899            # Ensure dual's primal did not change
10900            self.assertEqual(fwAD.unpack_dual(dual)[0], original_foo * 4)
10901            self.assertEqual(fwAD.unpack_dual(dual)[1], original_bar * 4)
10902
10903    def test_view_inplace_differentiable_views(self):
10904        original_foo = torch.rand(2)
10905        original_bar = torch.ones(2)
10906
10907        # Do clones to be able to compare the values updated inplace
10908        # with the original content of these Tensors
10909        foo = original_foo.clone()
10910        bar = original_bar.clone()
10911
10912        with fwAD.dual_level():
10913            # Check that inplace ops do update differentiable view but stop at non differentiable ones
10914            # A non differentiable view
10915            dual = fwAD.make_dual(foo, bar)
10916            # A differentiable view
10917            view = dual.narrow(0, 0, 1)
10918            view *= 2
10919            # Check that non differentiable view was not updated
10920            self.assertIsNone(fwAD.unpack_dual(foo)[1])
10921            # Check that differentiable view was updated
10922            self.assertEqual(fwAD.unpack_dual(dual)[1], torch.tensor([2.0, 1.0]))
10923            self.assertEqual(fwAD.unpack_dual(view)[1], torch.tensor([2.0]))
10924
10925            # Check that we track differentiable view even for Tensors that are not dual
10926            baz = torch.rand(2)
10927            baz += dual
10928            self.assertEqual(fwAD.unpack_dual(baz)[1], fwAD.unpack_dual(dual)[1])
10929            # Updates on view should as well
10930            baz = torch.rand(2)
10931            baz[0] = dual[0]
10932            self.assertEqual(fwAD.unpack_dual(baz)[1][0], fwAD.unpack_dual(dual)[1][0])
10933            # Unused values get a gradient of 0
10934            self.assertEqual(fwAD.unpack_dual(baz)[1][1], 0.0)
10935
10936            # Check that forward non-differentiable views do prevent gradient update
10937            baz = torch.rand(2)
10938            view = baz.detach()
10939            view += dual
10940            self.assertIsNone(fwAD.unpack_dual(baz)[1])
10941
10942    def test_view_inplace_always_creates_a_view(self):
10943        # See https://github.com/pytorch/pytorch/issues/67800
10944        # The codepath may depend on the op. At the time writing, when self is not a dual tensor
10945        # the resulting forward grad for self for...
10946        # - add_ has the same layout as self
10947        # - mul_ has the same layout as other
10948        # This is kind of fragile because the above depends on how the forward grad expression
10949        # is written. For add and mul at least, the output inherits the layout of LHS.
10950        # We want to handle at least these two cases.
10951        inplace_binary_ops = (  # Add more to this list?
10952            lambda x, y: x.add_(y),
10953            lambda x, y: x.mul_(y),
10954            lambda x, y: x.copy_(y),
10955        )
10956
10957        for inplace_binary_op in inplace_binary_ops:
10958            base = torch.randn(2, 2)
10959            view = base.transpose(0, 1)
10960
10961            primal = torch.randn(2, 2)
10962            tangent = torch.randn(2, 2)
10963
10964            with fwAD.dual_level():
10965                dual = fwAD.make_dual(primal, tangent)
10966                inplace_binary_op(view, dual)
10967
10968                # Verify that a view relationship is created for both the primal and tangent
10969                p, t = fwAD.unpack_dual(base)
10970                p_clone = p.clone()
10971                t_clone = t.clone()
10972                view *= 2
10973                p, t = fwAD.unpack_dual(base)
10974
10975                self.assertTrue(torch.allclose(p_clone * 2, p))
10976                self.assertTrue(torch.allclose(t_clone * 2, t))
10977
10978    def test_grad_cleanup(self):
10979        foo = torch.rand(2)
10980        bar = torch.rand(2)
10981        baz = torch.rand(2)
10982
10983        with fwAD.dual_level():
10984            dual = fwAD.make_dual(foo, bar)
10985            self.assertIsNone(fwAD.unpack_dual(foo)[1])
10986            self.assertIs(fwAD.unpack_dual(dual)[1], bar)
10987
10988        self.assertIsNone(fwAD.unpack_dual(dual)[1])
10989
10990        with fwAD.dual_level():
10991            self.assertIsNone(fwAD.unpack_dual(foo)[1])
10992            new_dual = fwAD.make_dual(foo, baz)
10993
10994            dual_primal, dual_tangent = fwAD.unpack_dual(dual)
10995            new_dual_primal, new_dual_tangent = fwAD.unpack_dual(new_dual)
10996            self.assertEqual(dual_primal, new_dual_primal)
10997            self.assertIsNone(dual_tangent)
10998            self.assertEqual(new_dual_tangent, baz)
10999
11000    def test_detach_view_tracking(self):
11001        # Default detach is both forward and backward non-differentiable
11002        foo = torch.rand(2)
11003        foo_weak = torch._C._WeakTensorRef(foo)
11004
11005        out = foo.detach()
11006
11007        del foo
11008        self.assertTrue(foo_weak.expired())
11009
11010    def test_out_variant(self):
11011        with fwAD.dual_level():
11012            foo = fwAD.make_dual(torch.rand(2), torch.rand(2))
11013            bar = torch.rand(2)
11014
11015            with self.assertRaisesRegex(RuntimeError, "out= function"):
11016                torch.add(bar, bar, out=foo)
11017
11018            with self.assertRaisesRegex(RuntimeError, "out= function"):
11019                torch.add(foo, bar, out=bar)
11020
11021    def test_non_differentiable(self):
11022        with fwAD.dual_level():
11023            foo = fwAD.make_dual(torch.rand(2), torch.rand(2))
11024            bar = torch.rand(2)
11025
11026            # No differentiable outputs, shouldn't error
11027            eq = foo == bar
11028
11029            # Inplace
11030            foo.eq_(bar)
11031
11032    def test_create_new_zeros_with_same_meta(self):
11033        new_zeroes_fn = torch.ops.aten._new_zeros_with_same_feature_meta
11034
11035        def check(a, b):
11036            def assert_same_meta(t, target):
11037                for num_bdim in range(t.dim()):
11038                    result = new_zeroes_fn(t, target, self_num_batch_dims=num_bdim)
11039
11040                    self.assertEqual(result.dim(), target.dim() + num_bdim)
11041
11042                    # Check size/strides match for feature dims only
11043                    for i in range(num_bdim, result.dim()):
11044                        self.assertEqual(result.size()[i], target.size()[i - num_bdim])
11045                        self.assertEqual(
11046                            result.stride()[i], target.stride()[i - num_bdim]
11047                        )
11048
11049                    # Check that we generate strides reasonably
11050                    if target.is_contiguous():
11051                        self.assertTrue(result.is_contiguous())
11052
11053                    self.assertEqual(result.storage_offset(), target.storage_offset())
11054
11055                    prod_of_t_bdims = reduce(operator.mul, t.size()[:num_bdim], 1)
11056                    self.assertEqual(
11057                        len(result.storage()), len(target.storage()) * prod_of_t_bdims
11058                    )
11059
11060                    # TensorOptions is same
11061                    self.assertEqual(result.dtype, target.dtype)
11062
11063            assert_same_meta(a, b)
11064            assert_same_meta(b, a)
11065
11066        a = torch.randn(5, dtype=torch.float)
11067        b = torch.randn(2, 3, 4, dtype=torch.double)
11068        check(a, b)
11069
11070        # non-contiguous case
11071        a = torch.randn(2, 3, 4).transpose(0, 1).contiguous().transpose(0, 1)
11072        b = torch.randn(2, 3, 4)
11073        check(a, b)
11074
11075        a = torch.randn(5).narrow(0, 1, 2)
11076        b = torch.randn(2)
11077        check(a, b)
11078
11079        # tensor is not a view, but still does not index entirety of storage
11080        a = torch.randn(5).resize_(4)
11081        b = torch.randn(4)
11082        check(a, b)
11083
11084        # Zero-numel tensors
11085        a = torch.randn(1, 0, 2)
11086        b = torch.randn(1, 2)
11087        check(a, b)
11088
11089        # Scalar tensor
11090        a = torch.tensor(1.0)
11091        b = torch.randn(1, 2)
11092        check(a, b)
11093
11094    def test_backward_graph_destruction(self):
11095        def fn():
11096            a = torch.rand(10, requires_grad=True)
11097
11098            da = fwAD.make_dual(torch.rand_like(a), a)
11099
11100            # Create an object with a c++ cycle as:
11101            # db -> AutogradMeta -> ForwardGrad -> db's grad
11102            # db's grad -> AutogradMeta -> MulBackward
11103            # MulBackward -> SavedVariable -> db
11104            db = da.exp()
11105
11106        with fwAD.dual_level():
11107            fn()
11108        # This test make sure that we don't deadlock on exit of this
11109        # context manager. If you do, there is something wrong with the
11110        # locking of the forward ad level most likely
11111
11112
11113# Generic device type autograd tests.
11114class TestAutogradDeviceType(TestCase):
11115    def test_min_max_median_backprops_to_all_values(self, device):
11116        for f in [torch.min, torch.max, torch.median, torch.nanmedian]:
11117            x1 = torch.tensor(
11118                [1.0, 0.0, 1.0, 0.0, 1.0, 0.0], device=device, requires_grad=True
11119            )
11120            x2 = torch.tensor(
11121                [float("nan"), float("nan"), float("nan")], requires_grad=True
11122            )
11123            for x in [x1, x2]:
11124                y = f(x)
11125                y.backward()
11126                self.assertEqual(x.grad.sum(), 1.0)
11127                self.assertEqual((x.grad == 1 / 3).sum(), 3)
11128
11129    def test_scatter_index_reduce_amin_amax_backprops_to_all_values(self, device):
11130        # tests that gradients are evenly distributed when there are multiple max/min values
11131        # tested here instead of adding a SampleInput as the backward for this case is non-differentiable for gradgrad
11132        # as is the case for test_min_max_median_backprops_to_all_values above
11133        fns = (torch.scatter_reduce, torch.index_reduce)
11134        reduces = ("amin", "amax")
11135        for fn, reduction in product(fns, reduces):
11136            input = torch.randn(
11137                (2, 3), device=device, dtype=torch.float64, requires_grad=True
11138            )
11139            src = input.clone().detach_().requires_grad_(True)
11140            idx = torch.arange(2).to(dtype=torch.long, device=device)
11141            if fn == torch.scatter_reduce:
11142                idx = idx.unsqueeze(-1).expand((2, 3))
11143
11144            gradcheck(fn, (input, 0, idx, src, reduction), check_batched_grad=False)
11145
11146    def test_scatter_index_reduce_prod_gradgrad_error(self, device):
11147        # test that double backward raises an error for the case where 2 zeros in src
11148        # are scattered to the same position in self
11149        input = torch.tensor(
11150            [1.0], device=device, dtype=torch.float64, requires_grad=True
11151        )
11152        src = torch.tensor(
11153            [0.0, 0.0], device=device, dtype=torch.float64, requires_grad=True
11154        )
11155        idx = torch.tensor([0, 0], device=device, dtype=torch.long)
11156
11157        for fn in (torch.scatter_reduce, torch.index_reduce):
11158            # check that this case passes on gradcheck
11159            gradcheck(fn, (input, 0, idx, src, "prod"), check_batched_grad=False)
11160            with self.assertRaisesRegex(
11161                RuntimeError, "Double backward is unsupported for"
11162            ):
11163                gradgradcheck(fn, (input, 0, idx, src, "prod"))
11164
11165    @skipIfMps  # the test doesn't work on MPS as double types are not supported
11166    def test_parameter_resize(self, device):
11167        asd = torch.nn.Parameter(torch.ones(16, dtype=torch.double, device=device))
11168
11169        for i in range(2):
11170            with torch.no_grad():
11171                asd.set_(asd[1:])
11172                asd.grad = None
11173
11174            m = torch.cat((asd, asd))
11175            m.sum().backward()
11176
11177    @skipIfMps  # the test doesn't work on MPS as double types are not supported
11178    @dtypes(torch.double, torch.cdouble)
11179    def test_sparse_ctor_getter_backward(self, device, dtype):
11180        # See NOTE [ Sparse: autograd and API ] on the expected behavior of this test
11181        def _test(size, sparse_dim, nnz, device):
11182            v_size = [nnz] + list(size[sparse_dim:])
11183            i = torch.rand(sparse_dim, nnz)
11184            i.mul_(torch.tensor(size[:sparse_dim]).unsqueeze(1).to(i))
11185            i = i.to(torch.long)
11186
11187            inp = torch.randn(
11188                v_size, dtype=torch.double, device=device, requires_grad=True
11189            )
11190            other = self.genSparseTensor(
11191                size, sparse_dim, nnz, is_uncoalesced=True, device=device, dtype=dtype
11192            )[0]
11193
11194            def fn(v):
11195                x = torch.sparse_coo_tensor(i, v, size, dtype=dtype, device=device)
11196                y = (x + other).coalesce()
11197                yv = y.values()
11198                new_v = yv.tanh()
11199                z = torch.sparse_coo_tensor(y.indices(), new_v, y.size())
11200                return z.coalesce().values()
11201
11202            gradcheck(fn, (inp,), check_batched_grad=False)
11203            # FIXME: make gradgradcheck work.
11204            # gradgradcheck(fn, (inp,), check_batched_grad=False)
11205
11206            # assert that _values is non-differentiable
11207            with self.assertRaisesRegex(RuntimeError, "does not have a grad_fn"):
11208                other.detach().requires_grad_()._values().backward(
11209                    torch.ones_like(other._values())
11210                )
11211
11212        for empty_i, empty_v, empty_nnz in product([True, False], repeat=3):
11213            sparse_size = [] if empty_i else [2, 1]
11214            dense_size = [1, 0, 2] if empty_v else [1, 2]
11215            nnz = 0 if empty_nnz else 5
11216            _test(sparse_size + dense_size, len(sparse_size), nnz, device)
11217
11218    @skipMeta
11219    @skipIfMps
11220    @dtypes(torch.double, torch.cdouble)
11221    def test_sparse_backward(self, device, dtype):
11222        class FixedGradientFunction(Function):
11223            @staticmethod
11224            def forward(ctx, x, grad_x):
11225                ctx.save_for_backward(grad_x)
11226                return x
11227
11228            @staticmethod
11229            def backward(ctx, grad_x):
11230                (saved_grad_x,) = ctx.saved_tensors
11231                return saved_grad_x, None
11232
11233        size = torch.Size([6, 3, 2])
11234        i1 = torch.tensor([[0, 3, 4], [0, 2, 2]], dtype=torch.long)
11235        v1 = make_tensor([3, 2], dtype=dtype, device=device)
11236        sparse_grad1 = torch.sparse_coo_tensor(i1, v1, size, dtype=dtype, device=device)
11237        i2 = torch.tensor([[0, 1, 3, 4], [0, 1, 2, 2]], dtype=torch.long)
11238        v2 = make_tensor([4, 2], dtype=dtype, device=device)
11239        sparse_grad2 = torch.sparse_coo_tensor(i2, v2, size, dtype=dtype, device=device)
11240        dense_grad = torch.rand(size, device=device, dtype=dtype)
11241        fn = FixedGradientFunction
11242
11243        # sparse first
11244        x = torch.randn(size, dtype=dtype, device=device, requires_grad=True)
11245        (
11246            fn.apply(x, sparse_grad1)
11247            + fn.apply(x, dense_grad)
11248            + fn.apply(x, sparse_grad2)
11249        ).sum().abs().backward()
11250        self.assertEqual(x.grad, dense_grad + sparse_grad1 + sparse_grad2)
11251        # dense first
11252        x = torch.randn(size, dtype=dtype, device=device, requires_grad=True)
11253        (
11254            fn.apply(x, dense_grad)
11255            + fn.apply(x, sparse_grad1)
11256            + fn.apply(x, sparse_grad2)
11257        ).sum().abs().backward()
11258        self.assertEqual(x.grad, dense_grad + sparse_grad1 + sparse_grad2)
11259        # sparse only
11260        x = torch.randn(size, dtype=dtype, device=device, requires_grad=True)
11261        (fn.apply(x, sparse_grad1) + fn.apply(x, sparse_grad2)).sum().abs().backward()
11262        self.assertEqual(x.grad, sparse_grad1 + sparse_grad2)
11263
11264    @skipIfMps
11265    def test_sparse_mask_autograd(self, device):
11266        tensor = torch.randn(3, requires_grad=True, device=device)
11267        mask = torch.ones(3, device=device)
11268        mask[1] = 0
11269        mask = mask.to_sparse()
11270        converted = tensor.sparse_mask(mask).to_dense()
11271        converted.sum().backward()
11272        self.assertEqual(tensor.grad, mask.to_dense())
11273
11274    @skipIfMps  # the test doesn't work on MPS as double types are not supported
11275    def test_pyscalar_conversions(self, device):
11276        def _test_pyscalar_conversions(t, integral_conv):
11277            # integral -> integral
11278            l = t(torch.zeros(1, 1, 1, dtype=torch.long))
11279            pyscalar = -12345
11280            l[0] = pyscalar
11281            self.assertEqual(integral_conv(l), pyscalar)
11282
11283            # floating point -> floating point
11284            f = Variable(t(torch.randn(1, 1, dtype=torch.double)))
11285            pyscalar = -12345.1
11286            f[0] = pyscalar
11287            self.assertEqual(float(f), pyscalar)
11288            f[0] = nan
11289            self.assertTrue(math.isnan(float(f)))
11290            f[0] = inf
11291            self.assertEqual(float(f), inf)
11292            f[0] = -inf
11293            self.assertEqual(float(f), -inf)
11294
11295            # integral -> floating point
11296            # check we can convert something that loses precision
11297            pyscalar = 1234567890123456789
11298            self.assertNotEqual(pyscalar, integral_conv(float(pyscalar)))
11299            l[0] = pyscalar
11300            self.assertEqual(float(l), float(pyscalar))
11301
11302            # floating point -> integral
11303            f[0] = nan
11304            self.assertRaises(ValueError, lambda: integral_conv(f[0]))
11305            f[0] = inf
11306            self.assertRaises(OverflowError, lambda: integral_conv(f[0]))
11307            f[0] = -inf
11308            self.assertRaises(OverflowError, lambda: integral_conv(f[0]))
11309            f[0] = sys.float_info.max
11310            self.assertEqual(integral_conv(f), sys.float_info.max)
11311
11312            # bool, nonzero
11313            def test_nonzero(tensor, value, expected):
11314                tensor[0] = value
11315                self.assertEqual(expected, bool(tensor))
11316                self.assertEqual(expected, True if tensor else False)
11317
11318            test_nonzero(l, 0, False)
11319            test_nonzero(l, -2, True)
11320            test_nonzero(f, 0.0, False)
11321            test_nonzero(f, sys.float_info.min, True)
11322            test_nonzero(f, nan, bool(nan))
11323            test_nonzero(f, inf, bool(inf))
11324            test_nonzero(f, -inf, bool(-inf))
11325
11326        _test_pyscalar_conversions(lambda x: x.to(device), lambda x: int(x))
11327
11328    @dtypesIfMPS(torch.float32)
11329    @dtypesIfCUDA(
11330        torch.half,
11331        torch.float,
11332        torch.double,
11333        torch.int8,
11334        torch.int16,
11335        torch.int32,
11336        torch.int64,
11337    )
11338    @dtypes(
11339        torch.float, torch.double, torch.int8, torch.int16, torch.int32, torch.int64
11340    )
11341    def test_set_requires_grad_only_for_floats(self, device, dtype):
11342        def f1():
11343            a = torch.ones(1, dtype=dtype, device=device)
11344            a.requires_grad_()
11345
11346        def f2():
11347            a = torch.ones(1, dtype=dtype, device=device)
11348            a.requires_grad = True
11349
11350        def f3():
11351            torch.ones(1, dtype=dtype, device=device, requires_grad=True)
11352
11353        a = torch.ones(1, dtype=dtype, device=device)
11354        a.requires_grad = False  # should always work
11355        a.requires_grad_(False)
11356
11357        for f in [f1, f2, f3]:
11358            if dtype.is_floating_point:
11359                f()
11360            else:
11361                with self.assertRaisesRegex(
11362                    RuntimeError,
11363                    "floating point",
11364                    msg=f"dt: {a.dtype} device: {a.device}",
11365                ):
11366                    f()
11367
11368    @onlyCUDA
11369    def test_advanced_indexing_backwards_large(self, device):
11370        # See https://github.com/pytorch/pytorch/issues/22843
11371        n = 1 << 16
11372        x = torch.rand(n, 1, device=device, requires_grad=True)
11373        a = x[:, [0]]
11374        a.sum().backward()
11375        self.assertEqual(x.grad, torch.ones(n, 1, device=device))
11376
11377    def test_advanced_indexing_backwards_memory_format(self, device):
11378        # See https://github.com/pytorch/pytorch/issues/36956
11379        shape = (2, 8, 1, 2)
11380        i = torch.randint(1, shape, device=device).contiguous(
11381            memory_format=torch.channels_last
11382        )
11383        x = torch.randn(shape, requires_grad=True, device=device)
11384        x[i].sum().backward()
11385
11386    def _test_reentrant_parent_error_on_cpu(self, device):
11387        t1 = torch.rand([3, 3], requires_grad=True)
11388        t2 = torch.rand([3, 3], device=device, requires_grad=True)
11389        t3 = torch.rand([3, 3], device=device, requires_grad=True)
11390
11391        # Parent graph cpu graph.
11392        t4 = t1 * t1
11393        t5 = TestAutograd.SimulateBackwardError.apply(t4)
11394
11395        # Child gpu graph (much longer than parent graph).
11396        prev = t2 * t2
11397        for i in range(10):
11398            prev = prev * t2
11399        reentrant_root = prev
11400
11401        class ReentrantFunc(Function):
11402            @staticmethod
11403            def forward(ctx, inp):
11404                return inp.clone()
11405
11406            @staticmethod
11407            def backward(ctx, grad):
11408                # Reentrant backward in child will take much longer.
11409                reentrant_root.backward()
11410                return grad
11411
11412        # Parent gpu graph.
11413        t6 = ReentrantFunc.apply(t3)
11414        t7 = t6 * t6
11415
11416        # Parent graph will error out first, while child graph will continue executing.
11417        with self.assertRaisesRegex(Exception, "Simulate error"):
11418            torch.autograd.backward([t5.sum(), t7.sum()])
11419
11420        # No grads should be accumulated since child graph will stop execution
11421        # after parent receives error.
11422        self.assertIsNone(t2.grad)
11423        self.assertIsNone(t1.grad)
11424        self.assertIsNone(t3.grad)
11425
11426    @onlyCUDA
11427    def test_reentrant_parent_error_on_cpu(self, device):
11428        def _get_cuda_memory_usage():
11429            # we don't need CUDA synchronize because the statistics are not tracked at
11430            # actual freeing, but at when marking the block as free.
11431            num_devices = torch.cuda.device_count()
11432            gc.collect()
11433            return tuple(torch.cuda.memory_allocated(i) for i in range(num_devices))
11434
11435        before = _get_cuda_memory_usage()
11436
11437        # Run as separate function so that gc can clean up everything when we
11438        # check for memory usage.
11439        self._test_reentrant_parent_error_on_cpu(device)
11440
11441        # Wait for autograd thread to cleanup failed tasks.
11442        after = _get_cuda_memory_usage()
11443        start = time.time()
11444        while before != after and time.time() - start < 30:
11445            time.sleep(0.1)
11446            after = _get_cuda_memory_usage()
11447
11448        self.assertEqual(before, after)
11449
11450    @skipIfMps  # the test doesn't work on MPS
11451    # TODO: see if these tests can be ported to OpInfos or moved to where's test suite
11452    def test_where_functional(self, device):
11453        x = torch.randn(5, 5, dtype=torch.double, device=device, requires_grad=True)
11454        y = torch.randn(5, 5, dtype=torch.double, device=device, requires_grad=True)
11455        cond = mask_not_all_zeros((5, 5)).to(device=device)
11456
11457        def where(cond, x, y):
11458            return torch.where(cond, x, y)
11459
11460        gradcheck(where, [cond, x, y], raise_exception=True)
11461        gradgradcheck(where, [cond, x, y], [torch.randn(5, 5, device=device)])
11462
11463        x = torch.randn(5, 1, 5, dtype=torch.double, device=device, requires_grad=True)
11464        y = torch.randn(5, 5, 1, dtype=torch.double, device=device, requires_grad=True)
11465        gradcheck(where, [cond, x, y], raise_exception=True)
11466        gradgradcheck(where, [cond, x, y], [torch.randn(5, 5, 5, device=device)])
11467
11468    @skipIfMps  # the test doesn't work on MPS
11469    def test_where_scalar(self, device):
11470        x = torch.randn(5, 5, dtype=torch.double, device=device, requires_grad=True)
11471        scalar = 4.0
11472        cond = mask_not_all_zeros((5, 5)).to(device=device)
11473
11474        def where_scalar_first(cond, x):
11475            return torch.where(cond, scalar, x)
11476
11477        def where_scalar_second(cond, x):
11478            return torch.where(cond, x, scalar)
11479
11480        gradcheck(where_scalar_first, (cond, x))
11481        gradgradcheck(where_scalar_first, (cond, x))
11482
11483        gradcheck(where_scalar_second, (cond, x))
11484        gradgradcheck(where_scalar_second, (cond, x))
11485
11486    @onlyCUDA
11487    def test_free_unneeded_tensor(self, device):
11488        x = torch.randn(2, 3, 10, 10, device=device, requires_grad=True)
11489        m = torch.randn(1, 3, 1, 1, device=device)
11490
11491        z = x.sum()
11492        base_mem = torch.cuda.memory_allocated()
11493        z = ((x + 2) * m).sum()
11494        end_mem = torch.cuda.memory_allocated()
11495
11496        # In the end the memory usage should remain equal, because neither of
11497        # (x + 2) and ((x + 2) * m) should be kept alive for backward, while the
11498        # previous allocation of z had the same size as the current one.
11499        self.assertEqual(base_mem, end_mem)
11500
11501    @onlyCUDA
11502    def test_pin_memory(self, device):
11503        x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
11504        self.assertEqual(x, x.pin_memory())
11505        self.assertIsNot(x, x.pin_memory())
11506        self.assertTrue(x.pin_memory().requires_grad)
11507        gradcheck(lambda x: x.pin_memory(), [x])
11508        gradgradcheck(lambda x: x.pin_memory(), [x])
11509
11510    @onlyCUDA
11511    def test_profiler_emit_nvtx(self, device):
11512        # This test is not intended to ensure correctness of nvtx ranges.
11513        # That would require something a great deal more complex (you'd have to create a
11514        # profile in a subprocess, open it, and parse the sql somehow).
11515        # This test is merely intended to catch if emit_nvtx breaks on construction.
11516        a = torch.tensor([1, 2, 3], dtype=torch.float32, device=device)
11517        with torch.cuda.profiler.profile():
11518            with emit_nvtx():
11519                a.add(1.0)
11520
11521    @onlyCUDA
11522    def test_rnn_backward_to_input_but_not_parameters(self, device):
11523        # this checks whether it is possible to not require
11524        # weight parameters, but require inputs, see #7722
11525        l = torch.nn.LSTM(2, 3).to(device)
11526        for p in l.parameters():
11527            p.requires_grad = False
11528        s = torch.randn(1, 1, 2, requires_grad=True, device=device)
11529        out, _ = l(s)
11530        out.sum().backward()
11531        self.assertFalse(s.grad is None or s.grad.abs().sum().item() == 0)
11532
11533    @unittest.skipIf(not torch.profiler.itt.is_available(), "ITT is required")
11534    def test_profiler_emit_itt(self, device):
11535        # This test is not intended to ensure correctness of itt ranges.
11536        # That would require something a great deal more complex (you'd have to create a
11537        # profile in a subprocess, open it, and parse the sql somehow).
11538        # This test is merely intended to catch if emit_itt breaks on construction.
11539        a = torch.tensor([1, 2, 3], dtype=torch.float32, device=device)
11540        with emit_itt():
11541            a.add(1.0)
11542
11543    @skipIfMps  # the test doesn't work as randn is not supported with type long
11544    @deviceCountAtLeast(1)
11545    def test_grad_assignment(self, devices):
11546        x = torch.randn(5, 5, device=devices[0])
11547
11548        # Tests that the wrong type raises
11549        with self.assertRaisesRegex(TypeError, "expected to be a Tensor or None"):
11550            x.grad = 0
11551
11552        # Tests that the wrong shape raises
11553        with self.assertRaises(RuntimeError):
11554            x.grad = torch.randn(2, 2, device=devices[0])
11555
11556        # Tests that the wrong dtype raises
11557        with self.assertRaises(RuntimeError):
11558            x.grad = torch.randn(5, 5, dtype=torch.long, device=devices[0])
11559
11560        # Tests that self-assignment raises
11561        with self.assertRaises(RuntimeError):
11562            x.grad = x
11563
11564        # Tests device -> cpu grad assignment raises
11565        if self.device_type != "cpu":
11566            with self.assertRaises(RuntimeError):
11567                t_cpu = torch.rand(5, 5)
11568                t_cpu.grad = torch.randn(5, 5, device=devices[0])
11569
11570        # Tests half type on CUDA
11571        if self.device_type == "cuda":
11572            x = x.to(dtype=torch.half, device=devices[0])
11573            x.grad = torch.zeros_like(x)
11574
11575        # Tests cross-device assignment raises
11576        if len(devices) > 1:
11577            x = torch.randn(5, 5, device=devices[0])
11578            with self.assertRaises(RuntimeError):
11579                x.grad = torch.randn(5, 5, device=devices[1])
11580
11581    @dtypesIfMPS(torch.float32)
11582    @deviceCountAtLeast(1)
11583    @dtypes(torch.float, torch.double)
11584    def test_requires_grad_factory(self, devices, dtype):
11585        fns = [torch.ones_like, torch.randn_like]
11586        x = torch.randn(2, 3, dtype=dtype, device=devices[0])
11587
11588        for fn in fns:
11589            for requires_grad in [True, False]:
11590                output = fn(
11591                    x, dtype=dtype, device=devices[0], requires_grad=requires_grad
11592                )
11593                self.assertEqual(requires_grad, output.requires_grad)
11594                self.assertIs(dtype, output.dtype)
11595                self.assertEqual(devices[0], str(x.device))
11596
11597    @deviceCountAtLeast(2)
11598    def test_unused_output_device(self, devices):
11599        from torch.nn.parallel._functions import Broadcast
11600
11601        x = torch.randn(5, 5, dtype=torch.float, device=devices[0], requires_grad=True)
11602        outputs = Broadcast.apply(list(range(len(devices))), x)
11603        y = outputs[-1] * 2
11604        y.sum().backward()
11605        self.assertEqual(x.grad, torch.ones(5, 5) * 2)
11606
11607    @deviceCountAtLeast(2)
11608    def test_backward_device(self, devices):
11609        # check that current device matches the variable's device
11610        device = [None]
11611
11612        class Identity(torch.autograd.Function):
11613            @staticmethod
11614            def forward(ctx, x):
11615                return x.clone()
11616
11617            @staticmethod
11618            def backward(ctx, grad_output):
11619                device[0] = grad_output.device
11620                return grad_output.clone()
11621
11622        v = torch.randn(1, device=devices[1], requires_grad=True)
11623        Identity.apply(v).backward()
11624        self.assertEqual(str(device[0]), devices[1])
11625
11626    @deviceCountAtLeast(2)
11627    def test_inputbuffer_add_multidevice(self, devices):
11628        input = torch.randn(1, device=devices[0], requires_grad=True)
11629        output = input.to(device=devices[1]) + input.to(device=devices[1])
11630        output.backward()
11631
11632    @onlyCPU
11633    def test_copy_(self, device):
11634        # At the time of writing this test, copy_ is not generated from native_functions.yaml
11635        # there was a bug that bfloat16 was not recognized as floating.
11636        x = torch.randn(10, device=device, requires_grad=True)
11637        floating_dt = floating_types_and(torch.half, torch.bfloat16)
11638        for dt in floating_dt:
11639            y = torch.empty(10, device=device, dtype=dt)
11640            y.copy_(x)
11641            self.assertTrue(y.requires_grad)
11642            z = x.to(torch.bfloat16)
11643            self.assertTrue(z.requires_grad)
11644
11645    def test_copy_forward_ad_broadcasting(self, device):
11646        # copy_ allows the src to have a different shape from self as long as src is
11647        # broadcastable to self. Make sure forward AD handles this case.
11648        primal = torch.rand(3, 3, device=device)
11649        tangent = torch.rand(3, 3, device=device)
11650        non_dual = torch.rand(1, 3, 3, device=device)
11651
11652        with fwAD.dual_level():
11653            dual = fwAD.make_dual(primal, tangent)
11654            non_dual.copy_(dual)
11655
11656    def test_copy_forward_ad_same_layout_copies_grad(self, device):
11657        primal = torch.tensor([[3.0], [4.0]], device=device)
11658        tangent = torch.tensor([[5.0], [6.0]], device=device)
11659
11660        with fwAD.dual_level():
11661            x_dual = fwAD.make_dual(primal, tangent)
11662            non_dual = torch.tensor([[1.0], [2.0]])
11663            non_dual.copy_(x_dual)
11664            self.assertTrue(fwAD.unpack_dual(non_dual).tangent is not tangent)
11665
11666    @onlyCUDA
11667    def test_simple_reentrant_cross_device(self, device):
11668        class ReentrantFunc(Function):
11669            _cpu_mode = True
11670
11671            @staticmethod
11672            def forward(ctx, x):
11673                return x * (x + 2)
11674
11675            @staticmethod
11676            def backward(ctx, grad_output):
11677                with torch.enable_grad():
11678                    if ReentrantFunc._cpu_mode:
11679                        new_param = torch.randn(2, 2, requires_grad=True)
11680                        (new_param**2).sum().backward()
11681                    else:
11682                        new_param = torch.randn(2, 2, device=device, requires_grad=True)
11683                        (new_param**2).sum().backward()
11684                return grad_output
11685
11686        # Reentrant starts on GPU thread, finishs on GPU thread
11687        x = torch.randn(2, 2, device=device, requires_grad=True)
11688        out = ReentrantFunc.apply(x)
11689        out.sum().backward()
11690
11691        # Reentrant starts on CPU thread, finishs on GPU thread
11692        x = torch.randn(2, 2, requires_grad=True)
11693        # set ReentrantFunc node to GPU to emit tasks to GPU queue
11694        ReentrantFunc._cpu_mode = False
11695        out = ReentrantFunc.apply(x)
11696        out.sum().backward()
11697
11698        # Reentrant starts on GPU thread, finishs on CPU thread
11699        x = torch.randn(2, 2, device=device, requires_grad=True)
11700        # set ReentrantFunc node to CPU to emit tasks to CPU queue
11701        ReentrantFunc._cpu_mode = True
11702        out = ReentrantFunc.apply(x)
11703        out.sum().backward()
11704
11705    @onlyCUDA
11706    def test_cross_device_reentrant_autograd(self, device):
11707        # Output on gpu so that this task will be associated with the gpu thread
11708        def fn_on_gpu(inp):
11709            # Artificially increase the priority of the next op to make sure it runs
11710            # as soon as we reach it before the ops of branch1.
11711            dummy = inp * 2 * 2 * 2 * 2
11712            return inp.to(device=device)
11713
11714        def parent_on_cpu(inp):
11715            # Slow branch of ops on gpu so that the work queue for the gpu thread
11716            # won't empty too quickly. They also have smaller priorities than the
11717            # ones created by fn_on_gpu
11718            branch1 = inp.to(device=device)
11719            branch1 = branch1 / branch1
11720            branch1 = branch1 / branch1
11721            branch1 = branch1 / branch1
11722            # Perform checkpoint on cpu tensors. So the last op performed in the reentrant
11723            # autograd is an AccumulateGrad that runs on the cpu thread for the gpu thread.
11724            # So the cpu thread will notify the gpu thread with an empty NodeTask.
11725            branch2 = checkpoint(fn_on_gpu, inp, use_reentrant=True)
11726            out = branch2 + branch1
11727            return out
11728
11729        inp = torch.rand(2, requires_grad=True)
11730        out = parent_on_cpu(inp)
11731        # This will segfault if the empty NodeTask is not handled properly in the
11732        # gpu thread ReadyQueue
11733        out.sum().backward()
11734
11735    def test_inplace_on_view_backprop_base(self, device):
11736        # modify view and back-prop through base
11737        root = torch.randn(2, 2, device=device, requires_grad=True)
11738        x = root.clone()
11739        v1 = x.narrow(0, 0, 1)
11740        v1.mul_(2)
11741        x.sum().backward()
11742        self.assertEqual(root.grad.tolist(), [[2, 2], [1, 1]])
11743
11744    def test_inplace_on_view_backprop_view_of_view(self, device):
11745        # modify view and backprop through view-of-view
11746        root = torch.randn(2, 2, device=device, requires_grad=True)
11747        x = root.clone()
11748        v1 = x.narrow(0, 0, 1)
11749        v2 = x.narrow(0, 0, 1)
11750        v1.mul_(2)
11751        v2.sum().backward()
11752        self.assertEqual(root.grad.tolist(), [[2, 2], [0, 0]])
11753
11754    def test_inplace_on_view_of_view(self, device):
11755        # modify view-of-view and backprop through base
11756        root = torch.randn(2, 2, device=device, requires_grad=True)
11757        x = root.clone()
11758
11759        v1 = x.narrow(0, 0, 1)
11760        v2 = v1.narrow(1, 1, 1)
11761        v2.mul_(2)
11762        x.sum().backward()
11763        self.assertEqual(root.grad.tolist(), [[1, 2], [1, 1]])
11764
11765    @skipIfMps  # the test doesn't work on MPS as double types are not supported
11766    def test_inplace_on_view_then_no_grad(self, device):
11767        # Perform an in-place operation on a view of a non-leaf variable.
11768        a = torch.ones(3, 1, dtype=torch.double, device=device, requires_grad=True)
11769        b = a * 2
11770        c = b.view_as(b)
11771        c[0][0] = 3
11772
11773        # Force a graph update with grad disabled.
11774        with torch.no_grad():
11775            c.grad_fn
11776
11777        c.sum().backward()
11778
11779    @skipIfMps  # the test doesn't work on MPS as double types are not supported
11780    def test_inplace_on_view_gradcheck(self, device):
11781        # gradcheck modifications to views
11782        a = torch.randn(4, 4, dtype=torch.double, device=device, requires_grad=True)
11783        b = torch.randn(2, 2, dtype=torch.double, device=device, requires_grad=True)
11784
11785        def func(root, b):
11786            x = root.clone()
11787            x.narrow(1, 2, 2).narrow(0, 1, 2).mul_(b)
11788            x.narrow(1, 0, 2).narrow(0, 1, 2).mul_(b)
11789            return x
11790
11791        gradcheck(func, [a, b], raise_exception=True)
11792        go = torch.randn(
11793            a.size(), dtype=torch.double, device=device, requires_grad=True
11794        )
11795        gradgradcheck(func, (a, b), (go,))
11796
11797    def test_inplace_on_view_multiple_outputs(self, device):
11798        root = torch.arange(9.0, dtype=torch.double).reshape(3, 3).requires_grad_()
11799        x = root.clone()
11800        v1 = x.unbind()
11801        with self.assertRaises(RuntimeError):
11802            v1[0].mul_(2)
11803
11804    @skipIfMps  # the test doesn't work on MPS as double types are not supported
11805    def test_inplace_on_view_of_multiple_output_view(self, device):
11806        a = torch.rand(
11807            10, dtype=torch.double, device=device, requires_grad=True
11808        ).clone()
11809        b = a.unbind(0)
11810        c = b[0].view_as(b[0])
11811        with self.assertRaises(RuntimeError):
11812            c.mul_(2)
11813
11814    @skipIfMps  # MPS backend doesn't support double types
11815    def test_inplace_multiple_output_view_of_view(self, device):
11816        a = torch.rand(
11817            10, dtype=torch.double, device=device, requires_grad=True
11818        ).clone()
11819        b = a.view_as(a)
11820        c = b.unbind(0)
11821        with self.assertRaises(RuntimeError):
11822            c[0].mul_(2)
11823
11824    @skipIfMps  # MPS backend doesn't support double types
11825    def test_inplace_on_view_makes_base_require_grad(self, device):
11826        # in-place modification to view makes base require grad
11827        a = torch.randn(4, 4, dtype=torch.double, device=device, requires_grad=False)
11828        b = torch.randn(4, 2, dtype=torch.double, device=device, requires_grad=True)
11829
11830        def func(root, b):
11831            x = root.clone()
11832            self.assertFalse(x.requires_grad)
11833            x.narrow(1, 2, 2).mul_(b)
11834            self.assertTrue(x.requires_grad)
11835            return x
11836
11837        gradcheck(func, [a, b], raise_exception=True)
11838        go = torch.randn(
11839            a.size(), dtype=torch.double, device=device, requires_grad=True
11840        )
11841        gradgradcheck(func, (a, b), (go,))
11842
11843    def test_inplace_on_view_backprop_view(self, device):
11844        # modify view and backprop through view
11845        a = torch.tensor([2.0, 5.0], device=device, requires_grad=False)
11846        b = torch.tensor([3.0], device=device, requires_grad=True)
11847        res = a.narrow(0, 1, 1).mul_(b)
11848        res.sum().backward()
11849        self.assertEqual(b.grad.tolist(), [5])
11850        self.assertIsNone(a.grad)
11851
11852    @skipIfMps  # the test doesn't work on MPS as double types are not supported
11853    def test_inplace_on_view_modify_base(self, device):
11854        # Test that an in-place operation on a base that forced it to require
11855        # grad also forces any previous views to require grad and backprop
11856        # correctly
11857        r = torch.ones(1, dtype=torch.double, device=device, requires_grad=True)
11858
11859        def fn(r):
11860            x = torch.ones(5, dtype=torch.double, device=device)
11861            v = x.select(0, 1)
11862            self.assertFalse(v.requires_grad)
11863            self.assertIsNone(v.grad_fn)
11864            x.add_(r)  # v is now dependent on r due to the in-place op on x
11865            self.assertTrue(v.requires_grad)
11866            return v
11867
11868        gradcheck(fn, [r])
11869        gradgradcheck(fn, [r])
11870
11871    @skipIfMps  # the test doesn't work on MPS as double types are not supported
11872    def test_inplace_on_view_python(self, device):
11873        # in-place modifications of Python-autograd created view
11874        a = torch.randn(4, 4, dtype=torch.double, device=device, requires_grad=True)
11875        b = torch.randn(2, 2, dtype=torch.double, device=device, requires_grad=True)
11876
11877        class PyAdd(torch.autograd.Function):
11878            @staticmethod
11879            def forward(ctx, x, y):
11880                ctx.mark_dirty(x)
11881                x.add_(y)
11882                return x
11883
11884            @staticmethod
11885            def backward(ctx, grad):
11886                return grad, grad
11887
11888        def func(root, b):
11889            x = root.clone()
11890            PyAdd.apply(x.narrow(1, 2, 2).narrow(0, 1, 2), b)
11891            PyAdd.apply(x.narrow(1, 0, 2).narrow(0, 1, 2), b)
11892            return x
11893
11894        gradcheck(func, [a, b], raise_exception=True)
11895        go = torch.randn(
11896            a.size(), dtype=torch.double, device=device, requires_grad=True
11897        )
11898        gradgradcheck(func, (a, b), (go,))
11899
11900    def test_inplace_on_view_non_contig(self, device):
11901        root = torch.ones(2, 3, 2, device=device).select(2, 1).t().requires_grad_(True)
11902        x = root.clone()
11903        v1 = x.narrow(0, 0, 1)
11904        v2 = v1.narrow(1, 1, 1)
11905        v2.mul_(2)
11906        x.sum().backward()
11907        self.assertEqual(root.grad.tolist(), [[1, 2], [1, 1], [1, 1]])
11908
11909    def test_inplace_on_view_multi_output_unsafe(self, device):
11910        for f in [
11911            lambda t: t.unsafe_split(1),
11912            lambda t: t.unsafe_split_with_sizes((1, 1, 1)),
11913            lambda t: t.unsafe_chunk(3),
11914        ]:
11915            a = torch.randn(3, 3, device=device, requires_grad=True)
11916            b = a + a
11917            s1, s2, s3 = f(b)
11918            s1.mul_(s2)
11919            s1.sum().backward()
11920
11921    def test_inplace_on_view_multi_output_safe(self, device):
11922        for f in [
11923            lambda t: t.split(1),
11924            lambda t: t.split_with_sizes((1, 1, 1)),
11925            lambda t: t.chunk(3),
11926        ]:
11927            a = torch.randn(3, 3, device=device, requires_grad=True)
11928            b = a + a
11929            s1, s2, s3 = f(b)
11930            error_msg = (
11931                "This view is the output of a function that returns multiple views."
11932            )
11933            with self.assertRaisesRegex(RuntimeError, error_msg):
11934                s1.mul_(s2)
11935
11936    def test_inplace_on_view_undefined_grad_output(self, device):
11937        a = torch.tensor([1.0], requires_grad=True)
11938        c = a.clone()
11939        v = c[:]
11940        b = torch.tensor(1.0, requires_grad=True)
11941
11942        class InplaceFunc(torch.autograd.Function):
11943            @staticmethod
11944            def forward(ctx, x, other):
11945                ctx.mark_dirty(x)
11946                return x.mul_(2)
11947
11948            @staticmethod
11949            def backward(ctx, grad):
11950                return grad * 2, None
11951
11952        out = InplaceFunc.apply(v, b)
11953        out.backward()
11954        self.assertIsNone(b.grad)
11955        self.assertEqual(a.grad.item(), 2)
11956
11957    @skipIfMps  # the test doesn't work on MPS as double types are not supported
11958    def test_mv_grad_stride_0(self, device):
11959        # Reference: https://github.com/pytorch/pytorch/issues/38315
11960        mat = torch.randn(2, 2, dtype=torch.double, device=device)
11961        vec = torch.randn(1, dtype=torch.double, device=device).requires_grad_(True)
11962
11963        def fn(vec):
11964            # Expand inside the function to make sure the input to
11965            # gradcheck does not have overlapping memory
11966            vec = vec.expand(2)
11967            return (mat @ vec).sum()
11968
11969        gradcheck(fn, (vec))
11970        gradgradcheck(fn, (vec))
11971
11972    @onlyCUDA
11973    def test_gradcheck_input_output_different_device(self, device):
11974        x = torch.ones((1,), dtype=torch.double, device="cuda", requires_grad=True)
11975        gradcheck(lambda x: x.to("cpu"), (x,))
11976
11977        x = torch.ones((1,), dtype=torch.double, device="cpu", requires_grad=True)
11978        gradcheck(lambda x: x.to("cuda"), (x,))
11979
11980    def test_strided_leaf_grad_layout(self, device):
11981        # (1) If leaf is non-overlapping and dense, grad's layout should match its leaf.
11982        for fmt_a in (torch.contiguous_format, torch.channels_last):
11983            for fmt_b in (torch.contiguous_format, torch.channels_last):
11984                a = torch.rand((2, 3, 4, 5), device=device).to(memory_format=fmt_a)
11985                b = torch.rand((2, 3, 4, 5), device=device).to(memory_format=fmt_b)
11986                a.requires_grad_()
11987                b.requires_grad_()
11988                # checks (1) for broadcasted gradients
11989                a.sum().backward()
11990                self.assertEqual(a.grad.stride(), a.stride())
11991                b.sum().backward()
11992                self.assertEqual(b.grad.stride(), b.stride())
11993                # checks (1) for non-broadcasted gradients
11994                a.grad = None
11995                b.grad = None
11996                (a * b).sum().backward()
11997                self.assertEqual(a.grad.stride(), a.stride())
11998                self.assertEqual(b.grad.stride(), b.stride())
11999
12000        # (2) If leaf isn't dense, checks that grads are rowmajor contiguous.
12001        c = torch.empty_strided((2, 2), (4, 2), device=device).copy_(
12002            torch.rand((2, 2), device=device)
12003        )
12004        c.requires_grad_()
12005        d = torch.rand((2, 2), device=device)
12006        # checks (2) for broadcasted gradients
12007        c.sum().backward()
12008        self.assertEqual(c.grad.stride(), (2, 1))
12009        # checks (2) for non-broadcasted gradients
12010        c.grad = None
12011        (c * d).sum().backward()
12012        self.assertEqual(c.grad.stride(), (2, 1))
12013
12014    @skipIfMps
12015    def test_copy_r_to_c(self, device):
12016        out_c = torch.empty(3, 2, dtype=torch.cdouble, device=device)
12017        inp_r = torch.randn(3, 2, dtype=torch.double, device=device, requires_grad=True)
12018
12019        def do_test():
12020            out_c.copy_(inp_r)
12021            out_c_inter = out_c.sum()
12022            out_c_inter.abs().backward()
12023            with torch.no_grad():
12024                self.assertEqual(
12025                    inp_r.grad, torch.ones_like(inp_r) * torch.sgn(out_c_inter).real
12026                )
12027
12028        self.assertNotWarn(do_test)
12029
12030    def test_to_r_to_c(self, device):
12031        def do_test():
12032            inp_r = torch.randn(
12033                3, 2, dtype=torch.double, device=device, requires_grad=True
12034            )
12035            out = inp_r.to(torch.complex128)
12036            out_inter = out.sum()
12037            out_inter.abs().backward()
12038            with torch.no_grad():
12039                self.assertEqual(
12040                    inp_r.grad, torch.ones_like(inp_r) * torch.sgn(out_inter).real
12041                )
12042
12043        self.assertNotWarn(do_test)
12044
12045    def test_non_differentiable_ops(self, device):
12046        # Just make sure the op doesn't raise an error
12047        # and resulting tensor has requires_grad=False.
12048        x = torch.tensor([[1, 2], [3, 4.0]], requires_grad=True, device=device)
12049        out = torch.isin(x, torch.tensor([2, 3], device=device))
12050        self.assertFalse(out.requires_grad)
12051
12052        x = torch.randn(3, 3, requires_grad=True)
12053        out = torch.signbit(x)
12054        self.assertFalse(out.requires_grad)
12055
12056    def test_warning_in_backward(self, device):
12057        # Test warning during backward are always propagated as python warnings (gh-50209)
12058        # NOTE: For device=cuda, warning gets propagated from a worker thread
12059        a = torch.zeros((), device=device, requires_grad=True)
12060        b = torch._C._nn._test_warn_in_autograd(a)
12061
12062        with self.assertWarnsRegex(UserWarning, "Warn from backward"):
12063            b.backward()
12064
12065    def test_complex_scalar_backward(self, device):
12066        a = torch.zeros(1, device=device, requires_grad=True)
12067        b = a * 0.5j
12068
12069        msg = "grad can be implicitly created only for real scalar outputs"
12070        with self.assertRaisesRegex(RuntimeError, msg):
12071            b.backward()
12072
12073        with self.assertRaisesRegex(RuntimeError, msg):
12074            torch.autograd.grad(b, a)
12075
12076    def test_pow_real_negative_base_complex_exponent(self, device):
12077        # OpInfo doesn't naturally support input of mixed types, hence this test here.
12078        base = -torch.ones(2, device=device, dtype=torch.double)
12079        exponent = torch.randn(
12080            2, device=device, dtype=torch.cdouble, requires_grad=True
12081        )
12082
12083        def fn(exponent):
12084            return torch.pow(base, exponent)
12085
12086        torch.autograd.gradcheck(fn, (exponent,))
12087
12088        def fn(exponent):
12089            return torch.pow(-1, exponent)
12090
12091        torch.autograd.gradcheck(fn, (exponent,))
12092
12093    def test_resize_version_bump(self, device):
12094        x = torch.rand((1,), device=device)
12095        y = torch.randn((3,), device=device)
12096        x.resize_((1, 2))
12097        self.assertEqual(x._version, 1)
12098        x.resize_as_(y)
12099        self.assertEqual(x._version, 2)
12100
12101        # In the following cases, `resize` is no-op,
12102        # so no version bumps.
12103        x.resize_((3,))
12104        self.assertEqual(x._version, 2)
12105
12106        x.resize_as_(y)
12107        self.assertEqual(x._version, 2)
12108
12109
12110class TestAllowMutationOnSaved(TestCase):
12111    def assertClonedLenEqual(self, ctx, n):
12112        self.assertEqual(len(list(ctx.cloned.items())), n)
12113
12114    def assertTIDMapLenEqual(self, ctx, n):
12115        self.assertEqual(len(list(ctx.tid_to_weakhandle.items())), n)
12116
12117    def test_basic(self):
12118        a = torch.rand(2, 3, requires_grad=True)
12119
12120        def fn(a):
12121            b = a.clone()
12122            out = (b**2).sum()
12123            b.sin_()
12124            out.sum().backward()
12125            return a.grad
12126
12127        msg = (
12128            "variables needed for gradient computation has been modified by an inplace"
12129        )
12130        with self.assertRaisesRegex(RuntimeError, msg):
12131            fn(a)
12132
12133        with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
12134            da = fn(a)
12135
12136        self.assertTrue(torch.allclose(a * 2, da))
12137        self.assertClonedLenEqual(ctx, 0)
12138
12139    def test_views(self):
12140        a = torch.rand(2, 3, requires_grad=True)
12141
12142        def fn(a):
12143            b = a.clone()
12144            c = b.view_as(b)
12145            out = (b**2).sum()  # How does this work?
12146            c.sin_()
12147            out.sum().backward()
12148            return a.grad
12149
12150        msg = (
12151            "variables needed for gradient computation has been modified by an inplace"
12152        )
12153        with self.assertRaisesRegex(RuntimeError, msg):
12154            fn(a)
12155
12156        with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
12157            da = fn(a)
12158
12159        self.assertClonedLenEqual(ctx, 0)
12160        self.assertTrue(torch.allclose(a * 2, da))
12161
12162    def test_save_base_and_modify_view(self):
12163        with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
12164            a = torch.rand(2, 3, requires_grad=True)
12165            b = a.clone()
12166            c = b[:1]
12167            out = b**2
12168            # modify the view
12169            c *= 10
12170            # self.assertClonedLenEqual(ctx, 1)
12171            out.sum().backward()
12172            self.assertClonedLenEqual(ctx, 0)
12173
12174        self.assertClonedLenEqual(ctx, 0)
12175        self.assertTrue(torch.allclose(a * 2, a.grad))
12176
12177    def test_save_view_modify_base(self):
12178        with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
12179            a = torch.rand(2, 3, requires_grad=True)
12180            b = a.clone()
12181            c = b[:]
12182            out = (c**2).sum()
12183            b *= 2
12184            out.backward()
12185            self.assertTrue(torch.allclose(a * 2, a.grad))
12186
12187    def test_double_backward(self):
12188        with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
12189            a = torch.rand(2, 3, requires_grad=True)
12190            b = a.clone()
12191            out = (b**2).sum()
12192            b.sin_()
12193            torch.autograd.grad(out, a, create_graph=True)
12194            (da,) = torch.autograd.grad(out, a, create_graph=True)
12195            (d2a,) = torch.autograd.grad(da.sum(), a)
12196
12197        self.assertTrue(torch.allclose(torch.ones_like(a) * 2, d2a))
12198        self.assertClonedLenEqual(ctx, 0)
12199
12200    def test_saved_but_not_anymore(self):
12201        # Make sure we don't clone if the tensor was once saved, but
12202        # by the time we do in-place, it is no longer saved
12203        with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
12204            a = torch.randn(2, 3, requires_grad=True).clone()
12205            out = (a**2).sum()
12206            self.assertTIDMapLenEqual(ctx, 1)
12207            self.assertClonedLenEqual(ctx, 0)
12208            out.backward()
12209            a.sin_()
12210            self.assertClonedLenEqual(ctx, 0)
12211            out = (a**2).sum()
12212            a.sin_()
12213            self.assertClonedLenEqual(ctx, 1)
12214            del out
12215            self.assertClonedLenEqual(ctx, 0)
12216
12217    def test_saved_same_tensor_many_times(self):
12218        # We should only clone once
12219        with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
12220            a = torch.randn(2, 3, requires_grad=True).clone()
12221            b = a**2
12222            c = a**2
12223            a.sin_()
12224            self.assertClonedLenEqual(ctx, 1)
12225            del b, c
12226            self.assertClonedLenEqual(ctx, 0)
12227
12228    def test_saved_same_tensor_different_versions(self):
12229        with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
12230            a = torch.randn(2, 3, requires_grad=True).clone()
12231            b = a**2
12232            a.sin_()
12233            c = a**2
12234            a.sin_()
12235            self.assertClonedLenEqual(ctx, 2)
12236            del b
12237            self.assertClonedLenEqual(ctx, 1)
12238            del c
12239            self.assertClonedLenEqual(ctx, 0)
12240
12241    def test_with_math_views(self):
12242        with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
12243            a = torch.tensor([1 + 1j], requires_grad=True).clone()
12244            b = a.conj()
12245            out = (b**2).sum()
12246            a.sin_()
12247            out.abs().backward()
12248
12249            a = torch.tensor([1 + 1j], requires_grad=True).clone()
12250            b = a.conj()
12251            out = (b**2).sum()
12252            # in this case, it is no longer a view it seems
12253            b.sin_()
12254            out.abs().backward()
12255
12256    def test_with_out_variant(self):
12257        with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
12258            a = torch.tensor([1.0], requires_grad=True)
12259            b = torch.tensor([1.0])
12260            c = torch.tensor([2.0])
12261            out = a * b
12262            self.assertTIDMapLenEqual(ctx, 1)
12263            torch.sin(c, out=b)
12264            self.assertClonedLenEqual(ctx, 1)
12265            out.backward()
12266            self.assertClonedLenEqual(ctx, 0)
12267
12268    def test_backward_out_of_context(self):
12269        # Out of context
12270        with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
12271            a = torch.rand(2, 3, requires_grad=True)
12272            out = (a**2).sum()
12273
12274        msg = "Trying to backward outside of the 'allow_mutation_on_saved_tensors' context"
12275        with self.assertRaisesRegex(AssertionError, msg):
12276            out.backward()
12277
12278        # Different context
12279        with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
12280            a = torch.rand(2, 3, requires_grad=True)
12281            out = (a**2).sum()
12282
12283        with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
12284            with self.assertRaisesRegex(AssertionError, msg):
12285                out.backward()
12286
12287    def test_disallow_nesting(self):
12288        with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
12289            msg = "allow_mutation_on_saved_tensors contexts cannot be nested"
12290            with self.assertRaisesRegex(RuntimeError, msg):
12291                with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
12292                    pass
12293
12294
12295class TestAutogradInferenceMode(TestCase):
12296    def _is_inference_tensor(self, tensor):
12297        try:
12298            err_msg = "Inference tensors do not track version counter"
12299            with self.assertRaisesRegex(RuntimeError, err_msg):
12300                tensor._version
12301            return True
12302        except AssertionError as e:
12303            return False
12304
12305    def test_inference_mode_context_manager(self):
12306        self.assertFalse(torch.is_inference_mode_enabled())
12307        with torch.inference_mode():
12308            self.assertTrue(torch.is_inference_mode_enabled())
12309            with torch.inference_mode(False):
12310                self.assertFalse(torch.is_inference_mode_enabled())
12311            self.assertTrue(torch.is_inference_mode_enabled())
12312        self.assertFalse(torch.is_inference_mode_enabled())
12313
12314    def test_inference_mode_decorator(self):
12315        def func(x):
12316            self.assertEqual(torch.is_inference_mode_enabled(), mode)
12317            return x * x
12318
12319        for mode, use_kwarg in product((True, False, None), (True, False)):
12320            if mode is None:
12321                if use_kwarg:
12322                    decorated = torch.inference_mode(mode=func)
12323                else:
12324                    decorated = torch.inference_mode(func)
12325                mode = True
12326            else:
12327                if use_kwarg:
12328                    decorated = torch.inference_mode(mode=mode)(func)
12329                else:
12330                    decorated = torch.inference_mode(mode)(func)
12331
12332            for requires_grad in (True, False):
12333                c = torch.ones(1, 2, 3, requires_grad=requires_grad)
12334                d = decorated(c)
12335                self.assertTrue(not mode or torch.is_inference(d))
12336                self.assertEqual(d.requires_grad, requires_grad and not mode)
12337
12338    def test_inference_mode_tensor_creation(self):
12339        with torch.inference_mode():
12340            # new tensors created through constructors are inference tensors
12341            c = torch.ones(1, 2, 3)
12342            self.assertFalse(c.requires_grad)
12343            self.assertTrue(torch.is_inference(c))
12344
12345            # requires_grad doesn't change inference tensor behavior in InferenceMode
12346            tmp = torch.ones(1, 2, 3, requires_grad=True)
12347            self.assertTrue(tmp.requires_grad)
12348            self.assertTrue(torch.is_inference(tmp))
12349
12350            tmp = torch.ones(1, 2, 3).requires_grad_(False)
12351            self.assertFalse(tmp.requires_grad)
12352            self.assertTrue(torch.is_inference(tmp))
12353
12354    def test_inference_mode_existing_autograd_session(self):
12355        s = torch.ones(1, 2, 3, requires_grad=True)
12356        a = s.clone()
12357
12358        # `a` gets saved outside of inference mode
12359        out = a * a
12360        with torch.inference_mode():
12361            a.add_(2)
12362
12363        self.assertFalse(torch.is_inference(a))
12364        # tensors created outside of inference mode aren't
12365        # inference tensors, so they will still have their
12366        # version counters tracked
12367        err_msg = (
12368            "one of the variables needed for gradient computation has been "
12369            "modified by an inplace operation"
12370        )
12371        with self.assertRaisesRegex(RuntimeError, err_msg):
12372            out.backward(torch.ones_like(out))
12373
12374    def test_inference_mode_inf_tensor_in_inf_mode_functional_op(self):
12375        def functional_op(x):
12376            return x * x
12377
12378        with torch.inference_mode():
12379            for requires_grad in (True, False):
12380                c = torch.ones(1, 2, 3, requires_grad=requires_grad)
12381
12382                # performing a non-view operation produces a inference tensor
12383                # that does not require grad
12384                func_out = functional_op(c)
12385                self.assertTrue(torch.is_inference(func_out))
12386                self.assertFalse(func_out.requires_grad)
12387
12388    def test_inference_mode_inf_tensor_in_inf_mode_inplace_op(self):
12389        @torch.inference_mode()
12390        def run_test(fn):
12391            for requires_grad in (True, False):
12392                c = torch.ones(1, 2, 3, requires_grad=requires_grad)
12393
12394                # after performing inplace operation, tensor is still
12395                # an inference tensor
12396                fn(c)
12397                self.assertTrue(torch.is_inference(c))
12398                self.assertEqual(c.requires_grad, requires_grad)
12399
12400        run_test(lambda x: x.add_(2))
12401        run_test(lambda x: x.transpose_(0, 1))
12402
12403        # inplace ops with manual kernel for ADInplaceOrView key in VariableTypeManual.cpp
12404        run_test(lambda x: x.resize_(1, 2))
12405        run_test(lambda x: x.resize_as_(torch.ones(1, 2)))
12406        run_test(lambda x: x.copy_(torch.ones(1, 2, 3)))
12407
12408    def test_inference_mode_inf_tensor_in_inf_mode_view_op(self):
12409        with torch.inference_mode():
12410            for requires_grad in (True, False):
12411                c = torch.ones(1, 2, 3, requires_grad=requires_grad)
12412
12413                # perform view operation produces inference tensor
12414                # that does not require grad
12415                view_out = c.view(-1)
12416                self.assertTrue(torch.is_inference(view_out))
12417                self.assertFalse(view_out.requires_grad)
12418
12419    def test_inference_mode_inf_tensor_in_normal_mode_functional_op(self):
12420        def functional_op(x):
12421            return x * x
12422
12423        for requires_grad in (True, False):
12424            with torch.inference_mode():
12425                c = torch.ones(1, 2, 3, requires_grad=requires_grad)
12426
12427        func_out = functional_op(c)
12428        self.assertFalse(torch.is_inference(func_out))
12429        self.assertFalse(func_out.requires_grad)
12430        self.assertTrue(func_out.is_leaf)
12431
12432    def test_inference_mode_inf_tensor_in_normal_mode_inplace_op(self):
12433        def run_test(fn):
12434            for requires_grad in (False, True):
12435                with torch.inference_mode():
12436                    c = torch.ones(1, 2, 3, requires_grad=requires_grad)
12437
12438                if requires_grad:
12439                    # leaf variable that requires grad is being used in an inplace
12440                    # operation when requires_grad=True
12441                    pass
12442                else:
12443                    err_msg = "Inplace update to inference tensor outside InferenceMode"
12444                    with self.assertRaisesRegex(RuntimeError, err_msg):
12445                        fn(c)
12446
12447        run_test(lambda x: x.add_(2))
12448        run_test(lambda x: x.transpose_(0, 1))
12449
12450    def test_inference_mode_inf_tensor_in_normal_mode_view_op(self):
12451        for requires_grad in (True, False):
12452            with torch.inference_mode():
12453                c = torch.ones(1, 2, 3, requires_grad=requires_grad)
12454
12455            out = c.view(-1)
12456            self.assertTrue(torch.is_inference(out))
12457            self.assertFalse(out.requires_grad)
12458            self.assertFalse(out._is_view())
12459            self.assertTrue(out.is_leaf)
12460
12461    def test_normal_tensor_inplace_output_in_inference_mode(self):
12462        def run_test(fn):
12463            for requires_grad in (True, False):
12464                s = torch.ones(1, 2, 3, requires_grad=requires_grad)
12465                a = s.clone()
12466
12467                with torch.inference_mode():
12468                    fn(a)
12469                    self.assertFalse(torch.is_inference(a))
12470                    self.assertEqual(a.requires_grad, requires_grad)
12471
12472                    # inplace -> inplace
12473                    fn(a)
12474                    self.assertFalse(torch.is_inference(a))
12475                    self.assertEqual(a.requires_grad, requires_grad)
12476
12477                    # inplace -> inplace -> view
12478                    view_out = a.view(-1)
12479                    self.assertFalse(torch.is_inference(view_out))
12480                    self.assertEqual(view_out.requires_grad, requires_grad)
12481
12482        run_test(lambda x: x.add_(2))
12483        run_test(lambda x: x.transpose_(0, 1))
12484
12485    def test_normal_tensor_inplace_output_in_normal_mode(self):
12486        def run_test(fn):
12487            for requires_grad in (True, False):
12488                s = torch.ones(1, 2, 3, requires_grad=requires_grad)
12489                a = s.clone()
12490
12491                with torch.inference_mode():
12492                    fn(a)
12493                    self.assertFalse(torch.is_inference(a))
12494                    self.assertEqual(a.requires_grad, requires_grad)
12495
12496                fn(a)
12497                self.assertFalse(torch.is_inference(a))
12498                self.assertEqual(a.requires_grad, requires_grad)
12499
12500                # inplace -> inplace
12501                fn(a)
12502                self.assertFalse(torch.is_inference(a))
12503                self.assertEqual(a.requires_grad, requires_grad)
12504
12505                # inplace -> inplace -> view
12506                view_out = a.view(-1)
12507                self.assertFalse(torch.is_inference(view_out))
12508                self.assertEqual(view_out.requires_grad, requires_grad)
12509            run_test(lambda x: x.add_(2))
12510            run_test(lambda x: x.transpose_(0, 1))
12511
12512    def test_normal_tensor_view_output_in_inference_mode(self):
12513        for requires_grad in (True, False):
12514            s = torch.ones(1, 2, 3, requires_grad=requires_grad)
12515            a = s.clone()
12516
12517            with torch.inference_mode():
12518                out = a.view(-1)
12519                self.assertFalse(torch.is_inference(out))
12520                self.assertEqual(out.requires_grad, requires_grad)
12521                self.assertTrue(out._is_view())
12522
12523                # view -> view
12524                tmp = out.view(-1)
12525                self.assertFalse(torch.is_inference(tmp))
12526                self.assertEqual(tmp.requires_grad, requires_grad)
12527                self.assertTrue(tmp._is_view())
12528                self.assertTrue(tmp.is_leaf)
12529
12530                # view -> view -> inplace
12531                self.assertTrue(torch.is_inference_mode_enabled())
12532                tmp.add_(2)
12533                self.assertFalse(torch.is_inference(tmp))
12534                self.assertEqual(tmp.requires_grad, requires_grad)
12535                # Accessing is_leaf in python tries to update grad_fn and raises:
12536                # A view was created in inference mode and its base or
12537                # another view of its base has been modified inplace in normal mode
12538                # tmp.is_leaf
12539                self.assertEqual(a._version, tmp._version)
12540
12541    def test_normal_tensor_view_output_in_normal_mode(self):
12542        def functional_op(x):
12543            return x * x
12544
12545        for requires_grad in (True, False):
12546            s = torch.ones(1, 2, 3, requires_grad=requires_grad)
12547            a = s.clone()
12548
12549            with torch.inference_mode():
12550                out = a.view(-1)
12551                self.assertFalse(torch.is_inference(out))
12552                self.assertEqual(out.requires_grad, requires_grad)
12553                self.assertTrue(out._is_view())
12554                self.assertTrue(out.is_leaf)
12555
12556            tmp = functional_op(out)
12557            self.assertFalse(torch.is_inference(tmp))
12558            self.assertEqual(tmp.requires_grad, requires_grad)
12559
12560            if requires_grad:
12561                err_msg = (
12562                    "A view was created in inference mode and is being modified inplace"
12563                )
12564                with self.assertRaisesRegex(RuntimeError, err_msg):
12565                    out.add_(2)
12566            else:
12567                out.add_(2)
12568
12569            tmp = out.view(2, 3)
12570            self.assertFalse(torch.is_inference(tmp))
12571            self.assertEqual(tmp.requires_grad, requires_grad)
12572
12573    def test_mix_inference_and_normal_tensor_functional_op(self):
12574        for requires_grad in (True, False):
12575            s = torch.ones(1, 2, 3, requires_grad=requires_grad)
12576
12577            with torch.inference_mode():
12578                c = torch.ones(1, 2, 3, requires_grad=requires_grad)
12579
12580            # add is safe since it doesn't save any variable for backward
12581            out = c.add(s)
12582            self.assertFalse(torch.is_inference(out))
12583            self.assertEqual(out.requires_grad, requires_grad)
12584            if requires_grad:
12585                # leaf inference tensor with requires_grad=True can still have gradient
12586                out.backward(torch.ones_like(out))
12587                self.assertEqual(c.grad, torch.ones_like(c))
12588
12589            if requires_grad:
12590                err_msg = "Inference tensors cannot be saved for backward"
12591                with self.assertRaisesRegex(RuntimeError, err_msg):
12592                    c * s
12593
12594                # TODO: Test this with an autograd.Function when it works
12595                #       stack stopped capturing a TensorList input
12596                # # inference tensor in TensorList input
12597                # inputs = [s, c]
12598                # with self.assertRaisesRegex(RuntimeError, err_msg):
12599                #     torch.stack(inputs)
12600
12601    def test_mix_inference_and_normal_tensor_inplace_op(self):
12602        for requires_grad in (True, False):
12603            s = torch.ones(1, 2, 3, requires_grad=requires_grad)
12604            a = s.clone()
12605
12606            with torch.inference_mode():
12607                c = torch.ones(1, 2, 3)
12608
12609            self.assertTrue(torch.is_inference(c))
12610            if requires_grad:
12611                err_msg = "Inference tensors cannot be saved for backward"
12612                with self.assertRaisesRegex(RuntimeError, err_msg):
12613                    a.mul_(c)
12614
12615                # inference tensor in TensorList input
12616                err_msg = (
12617                    "out=... arguments don't support automatic differentiation, "
12618                    "but one of the arguments requires grad"
12619                )
12620                with self.assertRaisesRegex(RuntimeError, err_msg):
12621                    torch.mul(s, s, out=c)
12622            else:
12623                a.mul_(c)
12624                err_msg = "Inplace update to inference tensor outside InferenceMode is not allowed"
12625                with self.assertRaisesRegex(RuntimeError, err_msg):
12626                    torch.mul(s, s, out=c)
12627
12628    def test_mix_inference_and_normal_tensor_view_op(self):
12629        for requires_grad in (True, False):
12630            s = torch.ones(1, 2, 3, requires_grad=requires_grad)
12631
12632            with torch.inference_mode():
12633                c = torch.ones(1, 2, 3)
12634
12635            # view_as is a composite op which calls view with only one
12636            # tensor argument. So there isn't a mixed inference and normal
12637            # tensor inputs for view ops
12638            tmp1 = c.view_as(s)
12639            self.assertTrue(torch.is_inference(tmp1))
12640            self.assertFalse(tmp1.requires_grad)
12641
12642            # this is fine since its equivalent as s.view(c.sizes()) which
12643            # isn't a mixed input scenario
12644            tmp2 = s.view_as(c)
12645            self.assertFalse(torch.is_inference(tmp2))
12646            self.assertEqual(tmp2.requires_grad, requires_grad)
12647
12648    def test_inference_mode_handle_direct_view_on_rebase(self):
12649        def run_test(fn):
12650            for requires_grad in (True, False):
12651                s = torch.ones(1, 2, 3, requires_grad=requires_grad)
12652                a = s.clone()
12653
12654                with torch.inference_mode():
12655                    view_out = a.view_as(a)
12656
12657                if requires_grad:
12658                    err_msg = "A view was created in inference mode and is being modified inplace"
12659                    with self.assertRaisesRegex(RuntimeError, err_msg):
12660                        fn(view_out)
12661                else:
12662                    fn(view_out)
12663
12664        run_test(lambda x: x.add_(2))
12665        run_test(lambda x: x.transpose_(0, 1))
12666
12667    def test_inference_mode_handle_indirect_view_on_rebase(self):
12668        def run_test(fn):
12669            for requires_grad in (True, False):
12670                s = torch.ones(1, 2, 3, requires_grad=requires_grad)
12671                a = s.clone()
12672
12673                with torch.inference_mode():
12674                    view_out = a.view(-1)
12675
12676                fn(a)
12677                if requires_grad:
12678                    err_msg = "A view was created in inference mode and its base or another view "
12679                    with self.assertRaisesRegex(RuntimeError, err_msg):
12680                        view_out.grad_fn
12681                else:
12682                    view_out.grad_fn
12683
12684        run_test(lambda x: x.add_(2))
12685        run_test(lambda x: x.transpose_(0, 1))
12686
12687
12688class TestMultithreadAutograd(TestCase):
12689    def _run_py_multithread_fn(
12690        self, fn, args=(), num_threads=10, kwargs=None, pass_idx=False
12691    ):
12692        class PropagatingThread(threading.Thread):
12693            """Helper class to propagate exception from child
12694            thread to main thread on join.
12695
12696            Reference: https://stackoverflow.com/a/31614591/5602957
12697            """
12698
12699            def run(self):
12700                self.exception = None
12701                try:
12702                    self.ret = super().run()
12703                except Exception as e:
12704                    self.exception = e
12705
12706            def join(self, timeout=None):
12707                super().join(timeout)
12708                if self.exception:
12709                    raise self.exception from self.exception
12710                return self.ret
12711
12712        threads = []
12713        for idx in range(num_threads):
12714            p = PropagatingThread(target=fn, args=((idx, *args) if pass_idx else args))
12715            p.start()
12716            threads.append(p)
12717
12718        for p in threads:
12719            p.join()
12720
12721    def test_multithreaded_exception_propagation(self):
12722        # Test whether exception in child thread
12723        # are propagated to main thread.
12724        def fn():
12725            self.assertTrue(False)
12726
12727        with self.assertRaises(AssertionError):
12728            self._run_py_multithread_fn(fn)
12729
12730    def test_simple_backward(self):
12731        # simple multithreaded backward that create threads in the beginning of training
12732        # and everything else is training separately, i.e. inputs, operations, etc.
12733        def train_fn():
12734            x = torch.ones(5, 5, requires_grad=True)
12735            y = (x + 3) * (x + 4) * 0.5
12736            y.sum().backward()
12737            self.assertEqual(x.grad, x + 3.5)
12738
12739        self._run_py_multithread_fn(train_fn)
12740
12741    def test_simple_backward_same_input(self):
12742        # simple multithreaded backward with only shared inputs (i.e. This is common
12743        # for things like Hogwild multithreaded training with multiple CPU threads)
12744        def train_fn_backward(x):
12745            y = (x + 3) * (x + 4) * 0.5
12746            y.sum().backward()
12747
12748        x = torch.ones(5, 5, requires_grad=True)
12749        self._run_py_multithread_fn(train_fn_backward, (x,))
12750        # Since we are calling backward from multiple threads
12751        # and all threads share the same input, when we do backward
12752        # concurrently, different backwards will all accumulate to
12753        # the same .grad for each input, and the gradients should
12754        # be equal to num_threads * gradient
12755        self.assertEqual(x.grad, 10 * (x + 3.5))
12756
12757        def train_fn_grad(x):
12758            y = (x + 3) * (x + 4) * 0.5
12759            grads = torch.autograd.grad(y.sum(), x)
12760            self.assertEqual(len(grads), 1)
12761            self.assertEqual(grads[0], x + 3.5)
12762
12763        # since we use functional grad() api, gradients will not
12764        # be accumulate to the same place and should be the same
12765        self._run_py_multithread_fn(train_fn_grad, (x,))
12766
12767    def test_multi_grad_all_hooks(self):
12768        # Multihooks should behave independently per execution of backward
12769        # Test that the hook fired the number of times we ran backward
12770        # even if those executions occur concurrently on different threads
12771        t1 = torch.rand(2, requires_grad=True)
12772        t2 = torch.rand(2, requires_grad=True)
12773        t3 = torch.rand(2, requires_grad=True)
12774        t4 = torch.rand(2, requires_grad=True)
12775
12776        res = None
12777        count = [0]
12778        hook_lock = threading.Lock()
12779
12780        def hook(grads):
12781            nonlocal res
12782            with hook_lock:
12783                count[0] += 1
12784                grad_is_none = [g is not None for g in grads]
12785                if res is None:
12786                    res = grad_is_none
12787                else:
12788                    self.assertEqual(res, grad_is_none)
12789
12790        torch.autograd.graph.register_multi_grad_hook((t1, t2, t3, t4), hook)
12791
12792        out = (t2 * t3).sum()
12793
12794        def backward_retain_graph(out, t2, t3):
12795            out.backward(inputs=(t2, t3), retain_graph=True)
12796
12797        self._run_py_multithread_fn(backward_retain_graph, (out, t2, t3), num_threads=5)
12798
12799        self.assertEqual(count[0], 5)
12800        self.assertEqual(res, [False, True, True, False])
12801
12802        # Leave one hook partially applied
12803        res = None
12804        count = [0]
12805        err_count = [0]
12806        bw_count = [0]
12807        bw_count_lock = threading.Lock()
12808        err_count_lock = threading.Lock()
12809
12810        class Func(torch.autograd.Function):
12811            @staticmethod
12812            def forward(ctx, x):
12813                return x
12814
12815            @staticmethod
12816            def backward(ctx, gO):
12817                with bw_count_lock:
12818                    bw_count[0] += 1
12819                    if bw_count[0] == 1:
12820                        raise RuntimeError("error message")
12821                    else:
12822                        return gO
12823
12824        out = (Func.apply(t2) * t3).sum()
12825
12826        def backward_retain_graph(out, t2, t3):
12827            try:
12828                out.backward(inputs=(t2, t3), retain_graph=True)
12829            except RuntimeError:
12830                with err_count_lock:
12831                    err_count[0] += 1
12832
12833        self._run_py_multithread_fn(backward_retain_graph, (out, t2, t3), num_threads=5)
12834
12835        self.assertEqual(count[0], 4)
12836        self.assertEqual(err_count[0], 1)
12837        self.assertEqual(res, [False, True, True, False])
12838
12839    def test_multi_grad_any_hooks(self):
12840        # Multihooks should behave independently per execution of backward
12841        # Test that the hook fired the number of times we ran backward
12842        # even if those executions occur concurrently on different threads
12843        t1 = torch.rand(2, requires_grad=True)
12844        t2 = torch.rand(2, requires_grad=True)
12845        t3 = torch.rand(2, requires_grad=True)
12846        t4 = torch.rand(2, requires_grad=True)
12847
12848        res = None
12849        count = [0]
12850        hook_lock = threading.Lock()
12851
12852        def hook(grad):
12853            nonlocal res
12854            with hook_lock:
12855                count[0] += 1
12856                if res is None:
12857                    res = "foo"
12858                else:
12859                    self.assertEqual(res, "foo")
12860
12861        torch.autograd.graph.register_multi_grad_hook(
12862            (t1, t2, t3, t4), hook, mode="any"
12863        )
12864
12865        out = (t2 * t3).sum()
12866
12867        def backward_retain_graph(out, t2, t3):
12868            out.backward(inputs=(t2, t3), retain_graph=True)
12869
12870        self._run_py_multithread_fn(backward_retain_graph, (out, t2, t3), num_threads=5)
12871        self.assertEqual(count[0], 5)
12872        self.assertEqual(res, "foo")
12873
12874        # Raise an error in one thread's backward
12875        res = None
12876        count = [0]
12877        err_count = [0]
12878        bw_count = [0]
12879        bw_count_lock = threading.Lock()
12880        err_count_lock = threading.Lock()
12881
12882        class Func(torch.autograd.Function):
12883            @staticmethod
12884            def forward(ctx, x):
12885                return x
12886
12887            @staticmethod
12888            def backward(ctx, gO):
12889                with bw_count_lock:
12890                    bw_count[0] += 1
12891                    if bw_count[0] == 1:
12892                        raise RuntimeError("error message")
12893                    else:
12894                        return gO
12895
12896        out = (Func.apply(t2) * t3).sum()
12897
12898        def backward_retain_graph(out, t2, t3):
12899            try:
12900                out.backward(inputs=(t2, t3), retain_graph=True)
12901            except RuntimeError:
12902                with err_count_lock:
12903                    err_count[0] += 1
12904
12905        self._run_py_multithread_fn(backward_retain_graph, (out, t2, t3), num_threads=5)
12906
12907        # Expect all 5 threads to increment count since the hook runs before
12908        # the custom backward
12909        self.assertEqual(count[0], 5)
12910        self.assertEqual(err_count[0], 1)
12911        self.assertEqual(res, "foo")
12912
12913    def test_dataparallel_saved_tensors_hooks(self):
12914        def pack(x):
12915            warnings.warn("pack")
12916            return x
12917
12918        _self = self
12919
12920        class Model(torch.nn.Module):
12921            def forward(self, x):
12922                with warnings.catch_warnings(record=True) as w:
12923                    y = x * x
12924                    if torch.cuda.device_count() >= 2:
12925                        # DataParallel is calling the forward in different threads
12926                        # without progating TLS, so hooks should not be called here
12927                        _self.assertEqual(len(w), 0)
12928                    else:
12929                        # DataParallel only uses one thread
12930                        # so hooks should be called here
12931                        _self.assertGreater(len(w), 0)
12932
12933        x = torch.ones(5, 5, requires_grad=True)
12934        model = torch.nn.DataParallel(Model())
12935
12936        with torch.autograd.graph.saved_tensors_hooks(pack, lambda x: x):
12937            model(x)
12938            with warnings.catch_warnings(record=True) as w:
12939                y = x * x
12940                # hooks should be called here
12941                _self.assertGreater(len(w), 0)
12942
12943    def test_python_thread_in_middle(self):
12944        # User might write a network that starts on one CPU thread, then runs its second half
12945        # concurrently with other threads (either via python threading or fork/join calls),
12946        # then calls backward()/grad() on BOTH threads, like a Y pattern from input at the
12947        # bottom to output at the top. This way part of the GraphTask is being shared across
12948        # different threads and we need to ensure user specify retain_graph=True, otherwise
12949        # error out with the correct error message
12950
12951        # Case 1: multiple backward with python threads, retain_graph=False
12952        # should throw error in some threads with no retain_graph.
12953        success_vs_raises = [0, 0]
12954
12955        def train_fn_no_retain_graph(x):
12956            y = x + x**2
12957            try:
12958                y.sum().backward()
12959                success_vs_raises[0] += 1
12960            except RuntimeError as error:
12961                success_vs_raises[1] += 1
12962                self.assertRegex(str(error), "Specify retain_graph=True")
12963
12964        x_no_retain = torch.ones(5, 5, requires_grad=True)
12965        y_no_retain = x_no_retain + x_no_retain**2
12966        self._run_py_multithread_fn(
12967            train_fn_no_retain_graph, (y_no_retain,), num_threads=5
12968        )
12969        # at least one thread will be success in this case, all other threads should raise
12970        # with the error that throw to user to recommend them specify retain_graph=True
12971        self.assertTrue(success_vs_raises[0] >= 1)
12972
12973        # multiple backward with python threads, no error with retain_graph=True
12974        def train_fn_retain_graph(x):
12975            y = x + x**2
12976            y.sum().backward(retain_graph=True)
12977
12978        x_retain = torch.ones(5, 5, requires_grad=True)
12979        y_retain = x_retain + x_retain**2
12980        self._run_py_multithread_fn(train_fn_retain_graph, (y_retain,), num_threads=5)
12981        # result should equal to num_thread * gradients
12982        self.assertEqual(
12983            x_retain.grad,
12984            5 * (4 * x_retain**3 + 6 * (x_retain**2) + 4 * x_retain + 1),
12985        )
12986
12987    def test_fork_join_in_middle(self):
12988        # multiple backward with jit threads (fork/join primitive)
12989        # similar to test_python_thread_in_middle, we test with retain_graph=False/True
12990
12991        # Case 1: multiple grad() calls with jit threads, retain_graph=False
12992        # should throw error in some threads with no retain_graph.
12993        @torch.jit.script
12994        def train_fn_jit_no_retain(middle, orig_x):
12995            y = middle + middle**2
12996            return torch.autograd.grad([y.sum()], [orig_x])
12997
12998        @torch.jit.script
12999        def train_fn_fork_join_calls_no_retain(x):
13000            y_no_retain = (x + 3) * (x + 4) * 0.5
13001
13002            fut = torch.jit._fork(train_fn_jit_no_retain, y_no_retain, x)
13003            grad_hat = train_fn_jit_no_retain(y_no_retain, x)
13004            grad = torch.jit._wait(fut)
13005            return grad, grad_hat
13006
13007        try:
13008            train_fn_fork_join_calls_no_retain(torch.randn(5, 5, requires_grad=True))
13009        except RuntimeError as error:
13010            self.assertRegex(str(error), "Specify retain_graph=True")
13011
13012        # Case 2: no error with retain_graph=True
13013        @torch.jit.script
13014        def train_fn_jit_retain(middle, orig_x):
13015            y = middle + middle**2
13016            return torch.autograd.grad([y.sum()], [orig_x], retain_graph=True)
13017
13018        @torch.jit.script
13019        def train_fn_fork_join_calls_retain(x):
13020            y_retain = (x + 3) * (x + 4) * 0.5
13021            fut1 = torch.jit._fork(train_fn_jit_retain, y_retain, x)
13022            fut2 = torch.jit._fork(train_fn_jit_retain, y_retain, x)
13023            grad = train_fn_jit_retain(y_retain, x)
13024            grad1 = torch.jit._wait(fut1)
13025            grad2 = torch.jit._wait(fut2)
13026            return grad, grad1, grad2
13027
13028        grad, grad1, grad2 = train_fn_fork_join_calls_retain(
13029            torch.randn(5, 5, requires_grad=True)
13030        )
13031        self.assertEqual(grad, grad1)
13032        self.assertEqual(grad, grad2)
13033
13034    def test_preserve_backtrace(self):
13035        class Foo(torch.autograd.Function):
13036            @staticmethod
13037            def forward(ctx, input):
13038                return input
13039
13040            @staticmethod
13041            def backward(ctx, *grad):
13042                raise ValueError("something")
13043
13044        t = torch.rand(10, requires_grad=True)
13045        try:
13046            Foo.apply(t).sum().backward()
13047        except Exception:
13048            import traceback
13049
13050            tb = sys.exc_info()[2]
13051            tb_str = "\n".join(traceback.format_tb(tb))
13052            self.assertTrue('raise ValueError("something")' in tb_str)
13053
13054    # TODO(@anjali411): add an OpInfo based test for torch.cat
13055    # Issue: https://github.com/pytorch/pytorch/issues/51627
13056    #        https://github.com/pytorch/pytorch/issues/75852
13057    def test_cat_stack_r_to_c(self):
13058        inp_c = torch.rand(3, 2, dtype=torch.cdouble, requires_grad=True)
13059        inp_r = torch.randn(3, 2, dtype=torch.double, requires_grad=True)
13060
13061        def fn(x1, x2):
13062            return torch.cat((x1, x2), dim=-1)
13063
13064        def fn2(x1, x2):
13065            return torch.stack((x1, x2), dim=-1)
13066
13067        torch.autograd.gradcheck(fn, [inp_r, inp_c], check_forward_ad=True)
13068        torch.autograd.gradcheck(fn, [inp_c, inp_r], check_forward_ad=True)
13069
13070        torch.autograd.gradcheck(fn2, [inp_r, inp_c], check_forward_ad=True)
13071        torch.autograd.gradcheck(fn2, [inp_c, inp_r], check_forward_ad=True)
13072
13073    def test_set_multithreading_enabled_as_context_manager_and_function(self):
13074        # Test as a context manager
13075        with torch.autograd.set_multithreading_enabled(False):
13076            self.assertFalse(torch.autograd.is_multithreading_enabled())
13077        self.assertTrue(torch.autograd.is_multithreading_enabled())
13078
13079        with torch.autograd.set_multithreading_enabled(True):
13080            self.assertTrue(torch.autograd.is_multithreading_enabled())
13081        self.assertTrue(torch.autograd.is_multithreading_enabled())
13082
13083        with torch.autograd.set_multithreading_enabled(False):
13084            torch.autograd.set_multithreading_enabled(True)
13085            self.assertTrue(torch.autograd.is_multithreading_enabled())
13086        self.assertTrue(torch.autograd.is_multithreading_enabled())
13087
13088        torch.autograd.set_multithreading_enabled(False)
13089        self.assertFalse(torch.autograd.is_multithreading_enabled())
13090
13091        torch.autograd.set_multithreading_enabled(True)
13092        self.assertTrue(torch.autograd.is_multithreading_enabled())
13093
13094    @unittest.skipIf(not TEST_CUDA, "test requires CUDA")
13095    def test_custom_function_propagates_errors_from_device_thread(self):
13096        class MyFunc(Function):
13097            @staticmethod
13098            def forward(ctx, x):
13099                return x
13100
13101            @staticmethod
13102            def backward(ctx, gO):
13103                raise RuntimeError("blah")
13104                return gO
13105
13106        t = torch.tensor([1.0, 2.0], requires_grad=True, device=torch.device("cuda"))
13107        out = MyFunc.apply(t).sum()
13108
13109        with self.assertRaisesRegex(RuntimeError, "blah"):
13110            out.backward()
13111
13112
13113class TestNestedCheckpoint(TestCase):
13114    @staticmethod
13115    def grad(fn):
13116        def wrapper(x):
13117            with torch.enable_grad():
13118                out = fn(x)
13119                (grad_input,) = torch.autograd.grad(out, inputs=(x,), create_graph=True)
13120            return grad_input
13121
13122        return wrapper
13123
13124    @staticmethod
13125    def sum(fn):
13126        def wrapped(x):
13127            return fn(x).sum()
13128
13129        return wrapped
13130
13131    @staticmethod
13132    def checkpoint(fn):
13133        def wrapped(*args, **kwargs):
13134            return torch.utils.checkpoint.checkpoint(
13135                fn, *args, use_reentrant=False, **kwargs
13136            )
13137
13138        return wrapped
13139
13140    def get_tests(self, fn):
13141        grad, c = self.grad, self.checkpoint
13142
13143        tests = (
13144            # function <> tuple of function arbitrarily wrapped in checkpoint in various ways
13145            (fn, (c(fn), c(c(fn)))),
13146            (grad(fn), (grad(c(fn)), grad(c(c(fn))))),
13147            (
13148                grad(grad(fn)),
13149                (grad(c(grad(fn))), c(grad(grad(c(fn)))), grad(c(grad(c(fn))))),
13150            ),
13151            (
13152                grad(grad(grad(fn))),
13153                (grad(c(grad(grad(c(fn))))), grad(c(grad(c(grad(c(fn))))))),
13154            ),
13155        )
13156        return tests
13157
13158    def check_graph_dies(self, fn):
13159        def iter_graph(roots):
13160            if not roots:
13161                return
13162            seen = set()
13163            q = collections.deque()
13164            for node in roots:
13165                if node is not None:
13166                    seen.add(node)
13167                    q.append(node)
13168
13169            while q:
13170                node = q.popleft()
13171                for fn, _idx in node.next_functions:
13172                    if fn in seen or fn is None:
13173                        continue
13174                    seen.add(fn)
13175                    q.append(fn)
13176
13177                yield node
13178
13179        class Handle:
13180            __slot__ = ["node_name"]
13181
13182            def __init__(self, node_name):
13183                self.node_name = node_name
13184
13185        def scope():
13186            a = torch.randn((), requires_grad=True)
13187            out = fn(a)
13188            refs = []
13189            for node in iter_graph([out.grad_fn]):
13190                handle = Handle(node.name())
13191                refs.append(weakref.ref(handle))
13192                node.metadata["blah"] = handle
13193            return refs
13194
13195        refs = scope()
13196        node_names = [ref().node_name for ref in refs if ref() is not None]
13197        if len(node_names) > 0:
13198            print("Nodes still alive:", node_names)
13199
13200        self.assertEqual(len(node_names), 0)
13201
13202    @parametrize("early_stop", [True, False])
13203    def test_nested_checkpoint(self, early_stop):
13204        with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop):
13205            x = torch.randn((), requires_grad=True)
13206
13207            def f(x):
13208                out = x.sin().exp().sin()
13209                return out
13210
13211            def g(x):
13212                a = x.sin().exp().sin()
13213                b = x.sin().exp().sin()
13214                (ga,) = torch.autograd.grad(a, x)
13215                (gb,) = torch.autograd.grad(b, x)
13216                return x.sin()
13217
13218            for fn in (f, g):
13219                for expected_fn, actual_fns in self.get_tests(fn):
13220                    expected = expected_fn(x)
13221
13222                    for actual_fn in actual_fns:
13223                        actual = actual_fn(x)
13224                        self.assertTrue(torch.allclose(expected, actual))
13225                        self.check_graph_dies(actual_fn)
13226
13227    @parametrize("early_stop", [True, False])
13228    def test_nested_checkpoint_two_children(self, early_stop):
13229        with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop):
13230            grad, sum, c = self.grad, self.sum, self.checkpoint
13231
13232            def f(x):
13233                return x.sin().exp().sin()
13234
13235            def g(x):
13236                return x.cos().sin().exp()
13237
13238            def hc(x):
13239                return c(g)(c(f)(x))
13240
13241            def h(x):
13242                return g(f(x))
13243
13244            a = torch.randn(3, 3, requires_grad=True)
13245            expected = grad(sum(grad(sum(h))))(a)
13246            actual = grad(sum(grad(sum(c(hc)))))(a)
13247            self.assertTrue(torch.allclose(expected, actual))
13248
13249            actual = grad(sum(c(grad(sum(c(hc))))))(a)
13250            self.assertTrue(torch.allclose(expected, actual))
13251
13252            self.check_graph_dies(grad(c(hc)))
13253            self.check_graph_dies(grad(sum(grad(sum(c(hc))))))
13254            self.check_graph_dies(grad(sum(c(grad(sum(c(hc)))))))
13255
13256    @parametrize("early_stop", [True, False])
13257    def test_nested_checkpoint_non_tensor_inputs_and_outputs(self, early_stop):
13258        def fn(k, a, b, f):
13259            return f(k * a * b.exp()), 1, "abcd"
13260
13261        k = 3
13262        a = torch.tensor(2.0, requires_grad=True)
13263        b = torch.tensor(3.0, requires_grad=True)
13264
13265        def f(x):
13266            return x.sin()
13267
13268        with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop):
13269            out, _unused1, _unused2 = checkpoint(fn, k, a, b, f, use_reentrant=False)
13270        actual_grads = torch.autograd.grad(out, (a, b))
13271
13272        out, _unused1, _unused2 = fn(k, a, b, f)
13273        expected_grads = torch.autograd.grad(out, (a, b))
13274        for actual, expected in zip(actual_grads, expected_grads):
13275            self.assertTrue(torch.allclose(actual, expected))
13276
13277    @parametrize("early_stop", [True, False])
13278    def test_nested_checkpoint_kwargs(self, early_stop):
13279        def fn(a, blah=None):
13280            out = a.sin().exp()
13281            if blah is not None:
13282                out = out * blah
13283            return out.sin().exp()
13284
13285        a = torch.tensor(2.0, requires_grad=True)
13286        b = torch.tensor(3.0, requires_grad=True)
13287
13288        with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop):
13289            out = checkpoint(fn, a, blah=b, use_reentrant=False)
13290            actual_grads = torch.autograd.grad(out, (a, b))
13291
13292            out = fn(a, blah=b)
13293            expected_grads = torch.autograd.grad(out, (a, b))
13294            for actual, expected in zip(actual_grads, expected_grads):
13295                self.assertTrue(torch.allclose(actual, expected))
13296
13297    @parametrize("early_stop", [True, False])
13298    def test_nested_checkpoint_same_graph(self, early_stop):
13299        counter = [0]
13300
13301        def hook(*_unused_args):
13302            counter[0] += 1
13303
13304        def fn(a):
13305            return a.sin().cos().sin()
13306
13307        a = torch.tensor(1.0, requires_grad=True)
13308
13309        with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop):
13310            out = checkpoint(fn, a, use_reentrant=False)
13311        # The hook is registered on the original graph
13312        out.grad_fn.next_functions[0][0].register_hook(hook)
13313        # And backward is performed on the original graph
13314        out.backward()
13315
13316        self.assertEqual(counter[0], 1)
13317
13318    @parametrize("early_stop", [True, False])
13319    def test_nested_checkpoint_reentrant_backwards(self, early_stop):
13320        def fn(a):
13321            x = a.sin().cos()
13322            out = x.sin()
13323            return x, out
13324
13325        def hook(*_unused_args):
13326            # do backward again, but skip over the part of the graph where
13327            # the hook was registered
13328            x.backward(retain_graph=True)
13329
13330        a = torch.tensor(1.0, requires_grad=True)
13331        with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop):
13332            x, out = checkpoint(fn, a, use_reentrant=False)
13333        out.grad_fn.register_hook(hook)
13334        out.backward(retain_graph=True)
13335
13336    def test_nested_checkpoint_set_early_stop(self):
13337        counter = [0]
13338
13339        def clone(x):
13340            counter[0] += 1
13341            return x.clone()
13342
13343        def fn(x):
13344            # Since clone does not save anything, it is not recomputed iff
13345            # early stop is enabled.
13346            return clone(x.sin().cos())
13347
13348        # Early stopping is enabled by default
13349        a = torch.tensor(1.0, requires_grad=True)
13350        out = checkpoint(fn, a, use_reentrant=False)
13351        out.backward()
13352        self.assertEqual(counter[0], 1)
13353
13354        # Try using the context manager to set early stopping to False.
13355        # Expect early stopping to be disabled for all checkpoints ran under
13356        # the context manager, even though context manager is no longer active
13357        # when backward/recomputation is performed.
13358        counter = [0]
13359        a = torch.tensor(1.0, requires_grad=True)
13360        with torch.utils.checkpoint.set_checkpoint_early_stop(False):
13361            out = checkpoint(fn, a, use_reentrant=False)
13362
13363        out.backward()
13364        self.assertEqual(counter[0], 2)
13365
13366    def test_nested_checkpoint_set_early_stop_no_recompution_needed(self):
13367        # Case 1: We have one tensor saved and its the input
13368
13369        # We have two different counters here because in this case we actually
13370        # do call into x.sin() at the python level during recomputation whether
13371        # or not early stop is enabled. This is because the early stopping
13372        # only happens at the autograd level (preventing us from reaching the
13373        # backend).
13374        python_dispatch_counter = [0]
13375        counter = [0]
13376
13377        class SinCounterMode(TorchDispatchMode):
13378            def __init__(self) -> None:
13379                self.count = 0
13380
13381            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
13382                kwargs = {} if kwargs is None else kwargs
13383                if func is torch.ops.aten.sin.default:
13384                    self.count += 1
13385                return func(*args, **kwargs)
13386
13387        def fn(x):
13388            counter[0] += 1
13389            return x.sin()
13390
13391        # With early stopping (enabled by default)
13392        a = torch.tensor(1.0, requires_grad=True)
13393        with SinCounterMode() as python_dispatch_counter:  # noqa: F811
13394            out = checkpoint(fn, a, use_reentrant=False)
13395            out.backward()
13396        self.assertEqual(counter[0], 2)
13397        self.assertEqual(python_dispatch_counter.count, 1)
13398
13399        # Without early stopping
13400        counter = [0]
13401        a = torch.tensor(1.0, requires_grad=True)
13402        with SinCounterMode() as python_dispatch_counter:
13403            with torch.utils.checkpoint.set_checkpoint_early_stop(False):
13404                out = checkpoint(fn, a, use_reentrant=False)
13405            out.backward()
13406        self.assertEqual(counter[0], 2)
13407        self.assertEqual(python_dispatch_counter.count, 2)
13408
13409        # Case 2: Forward saves no tensors
13410
13411        # Since unpack isn't even called, counter is 1 whether or not early stop
13412        # is enabled!
13413        counter = [0]
13414
13415        def fn2(x):
13416            counter[0] += 1
13417            return x.clone()
13418
13419        # With early stopping (enabled by default)
13420        a = torch.tensor(1.0, requires_grad=True)
13421        out = checkpoint(fn2, a, use_reentrant=False)
13422        out.backward()
13423        self.assertEqual(counter[0], 1)
13424
13425        # Without early stopping
13426        counter = [0]
13427        a = torch.tensor(1.0, requires_grad=True)
13428        with torch.utils.checkpoint.set_checkpoint_early_stop(False):
13429            out = checkpoint(fn2, a, use_reentrant=False)
13430        out.backward()
13431        self.assertEqual(counter[0], 1)
13432
13433
13434class TestSelectiveActivationCheckpoint(TestCase):
13435    @unittest.skipIf(not TEST_CUDA, "requires CUDA")
13436    def test_flops_and_mem(self):
13437        # From https://github.com/pytorch/pytorch/pull/126320
13438        def get_act_mem(f):
13439            out = f()
13440            out.backward()
13441            # Why do one forward and backward?
13442            start_mem = torch.cuda.memory_stats()["requested_bytes.all.current"]
13443            out = f()
13444            cur_mem = torch.cuda.memory_stats()["requested_bytes.all.current"]
13445            act_mem = (cur_mem - start_mem) / (1024 * 1024)
13446            out.backward()
13447            return act_mem
13448
13449        def get_bw_flops(f):
13450            # Normalized so that a 512 square matmul returns 1
13451            f().backward()
13452            out = f()
13453            # NB: FlopCounterMode is pushed onto the mode stack before CachedMode, so
13454            # it will be able to observe whether an op is cached or not.
13455            with FlopCounterMode(display=False) as mode:
13456                out.backward()
13457            return mode.get_total_flops() / (512**3 * 2)
13458
13459        x = torch.randn(512, 512, requires_grad=True, device="cuda")
13460        y = torch.randn(512, 512, requires_grad=True, device="cuda")
13461
13462        def fn(x, y):
13463            return torch.mm(x.cos(), y).sin().sum()
13464
13465        def fn_ac(x, y):
13466            return checkpoint(fn, x, y, use_reentrant=False)
13467
13468        def fn_sac(x, y):
13469            context_fn = functools.partial(
13470                create_selective_checkpoint_contexts,
13471                [torch.ops.aten.mm.default],
13472            )
13473            out = checkpoint(fn, x, y, use_reentrant=False, context_fn=context_fn)
13474            return out
13475
13476        def policy_fn(ctx, op, *args, **kwargs):
13477            if op == torch.ops.aten.mm.default:
13478                return CheckpointPolicy.MUST_SAVE
13479            else:
13480                return CheckpointPolicy.PREFER_RECOMPUTE
13481
13482        def fn_sac2(x, y):
13483            context_fn = functools.partial(
13484                create_selective_checkpoint_contexts,
13485                policy_fn,
13486            )
13487            out = checkpoint(fn, x, y, use_reentrant=False, context_fn=context_fn)
13488            return out
13489
13490        def policy_fn_bool(ctx, op, *args, **kwargs):
13491            return op == torch.ops.aten.mm.default
13492
13493        def fn_sac3(x, y):
13494            context_fn = functools.partial(
13495                create_selective_checkpoint_contexts,
13496                policy_fn_bool,
13497            )
13498            out = checkpoint(fn, x, y, use_reentrant=False, context_fn=context_fn)
13499            return out
13500
13501        act_mem_noac = get_act_mem(lambda: fn(x, y))
13502        bw_flops_noac = get_bw_flops(lambda: fn(x, y))
13503
13504        self.assertEqual(act_mem_noac, 2.0)
13505        self.assertEqual(bw_flops_noac, 2.0)
13506
13507        act_mem_ac = get_act_mem(lambda: fn_ac(x, y))
13508        bw_flops_ac = get_bw_flops(lambda: fn_ac(x, y))
13509
13510        self.assertEqual(act_mem_ac, 0.0)
13511        self.assertEqual(bw_flops_ac, 3.0)
13512
13513        act_mem_sac = get_act_mem(lambda: fn_sac(x, y))
13514        bw_flops_sac = get_bw_flops(lambda: fn_sac(x, y))
13515
13516        self.assertEqual(act_mem_sac, 1.0)
13517        self.assertEqual(bw_flops_sac, 2.0)
13518
13519        act_mem_sac2 = get_act_mem(lambda: fn_sac2(x, y))
13520        bw_flops_sac2 = get_bw_flops(lambda: fn_sac2(x, y))
13521
13522        self.assertEqual(act_mem_sac2, 1.0)
13523        self.assertEqual(bw_flops_sac2, 2.0)
13524
13525        act_mem_sac3 = get_act_mem(lambda: fn_sac3(x, y))
13526        bw_flops_sac3 = get_bw_flops(lambda: fn_sac3(x, y))
13527
13528        self.assertEqual(act_mem_sac3, 1.0)
13529        self.assertEqual(bw_flops_sac3, 2.0)
13530
13531    @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py")
13532    def test_output_already_has_autograd_meta(self):
13533        # View of tensor of non-differentiable dtype still has AutogradMeta
13534        def fn(x, y):
13535            return x.view(-1), y.sin().cos()
13536
13537        x = torch.tensor([1, 2, 3], dtype=torch.int64)
13538        y = torch.randn(3, requires_grad=True)
13539
13540        context_fn = functools.partial(
13541            create_selective_checkpoint_contexts,
13542            [torch.ops.aten.view.default],
13543        )
13544        out = checkpoint(fn, x, y, use_reentrant=False, context_fn=context_fn)
13545        out[1].sum().backward()
13546
13547    @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py")
13548    def test_subclass_dispatching_sizes(self):
13549        # Test that we ignore ops that grab metadata like torch.ops.aten.sym_size.default
13550        # Caching such metadata ops can be problematic when the following are satisfied:
13551        #
13552        # 1. size/strides are dispatched upon
13553        # 2. our policy saves sizes
13554        ta = torch.randn(6, 2)
13555
13556        class CustomSizeDynamicShapesTensor(torch.Tensor):
13557            @staticmethod
13558            def __new__(cls, inner):
13559                return torch.Tensor._make_wrapper_subclass(
13560                    # TODO: right now, _make_wrapper_subclass's dynamic shape interaction is not great.
13561                    # Calling the overload that has kwargs causes us to go down the first overload path,
13562                    # which will **always** specialize sizes.
13563                    # We should probably eventually fix this so that the first overload can just handle dynamic shapes.
13564                    cls,
13565                    inner.size(),
13566                    inner.stride(),
13567                    None,
13568                    None,
13569                    inner.dtype,
13570                    inner.layout,
13571                    inner.device,
13572                    False,
13573                    inner.requires_grad,
13574                    "sizes",
13575                )
13576
13577            def __init__(self, inner):
13578                self.inner = inner
13579
13580            @classmethod
13581            def __torch_dispatch__(cls, func, types, args, kwargs):
13582                if kwargs is None:
13583                    kwargs = {}
13584                args_inner = torch.utils._pytree.tree_map_only(
13585                    cls, lambda x: x.inner, args
13586                )
13587                out_inner = func(*args_inner, **kwargs)
13588                return torch.utils._pytree.tree_map_only(
13589                    torch.Tensor, lambda x: cls(x), out_inner
13590                )
13591
13592        def policy_fn(ctx, op, *args, **kwargs):
13593            if op is torch.ops.aten.sym_size.default:
13594                # Silently ignored!
13595                return CheckpointPolicy.MUST_SAVE
13596            else:
13597                return CheckpointPolicy.PREFER_RECOMPUTE
13598
13599        def fn(x):
13600            # We avoid the following case
13601            #
13602            # saved     :[4, 3], [], [], [4, 3], [4, 3], [4, 3], [12]
13603            # forward   :sum   ,sum,mul, mul   , mul   ,view   , view
13604            # recompute :sum   ,sum,mul, view  , view
13605            #
13606            # Views save the shape of their input, so we expect the second
13607            # view to save 12, but because during AC packing during forward
13608            # saves the shapes of the input for metadata checks later,
13609            # we would save the wrong shape during the recompute.
13610            view_out = (x * x.sum()).view(-1).view(4, 3)
13611            self.assertEqual(view_out.grad_fn._saved_self_sym_sizes, [12])
13612            return view_out.exp()
13613
13614        x = torch.randn(4, 3, requires_grad=True)
13615        x_wrapper = CustomSizeDynamicShapesTensor(x)
13616        context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn)
13617        out = checkpoint(fn, x_wrapper, use_reentrant=False, context_fn=context_fn)
13618        out.sum().backward()
13619
13620    def test_bad_inputs(self):
13621        bad_op_list1 = [2]
13622
13623        with self.assertRaisesRegex(
13624            ValueError, "Expected op in `op_list` to be an OpOverload"
13625        ):
13626            create_selective_checkpoint_contexts(bad_op_list1)
13627
13628        bad_op_list2 = [torch.ops.aten.sin]
13629
13630        with self.assertRaisesRegex(
13631            ValueError, "update the OpOverloadPacket to a specific OpOverload"
13632        ):
13633            create_selective_checkpoint_contexts(bad_op_list2)
13634
13635        with self.assertRaisesRegex(TypeError, "either a function or a list of ops."):
13636            create_selective_checkpoint_contexts(2)
13637
13638    # Dynamo fails for various reasons:
13639    # - some tests using custom op that does not implement Fake
13640    # - dynamo is trying to trace into saved variable hooks unpack hook for some reason
13641    @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py")
13642    def test_policy_with_state(self):
13643        # If I have a stateful callable, state is shared between the original
13644        # forward and the recompute.
13645        counters = []
13646
13647        class Policy:
13648            def __init__(self) -> None:
13649                self.counter = [0]
13650                self.recompute_counter = [0]
13651
13652            def __call__(self, ctx, func, *args, **kwargs):
13653                counter = self.recompute_counter if ctx.is_recompute else self.counter
13654                counter[0] += 1
13655                counters.append(counter[0])
13656                if counter == 1 and func is torch.ops.aten.mm.default:
13657                    return CheckpointPolicy.MUST_SAVE
13658                return CheckpointPolicy.PREFER_RECOMPUTE
13659
13660        def fn(x):
13661            return x.sin().sin().sin()
13662
13663        x = torch.randn(3, requires_grad=True)
13664        context_fn = functools.partial(
13665            create_selective_checkpoint_contexts,
13666            Policy(),
13667            allow_cache_entry_mutation=True,
13668        )
13669        out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn)
13670        out.sum().backward()
13671        # 1. counter properly reset to 0 for the recompute
13672        # 2. due to early-stop we do not recompute the final op
13673        self.assertEqual(counters, [1, 2, 3, 1, 2])
13674
13675    @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py")
13676    def test_storage_lifetime(self):
13677        from torch.utils._python_dispatch import _get_current_dispatch_mode
13678        from torch.utils.checkpoint import (
13679            _CachedTorchDispatchMode,
13680            _CachingTorchDispatchMode,
13681        )
13682
13683        def policy_fn(ctx, op, *args, **kwargs):
13684            return CheckpointPolicy.MUST_SAVE
13685
13686        ref = None
13687
13688        def fn(x):
13689            nonlocal ref
13690
13691            self.assertIsInstance(
13692                _get_current_dispatch_mode(),
13693                (_CachingTorchDispatchMode, _CachedTorchDispatchMode),
13694            )
13695
13696            out = x.cos().exp()
13697
13698            if isinstance(_get_current_dispatch_mode(), _CachingTorchDispatchMode):
13699                raw_val = (
13700                    _get_current_dispatch_mode()
13701                    .storage[torch.ops.aten.exp.default][0]
13702                    .val
13703                )
13704                # ref should've been detached
13705                # to avoid graph -> the saved variable hooks -> recompute_context -> storage -> graph
13706                self.assertFalse(raw_val.requires_grad)
13707                ref = weakref.ref(raw_val)
13708
13709            # Careful for early-stop
13710            return out.sin()
13711
13712        with disable_gc():
13713            # Case 1: If graph goes away without backward, make sure there's no reference cycle
13714            #         keeping storage alive.
13715            x = torch.randn(3, requires_grad=True)
13716            context_fn = functools.partial(
13717                create_selective_checkpoint_contexts, policy_fn
13718            )
13719            out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn)
13720            self.assertIsNotNone(ref())
13721            del out
13722            self.assertIsNone(ref())
13723
13724            # Case 2: After backward, even if retain_graph=True, the storage should go away
13725            x = torch.randn(3, requires_grad=True)
13726            context_fn = functools.partial(
13727                create_selective_checkpoint_contexts, policy_fn
13728            )
13729            out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn)
13730            self.assertIsNotNone(ref())
13731            out.sum().backward(retain_graph=True)
13732            # The dispatch mode's storage should still be alive, but the entries should've
13733            # been cleared.
13734            self.assertIsNone(ref())
13735
13736    @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py")
13737    def test_version_counter(self):
13738        def policy_fn(ctx, op, *args, **kwargs):
13739            if op == torch.ops.aten.sin.default:
13740                return CheckpointPolicy.MUST_SAVE
13741            else:
13742                return CheckpointPolicy.PREFER_RECOMPUTE
13743
13744        def fn(x):
13745            return x.sin().mul_(2).cos().exp()
13746
13747        x = torch.randn(3, requires_grad=True)
13748        context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn)
13749        out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn)
13750
13751        # 1) Error because the output of sin is saved and mutated by mul_
13752        with self.assertRaisesRegex(RuntimeError, "has been mutated"):
13753            out.sum().backward()
13754
13755        x = torch.randn(3, requires_grad=True)
13756        context_fn = functools.partial(
13757            create_selective_checkpoint_contexts,
13758            policy_fn,
13759            allow_cache_entry_mutation=True,
13760        )
13761        out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn)
13762
13763        # 2) No longer should be an error because of allow_cache_entry_mutation
13764        out.sum().backward()
13765
13766    @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py")
13767    def test_function_with_more_than_one_output(self):
13768        # maybe there is a more systematic way:
13769        counter = [0]
13770
13771        def policy_fn(ctx, op, *args, **kwargs):
13772            if op == torch.ops.aten.var_mean.correction:
13773                counter[0] += 1
13774                return CheckpointPolicy.MUST_SAVE
13775            else:
13776                return CheckpointPolicy.PREFER_RECOMPUTE
13777
13778        # var_mean has two outputs
13779        def fn(x):
13780            a, b = torch.var_mean(x)
13781            return a * b
13782
13783        x = torch.randn(3, requires_grad=True)
13784        context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn)
13785        out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn)
13786        x_grad = torch.autograd.grad(out.sum(), (x,))
13787        x_grad_ref = torch.autograd.grad(fn(x).sum(), (x,))
13788        self.assertEqual(x_grad, x_grad_ref)
13789        self.assertEqual(counter[0], 2)
13790
13791    @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py")
13792    def test_function_with_non_tensor_output(self):
13793        # When SAC is enabled, the op is not computed a second time
13794        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
13795            counter = [0]
13796
13797            @torch.library.custom_op("mylib::sin_with_extra", mutates_args=())
13798            def sin_with_extra(x: torch.Tensor) -> Tuple[torch.Tensor, int]:
13799                counter[0] += 1
13800                return x.sin(), 2
13801
13802            def setup_context(ctx, inputs, output) -> torch.Tensor:
13803                (x,) = inputs
13804                ctx.save_for_backward(x)
13805
13806            def backward(ctx, grad, _unused):
13807                (x,) = ctx.saved_tensors
13808                return grad * x.cos()
13809
13810            torch.library.register_autograd(
13811                "mylib::sin_with_extra", backward, setup_context=setup_context
13812            )
13813
13814            x = torch.randn(3, requires_grad=True)
13815
13816            def fn(x):
13817                return (torch.ops.mylib.sin_with_extra(x)[0] * x.sin().exp()).sin()
13818
13819            ops_list = [torch.ops.mylib.sin_with_extra.default]
13820
13821            x = torch.randn(3, requires_grad=True)
13822            context_fn = functools.partial(
13823                create_selective_checkpoint_contexts, ops_list
13824            )
13825            out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn)
13826            x_grad = torch.autograd.grad(out.sum(), (x,))
13827            self.assertEqual(counter[0], 1)
13828            x_grad_ref = torch.autograd.grad(fn(x).sum(), (x,))
13829            self.assertEqual(x_grad, x_grad_ref)
13830
13831    @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py")
13832    def test_can_only_trigger_recompute_once(self):
13833        # We don't support this to avoid adding extra complexity for now.
13834        # If there's a need, we could probably do some kind of use_count tracking.
13835        # TODO: have a nice error message here.
13836        def policy_fn(ctx, op, *args, **kwargs):
13837            if op == torch.ops.aten.sin.default:
13838                return CheckpointPolicy.MUST_SAVE
13839            else:
13840                return CheckpointPolicy.PREFER_RECOMPUTE
13841
13842        def fn(x):
13843            return x.sin().cos().exp()
13844
13845        x = torch.randn(3, requires_grad=True)
13846        context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn)
13847        out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn)
13848        out.sum().backward(retain_graph=True)
13849
13850        with self.assertRaisesRegex(RuntimeError, "Trying to backward an extra time"):
13851            out.sum().backward(retain_graph=True)
13852
13853
13854class TestAutogradMultipleDispatch(TestCase):
13855    def test_autograd_multiple_dispatch_registrations(self, device):
13856        t = torch.randn(3, 3, device=device, requires_grad=True)
13857        # using _test_autograd_multiple_dispatch.fullcoverage which has
13858        # registrations in derivatives.yaml for Default, AutogradCUDA and NestedTensorAutograd
13859        out = torch._test_autograd_multiple_dispatch(t)
13860        grad = torch.randn(3, 3, device=device)
13861        out.backward(grad)
13862
13863        if "cuda" not in device:
13864            # bogus default gradient registered for Autograd is grad + 1
13865            self.assertEqual(t.grad, grad + 1)
13866        else:
13867            # bogus gradient registered for AutogradCUDA is grad * 2
13868            self.assertEqual(t.grad, grad * 2)
13869
13870        # test registered AutogradNestedTensor formula
13871        a = (
13872            torch.arange(6, dtype=torch.float, device=device)
13873            .reshape(2, 3)
13874            .requires_grad_(True)
13875        )
13876        b = (
13877            torch.arange(8, dtype=torch.float, device=device)
13878            .reshape(2, 4)
13879            .requires_grad_(True)
13880        )
13881        nt = torch.nested.as_nested_tensor([a, b], dtype=torch.float, device=device)
13882
13883        nt_out = torch._test_autograd_multiple_dispatch(nt)
13884        c = torch.randn(2, 3, device=device)
13885        d = torch.randn(2, 4, device=device)
13886        nt_grad = torch.nested.nested_tensor([c, d], dtype=torch.float, device=device)
13887        nt_out.backward(nt_grad)
13888
13889        # bogus gradient for AutogradNestedTensor is grad * grad
13890        self.assertEqual(a.grad, c * c)
13891        self.assertEqual(b.grad, d * d)
13892
13893    def test_autograd_composite_implicit_and_dispatch_registration(self, device):
13894        t = torch.randn(3, 3, device=device, requires_grad=True)
13895        # using _test_autograd_multiple_dispatch.ntonly
13896        # which has registrations in derivatives.yaml for NestedTensorAutograd and otherwise is CompositeImplicit
13897        out = torch._test_autograd_multiple_dispatch(t, True)
13898        grad = torch.randn(3, 3, device=device)
13899        out.backward(grad)
13900
13901        # t.grad is just out.grad by composite op since _test_autograd_multiple_dispatch is just a clone
13902        self.assertEqual(t.grad, grad)
13903
13904        # test registered AutogradNestedTensor formula
13905        a = (
13906            torch.arange(6, dtype=torch.float, device=device)
13907            .reshape(2, 3)
13908            .requires_grad_(True)
13909        )
13910        b = (
13911            torch.arange(8, dtype=torch.float, device=device)
13912            .reshape(2, 4)
13913            .requires_grad_(True)
13914        )
13915        nt = torch.nested.as_nested_tensor([a, b], dtype=torch.float, device=device)
13916
13917        nt_out = torch._test_autograd_multiple_dispatch(nt, True)
13918        c = torch.randn(2, 3, device=device)
13919        d = torch.randn(2, 4, device=device)
13920        nt_grad = torch.nested.nested_tensor([c, d], dtype=torch.float, device=device)
13921        nt_out.backward(nt_grad)
13922
13923        # bogus gradient for AutogradNestedTensor is grad * grad + grad
13924        self.assertEqual(a.grad, c * c + c)
13925        self.assertEqual(b.grad, d * d + d)
13926
13927    def test_foward_mode_AD(self, device):
13928        # check that forward mode AD is only registered for the Default
13929        # dispatch for _test_autograd_multiple_dispatch.fullcoverage and not AutogradCUDA
13930
13931        primal = torch.randn(3, device=device)
13932        tangent = torch.randn(3, device=device)
13933
13934        with fwAD.dual_level():
13935            dual_input = fwAD.make_dual(primal, tangent)
13936
13937            err_msg = r"Trying to use forward AD with .* that does not support it"
13938            hint_msg = "Running forward AD for an OP that does not implement it should raise a NotImplementedError"
13939
13940            if "cuda" in device:
13941                with self.assertRaisesRegex(NotImplementedError, err_msg, msg=hint_msg):
13942                    torch._test_autograd_multiple_dispatch(dual_input)
13943            else:
13944                torch._test_autograd_multiple_dispatch(dual_input)
13945
13946    def test_view_copy(self, device):
13947        # tests that view_copy derivative formulas are also generated per dispatch key
13948        # from their respective view ops in derivatives.yaml
13949        t = torch.randn(2, 2, device=device, requires_grad=True)
13950        t_ref = t.clone().detach().requires_grad_()
13951        # _test_autograd_multiple_dispatch_view does a .view(-1) on the input
13952        t_view = torch._test_autograd_multiple_dispatch_view(t_ref)
13953        t_view_copy = torch._test_autograd_multiple_dispatch_view_copy(t)
13954
13955        grad = torch.randn(4, device=device)
13956        t_view_copy.backward(grad)
13957        t_view.backward(grad.clone())
13958
13959        # forward and backward give the same shape + result
13960        self.assertEqual(t_view_copy, t_view)
13961        self.assertEqual(t.grad, t_ref.grad)
13962        # backward results are per-dispatch-key in derivatives.yaml
13963        if "cuda" in device:
13964            # gradient registered to AutogradCUDA is grad.reshape_as(self) + 1
13965            self.assertEqual(t.grad, grad.reshape_as(t) + 1)
13966        else:
13967            # Default gradient registered is grad.reshape_as(self)
13968            self.assertEqual(t.grad, grad.reshape_as(t))
13969
13970    @onlyCPU
13971    def test_per_dispatch_key_input_saving(self, device):
13972        # Tests that sum.dim_IntList's input is not saved for regular tensors but is saved for nested tensors
13973        def foo(x):
13974            # Don't modify the input inplace
13975            x = x.clone()
13976            res = x.sum(-1, keepdim=True)
13977            x.add_(x)
13978            return res
13979
13980        inp = torch.rand(2, device=device, requires_grad=True)
13981        # sum's input is not saved for regular Tensors
13982        foo(inp).backward()
13983
13984        # sum's input is saved for Nested Tensors
13985        nt = torch.nested.nested_tensor(
13986            [torch.rand(2), torch.rand(2)], device=device, requires_grad=True
13987        )
13988        with self.assertRaisesRegex(RuntimeError, "modified by an inplace operation"):
13989            foo(nt).backward(
13990                torch.nested.nested_tensor(
13991                    [torch.rand(1), torch.rand(1)], device=device
13992                )
13993            )
13994
13995    @onlyCUDA
13996    def test_backward_single_threaded(self):
13997        threads_eq = None
13998
13999        class TestFn(Function):
14000            @staticmethod
14001            def forward(ctx, x, self):
14002                ctx.self = self
14003                ctx.tid = threading.get_ident()
14004                return x.clone()
14005
14006            @staticmethod
14007            def backward(ctx, gO):
14008                nonlocal threads_eq
14009                threads_eq = ctx.tid == threading.get_ident()
14010                return gO, None
14011
14012        inp = torch.rand(10, device="cuda", requires_grad=True)
14013
14014        with torch.autograd.set_multithreading_enabled(False):
14015            TestFn.apply(inp, None).sum().backward()
14016        self.assertTrue(threads_eq)
14017
14018        TestFn.apply(inp, None).sum().backward()
14019        self.assertFalse(threads_eq)
14020
14021    @onlyCUDA
14022    def test_backward_tls_stash(self):
14023        local = threading.local()
14024        local.my_obj = {}
14025        local.my_obj[10] = 10
14026        test_self = self
14027        torch._C._stash_obj_in_tls("my_obj", local.my_obj)
14028
14029        class TestFn(Function):
14030            @staticmethod
14031            def forward(ctx, x, self):
14032                return x.clone()
14033
14034            @staticmethod
14035            def backward(ctx, gO):
14036                test_self.assertTrue(torch._C._is_key_in_tls("my_obj"))
14037                test_self.assertTrue(torch._C._get_obj_in_tls("my_obj")[10] == 10)
14038                torch._C._get_obj_in_tls("my_obj")[10] = 5
14039                return gO, None
14040
14041        inp = torch.rand(10, device="cuda", requires_grad=True)
14042
14043        TestFn.apply(inp, None).sum().backward()
14044        self.assertEqual(local.my_obj[10], 5)
14045
14046    def test_is_retain_graph(self):
14047        retain_graph_set = False
14048
14049        class TestFn(Function):
14050            @staticmethod
14051            def forward(ctx, x):
14052                return x.clone()
14053
14054            @staticmethod
14055            def backward(ctx, gO):
14056                nonlocal retain_graph_set
14057                retain_graph_set = (
14058                    torch._C._autograd._get_current_graph_task_keep_graph()
14059                )
14060                return gO, None
14061
14062        inp = torch.rand(10, requires_grad=True)
14063
14064        out = TestFn.apply(inp)
14065        self.assertFalse(retain_graph_set)
14066        out.sum().backward(retain_graph=True)
14067        self.assertTrue(retain_graph_set)
14068        out.sum().backward(retain_graph=False)
14069        self.assertFalse(retain_graph_set)
14070
14071    def test_set_sequence_nr(self):
14072        x = torch.randn((10,), dtype=torch.float32, requires_grad=True)
14073        y = torch.randn((10,), dtype=torch.float32, requires_grad=True)
14074        z = torch.randn((10,), dtype=torch.float32, requires_grad=True)
14075
14076        a = x + y
14077        b = y + z
14078        c = a + b
14079
14080        self.assertIsNotNone(a.grad_fn)
14081        self.assertIsNotNone(b.grad_fn)
14082        self.assertIsNotNone(c.grad_fn)
14083
14084        a.grad_fn._set_sequence_nr(100)
14085        b.grad_fn._set_sequence_nr(99)
14086        c.grad_fn._set_sequence_nr(98)
14087
14088        self.assertEqual(a.grad_fn._sequence_nr(), 100)
14089        self.assertEqual(b.grad_fn._sequence_nr(), 99)
14090        self.assertEqual(c.grad_fn._sequence_nr(), 98)
14091
14092        def log_grad_order(grad: torch.Tensor, name: str, order):
14093            order.append(name)
14094            return grad
14095
14096        order = []
14097        a.register_hook(partial(log_grad_order, name="a", order=order))
14098        b.register_hook(partial(log_grad_order, name="b", order=order))
14099        c.register_hook(partial(log_grad_order, name="c", order=order))
14100
14101        c.sum().backward()
14102
14103        # Expect to see that even though c has the smallest sequence number, it is still the first node to get run in autograd.
14104        # Also check that although a comes first during the forward, after giving it priority with sequence_nr,
14105        # its autograd node is run before that of b.
14106        self.assertEqual(order, ["c", "a", "b"])
14107
14108        self.assertEqual(x.grad, torch.ones_like(x))
14109        self.assertEqual(y.grad, 2 * torch.ones_like(x))
14110        self.assertEqual(z.grad, torch.ones_like(x))
14111
14112
14113# Import test cases from below autograd/ here. These are found
14114# implicitly by the loader, so Flake8 thinks they are unused, hence
14115# the suppressions.
14116
14117from autograd.test_complex import TestAutogradComplex  # noqa: F401
14118from autograd.test_functional import TestAutogradFunctional  # noqa: F401
14119from autograd.test_logging import TestAutogradLogging  # noqa: F401
14120
14121
14122# e.g., TestAutogradDeviceTypeCPU and TestAutogradDeviceTypeCUDA
14123instantiate_device_type_tests(TestAutogradDeviceType, globals(), except_for=None)
14124
14125instantiate_device_type_tests(
14126    TestAutogradMultipleDispatch, globals(), only_for=("cpu", "cuda")
14127)
14128
14129instantiate_parametrized_tests(TestAutograd)
14130instantiate_parametrized_tests(TestNestedCheckpoint)
14131
14132if __name__ == "__main__":
14133    run_tests()
14134